Source code for diff_diff.visualization._continuous

"""Continuous DiD visualization functions (dose-response curves)."""

from typing import TYPE_CHECKING, Any, Optional, Tuple

import pandas as pd

if TYPE_CHECKING:
    from diff_diff.continuous_did_results import ContinuousDiDResults, DoseResponseCurve


[docs] def plot_dose_response( results: Optional["ContinuousDiDResults"] = None, *, curve: Optional["DoseResponseCurve"] = None, data: Optional[pd.DataFrame] = None, target: str = "att", alpha: float = 0.05, figsize: Tuple[float, float] = (10, 6), title: Optional[str] = None, xlabel: str = "Dose", ylabel: str = "Treatment Effect", color: str = "#2563eb", ci_color: Optional[str] = None, show_zero_line: bool = True, ax: Optional[Any] = None, show: bool = True, backend: str = "matplotlib", ) -> Any: """ Plot dose-response curve from Continuous DiD estimation. Visualizes how the treatment effect varies with the treatment dose (intensity), with confidence bands. Parameters ---------- results : ContinuousDiDResults, optional Results from ContinuousDiD estimator. Extracts the dose-response curve based on ``target``. curve : DoseResponseCurve, optional A DoseResponseCurve object directly. data : pd.DataFrame, optional DataFrame with columns ``dose``, ``effect``, ``se`` (and optionally ``conf_int_lower``, ``conf_int_upper``). target : str, default="att" Which dose-response curve: ``"att"`` or ``"acrt"``. alpha : float, default=0.05 Significance level for confidence intervals (used with DataFrame input). figsize : tuple, default=(10, 6) Figure size (width, height) in inches. title : str, optional Plot title. Auto-generated if None. xlabel : str, default="Dose" X-axis label. ylabel : str, default="Treatment Effect" Y-axis label. color : str, default="#2563eb" Color for the line. ci_color : str, optional Color for confidence band. Defaults to ``color`` with transparency. show_zero_line : bool, default=True Whether to show a horizontal line at y=0. ax : matplotlib.axes.Axes, optional Axes to plot on. If None, creates new figure. show : bool, default=True Whether to call plt.show() at the end. backend : str, default="matplotlib" Plotting backend: ``"matplotlib"`` or ``"plotly"``. Returns ------- matplotlib.axes.Axes or plotly.graph_objects.Figure The axes object (matplotlib) or figure (plotly). """ from scipy import stats as scipy_stats # Extract dose-response data if sum(x is not None for x in (results, curve, data)) != 1: raise ValueError("Provide exactly one of 'results', 'curve', or 'data'.") if results is not None: if target == "att": curve = results.dose_response_att elif target == "acrt": curve = results.dose_response_acrt else: raise ValueError(f"target must be 'att' or 'acrt', got '{target}'") if curve is not None: # Infer target from curve when passed directly (not via results) if results is None and hasattr(curve, "target") and curve.target: target = curve.target dose_grid = curve.dose_grid effects = curve.effects ci_lower = curve.conf_int_lower ci_upper = curve.conf_int_upper elif data is not None: if "dose" not in data.columns or "effect" not in data.columns: raise ValueError("DataFrame must have 'dose' and 'effect' columns") dose_grid = data["dose"].values effects = data["effect"].values if "conf_int_lower" in data.columns and "conf_int_upper" in data.columns: ci_lower = data["conf_int_lower"].values ci_upper = data["conf_int_upper"].values elif "se" in data.columns: z = scipy_stats.norm.ppf(1 - alpha / 2) ci_lower = effects - z * data["se"].values ci_upper = effects + z * data["se"].values else: ci_lower = None ci_upper = None else: raise ValueError("Must provide 'results', 'curve', or 'data'.") # Auto-generate title if title is None: if target == "att": title = "ATT Dose-Response Curve" else: title = "ACRT Dose-Response Curve" if backend == "plotly": return _render_dose_response_plotly( dose_grid=dose_grid, effects=effects, ci_lower=ci_lower, ci_upper=ci_upper, title=title, xlabel=xlabel, ylabel=ylabel, color=color, ci_color=ci_color, show_zero_line=show_zero_line, show=show, ) return _render_dose_response_mpl( dose_grid=dose_grid, effects=effects, ci_lower=ci_lower, ci_upper=ci_upper, figsize=figsize, title=title, xlabel=xlabel, ylabel=ylabel, color=color, ci_color=ci_color, show_zero_line=show_zero_line, ax=ax, show=show, )
def _render_dose_response_mpl( *, dose_grid, effects, ci_lower, ci_upper, figsize, title, xlabel, ylabel, color, ci_color, show_zero_line, ax, show, ): """Render dose-response curve with matplotlib.""" from diff_diff.visualization._common import _require_matplotlib plt = _require_matplotlib() if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.get_figure() # Zero line if show_zero_line: ax.axhline(y=0, color="gray", linestyle="--", linewidth=1, alpha=0.5) # Confidence band if ci_lower is not None and ci_upper is not None: band_color = ci_color or color ax.fill_between( dose_grid, ci_lower, ci_upper, alpha=0.15, color=band_color, label="95% CI", ) # Effect line ax.plot(dose_grid, effects, color=color, linewidth=2, label="Effect") ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title) ax.legend(loc="best") ax.grid(True, alpha=0.3) fig.tight_layout() if show: plt.show() return ax def _render_dose_response_plotly( *, dose_grid, effects, ci_lower, ci_upper, title, xlabel, ylabel, color, ci_color, show_zero_line, show, ): """Render dose-response curve with plotly.""" from diff_diff.visualization._common import ( _color_to_rgba, _plotly_default_layout, _require_plotly, ) go = _require_plotly() fig = go.Figure() # Zero line if show_zero_line: fig.add_hline(y=0, line_dash="dash", line_color="gray", line_width=1, opacity=0.5) # Confidence band if ci_lower is not None and ci_upper is not None: band_color = ci_color or color dose_list = list(dose_grid) fig.add_trace( go.Scatter( x=dose_list + dose_list[::-1], y=list(ci_upper) + list(ci_lower)[::-1], fill="toself", fillcolor=_color_to_rgba(band_color, 0.15), line=dict(color="rgba(0,0,0,0)"), name="95% CI", hoverinfo="skip", ) ) # Effect line fig.add_trace( go.Scatter( x=list(dose_grid), y=list(effects), mode="lines", line=dict(color=color, width=2), name="Effect", ) ) _plotly_default_layout(fig, title=title, xlabel=xlabel, ylabel=ylabel) if show: fig.show() return fig