How to plot infection rates by contact models¶
The following simulation is identical to the simulation in the tutorial on contact models. The heatmap shows the contribution of contact models to new infections for each day in the simulation. The values are shares in relation to the whole population.
[1]:
import warnings
import holoviews as hv
import numpy as np
import pandas as pd
import sid
from sid.config import INDEX_NAMES
from sid.plotting import plot_infection_rates_by_contact_models
warnings.filterwarnings(
"ignore", message="indexing past lexsort depth may impact performance."
)
[2]:
n_individuals = 10_000
available_ages = [
"0-9",
"10-19",
"20-29",
"30-39",
"40-49",
"50-59",
"60-69",
"70-79",
"80-100",
]
ages = np.random.choice(available_ages, size=n_individuals)
regions = np.random.choice(["North", "East", "South", "West"], size=n_individuals)
hh_id = pd.Series(np.random.choice(int(n_individuals / 1.6), size=n_individuals))
initial_states = pd.DataFrame(
{"age_group": ages, "region": regions, "hh_id": hh_id}
).astype("category")
initial_states.head(5)
[2]:
age_group | region | hh_id | |
---|---|---|---|
0 | 60-69 | South | 4825 |
1 | 20-29 | West | 4856 |
2 | 30-39 | South | 2389 |
3 | 0-9 | West | 5623 |
4 | 10-19 | East | 5200 |
[3]:
contact_models = {}
def random_encounters(states, params, seed):
np.random.seed(seed)
contacts = np.random.choice(np.arange(3), size=states.shape[0])
return pd.Series(index=states.index, data=contacts)
contact_models["random_encounters"] = {
"model": random_encounters,
"assort_by": ["age_group", "region"],
"is_recurrent": False,
}
def meet_household(states, params, seed):
return pd.Series(index=states.index, data=True)
contact_models["meet_household"] = {
"model": meet_household,
"assort_by": "hh_id",
"is_recurrent": True,
}
[4]:
params = sid.load_epidemiological_parameters()
immunity_params = pd.read_csv("immunity_params.csv", index_col=INDEX_NAMES)
params = pd.concat((params, immunity_params))
params.loc[("assortative_matching", "random_encounters", "age_group"), "value"] = 0.2
params.loc[("assortative_matching", "random_encounters", "region"), "value"] = 0.9
params.loc[("infection_prob", "random_encounters", "random_encounters"), "value"] = 0.1
params.loc[("infection_prob", "meet_household", "meet_household"), "value"] = 0.15
[5]:
simulate = sid.get_simulate_func(
params=params,
initial_states=initial_states,
contact_models=contact_models,
duration={"start": "2020-03-01", "end": "2020-04-01"},
saved_columns={"other": ["channel_infected_by_contact"]},
seed=144,
)
result = simulate(params)
Start the simulation...
2020-04-01: 100%|██████████| 32/32 [00:10<00:00, 3.06it/s]
[6]:
contact_models
[6]:
{'random_encounters': {'model': <function __main__.random_encounters(states, params, seed)>,
'assort_by': ['age_group', 'region'],
'is_recurrent': False},
'meet_household': {'model': <function __main__.meet_household(states, params, seed)>,
'assort_by': 'hh_id',
'is_recurrent': True}}
[7]:
result["time_series"]
[7]:
Dask DataFrame Structure:
ever_infected | infectious | knows_immune | is_tested_positive_by_rapid_test | newly_vaccinated | symptomatic | channel_infected_by_contact | date | ever_vaccinated | cd_infectious_false | new_known_case | newly_deceased | hh_id | needs_icu | newly_infected | knows_infectious | age_group | dead | region | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
npartitions=32 | |||||||||||||||||||
bool | bool | bool | bool | bool | bool | category[unknown] | datetime64[ns] | bool | int16 | bool | bool | category[unknown] | bool | bool | bool | category[unknown] | bool | category[unknown] | |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
Dask Name: read-parquet, 32 tasks
[8]:
heatmap = plot_infection_rates_by_contact_models(result["time_series"], unit="share")
heatmap
[8]: