Source code for sid.plotting

import itertools
from typing import Any
from typing import Dict
from typing import Optional
from typing import Union

import dask.dataframe as dd
import holoviews as hv
import numpy as np
import pandas as pd
from bokeh.models import HoverTool
from sid.colors import get_colors
from sid.policies import compute_pseudo_effect_sizes_of_policies


[docs]DEFAULT_FIGURE_KWARGS = { "height": 400, "width": 600, "line_width": 12, "title": "Gantt Chart of Policies",
}
[docs]def plot_policy_gantt_chart( policies, effects=False, colors="categorical", fig_kwargs=None, ): """Plot a Gantt chart of the policies.""" if fig_kwargs is None: fig_kwargs = {} fig_kwargs = {**DEFAULT_FIGURE_KWARGS, **fig_kwargs} if isinstance(policies, dict): df = ( pd.DataFrame(policies) .T.reset_index() .rename(columns={"index": "name"}) .astype({"start": "datetime64", "end": "datetime64"}) .drop(columns="policy") ) elif isinstance(policies, pd.DataFrame): df = policies else: raise ValueError("'policies' should be either a dict or pandas.DataFrame.") if effects: effect_kwargs = effects if isinstance(effects, dict) else {} effects = compute_pseudo_effect_sizes_of_policies( policies=policies, **effect_kwargs ) effects_s = pd.DataFrame( [{"policy": name, "effect": effects[name]["mean"]} for name in effects] ).set_index("policy")["effect"] df = df.merge(effects_s, left_on="name", right_index=True) df["alpha"] = (1 - df["effect"] + 0.1) / 1.1 else: df["alpha"] = 1 df = df.reset_index() df = _complete_dates(df) df = _add_color_to_gantt_groups(df, colors) df = _add_positions(df) hv.extension("bokeh", logo=False) segments = hv.Segments( df, [ hv.Dimension("start", label="Date"), hv.Dimension("position", label="Affected contact model"), "end", "position", ], ) y_ticks_and_labels = list(zip(*_create_y_ticks_and_labels(df))) tooltips = [("Name", "@name")] if effects: tooltips.append(("Effect", "@effect")) hover = HoverTool(tooltips=tooltips) gantt = segments.opts( color="color", alpha="alpha", tools=[hover], yticks=y_ticks_and_labels, **fig_kwargs, ) return gantt
[docs]def _complete_dates(df): """Complete dates.""" for column in ("start", "end"): df[column] = pd.to_datetime(df[column]) df["start"] = df["start"].fillna(df["start"].min()) df["end"] = df["end"].fillna(df["end"].max()) return df
[docs]def _add_color_to_gantt_groups(df, colors): """Add a color for each affected contact model.""" colors_ = itertools.cycle(get_colors(colors, 4)) acm_to_color = dict(zip(df["affected_contact_model"].unique(), colors_)) df["color"] = df["affected_contact_model"].replace(acm_to_color) return df
[docs]def _add_positions(df): """Add positions. This functions computes the positions of policies, displayed as segments on the time line. For example, if two policies affecting the same contact model have an overlapping time windows, the segments are stacked and drawn onto different horizontal lines. """ min_position = 0 def _add_within_group_positions(df): """Add within group positions.""" nonlocal min_position position = pd.Series(data=min_position, index=df.index) for i in range(1, len(df)): start = df.iloc[i]["start"] end = df.iloc[i]["end"] is_overlapping = ( (df.iloc[:i]["start"] <= start) & (start <= df.iloc[:i]["end"]) ) | ((df.iloc[:i]["start"] <= end) & (end <= df.iloc[:i]["end"])) if is_overlapping.any(): possible_positions = set(range(min_position, i + min_position + 1)) positions_of_overlapping = set(position.iloc[:i][is_overlapping]) position.iloc[i] = min(possible_positions - positions_of_overlapping) min_position = max(position) + 1 return position positions = df.groupby("affected_contact_model", group_keys=False).apply( _add_within_group_positions ) df["position_local"] = positions df["position"] = df.groupby( ["affected_contact_model", "position_local"], sort=True ).ngroup() return df
[docs]def _create_y_ticks_and_labels(df): """Create the positions and their related labels for the y axis.""" pos_per_group = df.groupby("position", as_index=False).first() mean_pos_per_group = ( pos_per_group.groupby("affected_contact_model")["position"].mean().reset_index() ) return mean_pos_per_group["position"], mean_pos_per_group["affected_contact_model"]
[docs]ERROR_MISSING_CHANNEL = ( "'channel_infected_by_contact' is necessary to plot infection rates by contact "
"models. Re-run the simulation and pass `saved_columns={'channels': " "'channel_infected_by_contact'}` to `sid.get_simulate_func`." )
[docs]DEFAULT_IR_PER_CM_KWARGS = { "width": 600, "height": 400, "tools": ["hover"], "title": "Contribution of Contact Models to Infections", "xlabel": "Date", "ylabel": "Contact Model", "invert_yaxis": True, "colorbar": True, "cmap": "YlOrBr",
}
[docs]def plot_infection_rates_by_contact_models( df_or_time_series: Union[pd.DataFrame, dd.core.DataFrame], show_reported_cases: bool = False, unit: str = "share", fig_kwargs: Optional[Dict[str, Any]] = None, ) -> hv.HeatMap: """Plot infection rates by contact models. Parameters ---------- df_or_time_series : Union[pandas.DataFrame, dask.dataframe.core.DataFrame] The input can be one of the following two. 1. It is a :class:`dask.dataframe.core.DataFrame` which holds the time series from a simulation. 2. It can be a :class:`pandas.DataFrame` which is created with :func:`prepare_data_for_infection_rates_by_contact_models`. It allows to compute the data for various simulations with different seeds and use the average over all seeds. show_reported_cases : bool, optional A boolean to select between reported or real cases of infections. Reported cases are identified via testing mechanisms. unit : str The arguments specifies the unit shown in the figure. - ``"share"`` means that daily units represent the share of infection caused by a contact model among all infections on the same day. - ``"population_share"`` means that daily units represent the share of infection caused by a contact model among all people on the same day. - ``"incidence"`` means that the daily units represent incidence levels per 100,000 individuals. fig_kwargs : Optional[Dict[str, Any]], optional Additional keyword arguments which are passed to ``heatmap.opts`` to style the plot. The keyword arguments overwrite or extend the default arguments. Returns ------- heatmap : hv.HeatMap The heatmap object. """ fig_kwargs = ( DEFAULT_IR_PER_CM_KWARGS if fig_kwargs is None else {**DEFAULT_IR_PER_CM_KWARGS, **fig_kwargs} ) if _is_data_prepared_for_heatmap(df_or_time_series): df = df_or_time_series else: df = prepare_data_for_infection_rates_by_contact_models( df_or_time_series, show_reported_cases, unit ) hv.extension("bokeh", logo=False) heatmap = hv.HeatMap(df) plot = heatmap.opts(**fig_kwargs) return plot
[docs]def _is_data_prepared_for_heatmap(df): """Is the data prepared for the heatmap plot.""" return ( isinstance(df, pd.DataFrame) and df.columns.isin(["date", "channel_infected_by_contact", "share"]).all() and not df["channel_infected_by_contact"] .isin(["not_infected_by_contact"])
.any() )
[docs]def prepare_data_for_infection_rates_by_contact_models( time_series: dd.core.DataFrame, show_reported_cases: bool = False, # noqa: U100 unit: str = "share", ) -> pd.DataFrame: """Prepare the data for the heatmap plot. Parameters ---------- time_series : dask.dataframe.core.DataFrame The time series of a simulation. show_reported_cases : bool, optional A boolean to select between reported or real cases of infections. Reported cases are identified via testing mechanisms. unit : str The arguments specifies the unit shown in the figure. - ``"share"`` means that daily units represent the share of infection caused by a contact model among all infections on the same day. - ``"population_share"`` means that daily units represent the share of infection caused by a contact model among all people on the same day. - ``"incidence"`` means that the daily units represent incidence levels per 100,000 individuals. Returns ------- time_series : pandas.DataFrame The time series with the prepared data for the plot. """ if isinstance(time_series, pd.DataFrame): time_series = dd.from_pandas(time_series, npartitions=1) elif not isinstance(time_series, dd.core.DataFrame): raise ValueError("'time_series' must be either pd.DataFrame or dask.dataframe.") if "channel_infected_by_contact" not in time_series: raise ValueError(ERROR_MISSING_CHANNEL) if show_reported_cases: time_series = _adjust_channel_infected_by_contact_to_new_known_cases( time_series ) counts = ( time_series[["date", "channel_infected_by_contact"]] .groupby(["date", "channel_infected_by_contact"]) .size() .reset_index() .rename(columns={0: "n"}) ) if unit == "share": out = counts.query( "channel_infected_by_contact != 'not_infected_by_contact'" ).assign( share=lambda x: x["n"] / x.groupby("date")["n"].transform("sum", meta=("n", "f8")), ) elif unit == "population_share": out = counts.assign( share=lambda x: x["n"] / x.groupby("date")["n"].transform("sum", meta=("n", "f8")), ).query("channel_infected_by_contact != 'not_infected_by_contact'") elif unit == "incidence": out = counts.query( "channel_infected_by_contact != 'not_infected_by_contact'" ).assign(share=lambda x: x["n"] * 7 / 100_000) else: raise ValueError( "'unit' should be one of 'share', 'population_share' or 'incidence'" ) out = out.drop(columns="n").compute() return out
[docs]def _adjust_channel_infected_by_contact_to_new_known_cases(df): """Adjust channel of infections by contacts to new known cases. Channel of infections are recorded on the date an individual got infected which is not the same date an individual is tested positive with a PCR test. This function adjusts ``"channel_infected_by_contact"`` such that the infection channel is shifted to the date when an individual is tested positive. """ channel_of_infection_by_contact = _find_channel_of_infection_for_individuals(df) df = _patch_channel_infected_by_contact(df, channel_of_infection_by_contact) return df
[docs]def _find_channel_of_infection_for_individuals(df): """Find the channel of infected by contact for each individual.""" df["channel_infected_by_contact"] = df["channel_infected_by_contact"].cat.as_known() df["channel_infected_by_contact"] = df[ "channel_infected_by_contact" ].cat.remove_categories(["not_infected_by_contact"]) df = df.dropna(subset=["channel_infected_by_contact"]) df = df["channel_infected_by_contact"].compute() return df
[docs]def _patch_channel_infected_by_contact(df, s): """Patch channel of infections by contact to only show channels for known cases.""" df = df.drop(columns="channel_infected_by_contact") df = df.merge(s.to_frame(name="channel_infected_by_contact"), how="left") df["channel_infected_by_contact"] = df["channel_infected_by_contact"].mask( ~df["new_known_case"], np.nan ) df["channel_infected_by_contact"] = ( df["channel_infected_by_contact"] .cat.add_categories("not_infected_by_contact") .fillna("not_infected_by_contact") ) return df