Source code for diff_diff.staggered_results

"""
Result container classes for Callaway-Sant'Anna estimator.

This module provides dataclass containers for storing and presenting
group-time average treatment effects and their aggregations.
"""

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd

from diff_diff.results import _format_survey_block, _get_significance_stars

if TYPE_CHECKING:
    from diff_diff.staggered_bootstrap import CSBootstrapResults


[docs] @dataclass class GroupTimeEffect: """ Treatment effect for a specific group-time combination. Attributes ---------- group : any The treatment cohort (first treatment period). time : any The time period. effect : float The ATT(g,t) estimate. se : float Standard error. n_treated : int Number of treated observations. n_control : int Number of control observations. """ group: Any time: Any effect: float se: float t_stat: float p_value: float conf_int: Tuple[float, float] n_treated: int n_control: int @property def is_significant(self) -> bool: """Check if effect is significant at 0.05 level.""" return bool(self.p_value < 0.05) @property def significance_stars(self) -> str: """Return significance stars based on p-value.""" return _get_significance_stars(self.p_value)
[docs] @dataclass class CallawaySantAnnaResults: """ Results from Callaway-Sant'Anna (2021) staggered DiD estimation. This class stores group-time average treatment effects ATT(g,t) and provides methods for aggregation into summary measures. Attributes ---------- group_time_effects : dict Dictionary mapping (group, time) tuples to effect dictionaries. overall_att : float Overall average treatment effect (weighted average of ATT(g,t)). overall_se : float Standard error of overall ATT. overall_p_value : float P-value for overall ATT. overall_conf_int : tuple Confidence interval for overall ATT. groups : list List of treatment cohorts (first treatment periods). time_periods : list List of all time periods. n_obs : int Total number of observations. n_treated_units : int Number of ever-treated units. n_control_units : int Number of never-treated units (excludes not-yet-treated dynamic controls). event_study_effects : dict, optional Effects aggregated by relative time (event study). group_effects : dict, optional Effects aggregated by treatment cohort. pscore_trim : float Propensity score trimming bound used during estimation. vcov_type : str Variance type used during estimation. CallawaySantAnna is permanently narrow to ``"hc1"`` — see REGISTRY.md "IF-based variance estimators vs analytical-sandwich estimators" for why analytical-sandwich families don't compose with the per-(g,t) doubly-robust / IPW / outcome-regression structure. cluster_name : str, optional Canonical cluster column. Set to ``survey_design.psu`` when an explicit survey PSU was provided (regardless of bare ``cluster=``), otherwise to ``self.cluster`` when bare cluster synthesizes or injects a PSU. ``None`` when no clustering is active. n_clusters : int, optional Number of unique clusters (PSUs) used for variance estimation. ``None`` when no clustering is active. df_inference : int, optional Cluster-level degrees of freedom for downstream inference (e.g., ``HonestDiD`` t-critical-value selection) on the bare-``cluster=`` synthesize path ONLY (the case where ``survey_metadata`` is intentionally ``None`` to preserve the survey/non-survey contract for ``DiagnosticReport`` / ``summary()``). When the user provides an explicit ``survey_design=`` (inject or conflict branches), ``df_inference`` stays ``None`` and the canonical df carrier is ``survey_metadata.df_survey`` — which holds the actual CS-internal df, including any post-resolve tightening (e.g., the ``overall_effective_df`` recompute for replicate aggregations). ``HonestDiD`` reads ``survey_metadata.df_survey`` first and falls back to ``df_inference`` only when ``survey_metadata`` is absent. Narrow contract prevents HonestDiD from silently overriding a tightened survey df with the original ``resolved_survey.df_survey``. """ group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]] overall_att: float overall_se: float overall_t_stat: float overall_p_value: float overall_conf_int: Tuple[float, float] groups: List[Any] time_periods: List[Any] n_obs: int n_treated_units: int n_control_units: int alpha: float = 0.05 control_group: str = "never_treated" base_period: str = "varying" # Anticipation periods (``k``) used at fit time. Persisted on the # result so downstream diagnostics (``BusinessReport`` / # ``DiagnosticReport`` / ``compute_pretrends_power``) can classify # pre-period vs anticipation-window coefficients without re- # plumbing the kwarg through every call site. See REGISTRY.md # §CallawaySantAnna lines 355-395 for the shifted-boundary # contract. anticipation: int = 0 panel: bool = True event_study_effects: Optional[Dict[int, Dict[str, Any]]] = field(default=None) group_effects: Optional[Dict[Any, Dict[str, Any]]] = field(default=None) influence_functions: Optional["np.ndarray"] = field(default=None, repr=False) # Full event-study VCV matrix (Phase 7d): indexed by event_study_vcov_index event_study_vcov: Optional["np.ndarray"] = field(default=None, repr=False) event_study_vcov_index: Optional[list] = field(default=None, repr=False) bootstrap_results: Optional["CSBootstrapResults"] = field(default=None, repr=False) cband_crit_value: Optional[float] = None pscore_trim: float = 0.01 # Survey design metadata (SurveyMetadata instance from diff_diff.survey) survey_metadata: Optional[Any] = field(default=None, repr=False) # EPV diagnostics per (group, time) cell epv_diagnostics: Optional[Dict[Tuple[Any, Any], Dict[str, Any]]] = field( default=None, repr=False ) epv_threshold: float = 10 pscore_fallback: str = "error" # Variance / clustering metadata (PR #XXX — narrow vcov_type contract # + cluster= wiring fix). vcov_type is permanently narrow to "hc1" for # CS per IF-based variance structure (REGISTRY.md). cluster_name + # n_clusters surface the effective clustering level for downstream # introspection and label rendering. vcov_type: str = "hc1" cluster_name: Optional[str] = None n_clusters: Optional[int] = None # df_inference: cluster-level degrees of freedom for downstream # inference, populated on the bare-cluster-synthesize path ONLY. # When the user provides an explicit survey_design= (inject or # conflict branches), df_inference stays None and the canonical df # carrier is survey_metadata.df_survey (which holds the actual # CS-internal df, including any post-resolve tightening via the # overall_effective_df recompute at staggered.py:~1995-1999). # HonestDiD reads survey_metadata.df_survey first and falls back to # df_inference only when survey_metadata is absent. Narrow contract # prevents HonestDiD from silently overriding a tightened survey df # with the original resolved_survey.df_survey. df_inference: Optional[int] = None # --- Inference-field aliases (balance/external-adapter compatibility) --- @property def att(self) -> float: return self.overall_att @property def se(self) -> float: return self.overall_se @property def conf_int(self) -> Tuple[float, float]: return self.overall_conf_int @property def p_value(self) -> float: return self.overall_p_value @property def t_stat(self) -> float: return self.overall_t_stat
[docs] def __repr__(self) -> str: """Concise string representation.""" sig = _get_significance_stars(self.overall_p_value) return ( f"CallawaySantAnnaResults(ATT={self.overall_att:.4f}{sig}, " f"SE={self.overall_se:.4f}, " f"n_groups={len(self.groups)}, " f"n_periods={len(self.time_periods)})" )
@property def coef_var(self) -> float: """Coefficient of variation: SE / abs(overall ATT). NaN when ATT is 0 or SE non-finite.""" if not (np.isfinite(self.overall_se) and self.overall_se >= 0): return np.nan if not np.isfinite(self.overall_att) or self.overall_att == 0: return np.nan return self.overall_se / abs(self.overall_att)
[docs] def summary(self, alpha: Optional[float] = None) -> str: """ Generate formatted summary of estimation results. Parameters ---------- alpha : float, optional Significance level. Defaults to alpha used in estimation. Returns ------- str Formatted summary. """ alpha = alpha or self.alpha conf_level = int((1 - alpha) * 100) lines = [ "=" * 85, "Callaway-Sant'Anna Staggered Difference-in-Differences Results".center(85), "=" * 85, "", f"{'Total observations:':<30} {self.n_obs:>10}", f"{'Treated ' + ('obs:' if not self.panel else 'units:'):<30} {self.n_treated_units:>10}", f"{'Never-treated ' + ('obs:' if not self.panel else 'units:'):<30} {self.n_control_units:>10}", f"{'Treatment cohorts:':<30} {len(self.groups):>10}", f"{'Time periods:':<30} {len(self.time_periods):>10}", f"{'Control group:':<30} {self.control_group:>10}", f"{'Base period:':<30} {self.base_period:>10}", "", ] # Survey design info if self.survey_metadata is not None: sm = self.survey_metadata lines.extend(_format_survey_block(sm, 85)) # Overall ATT lines.extend( [ "-" * 85, "Overall Average Treatment Effect on the Treated".center(85), "-" * 85, f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", "-" * 85, f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} " f"{self.overall_t_stat:>10.3f} {self.overall_p_value:>10.4f} " f"{_get_significance_stars(self.overall_p_value):>6}", "-" * 85, "", f"{conf_level}% Confidence Interval: [{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]", ] ) cv = self.coef_var if np.isfinite(cv): lines.append(f"{'CV (SE/abs(ATT)):':<25} {cv:>10.4f}") lines.append("") # EPV diagnostics block (if any cohort has low EPV) if self.epv_diagnostics: low_epv = {k: v for k, v in self.epv_diagnostics.items() if v.get("is_low")} if low_epv: n_affected = len(low_epv) n_total = len(self.epv_diagnostics) min_entry = min(low_epv.values(), key=lambda v: v["epv"]) min_g = min(low_epv.keys(), key=lambda k: low_epv[k]["epv"]) lines.extend( [ "-" * 85, "Propensity Score Diagnostics".center(85), "-" * 85, f"WARNING: Low Events Per Variable (EPV) in " f"{n_affected} of {n_total} cohort-time cell(s).", f"Minimum EPV: {min_entry['epv']:.1f} " f"(cohort g={min_g[0]}). Threshold: {self.epv_threshold:.0f}.", "Consider: estimation_method='reg' or fewer covariates.", "Call results.epv_summary() for per-cohort details.", "-" * 85, "", ] ) # Event study effects if available if self.event_study_effects: ci_label = "Simult. CI" if self.cband_crit_value is not None else "Pointwise CI" lines.extend( [ "-" * 85, "Event Study (Dynamic) Effects".center(85), "-" * 85, f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", "-" * 85, ] ) for rel_t in sorted(self.event_study_effects.keys()): eff = self.event_study_effects[rel_t] sig = _get_significance_stars(eff["p_value"]) lines.append( f"{rel_t:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} " f"{eff['t_stat']:>10.3f} {eff['p_value']:>10.4f} {sig:>6}" ) lines.extend(["-" * 85]) if self.cband_crit_value is not None: lines.append( f"{ci_label}: critical value = {self.cband_crit_value:.4f} " f"(sup-t bootstrap, {conf_level}% family-wise)" ) lines.append("") # Group effects if available if self.group_effects: lines.extend( [ "-" * 85, "Effects by Treatment Cohort".center(85), "-" * 85, f"{'Cohort':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", "-" * 85, ] ) for group in sorted(self.group_effects.keys()): eff = self.group_effects[group] sig = _get_significance_stars(eff["p_value"]) lines.append( f"{group:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} " f"{eff['t_stat']:>10.3f} {eff['p_value']:>10.4f} {sig:>6}" ) lines.extend(["-" * 85, ""]) lines.extend( [ "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", "=" * 85, ] ) return "\n".join(lines)
[docs] def epv_summary(self, show_all: bool = False) -> pd.DataFrame: """ Return per-cohort EPV diagnostics as a DataFrame. Parameters ---------- show_all : bool, default False If False, only show cells with low EPV. If True, show all cells. Returns ------- pd.DataFrame Columns: group, time, epv, n_events, n_params, is_low. """ if not self.epv_diagnostics: return pd.DataFrame(columns=["group", "time", "epv", "n_events", "n_params", "is_low"]) rows = [] for (g, t), diag in sorted(self.epv_diagnostics.items()): if show_all or diag.get("is_low", False): rows.append( { "group": g, "time": t, "epv": diag.get("epv"), "n_events": diag.get("n_events"), "n_params": diag.get("k"), "is_low": diag.get("is_low", False), } ) cols = ["group", "time", "epv", "n_events", "n_params", "is_low"] return pd.DataFrame(rows, columns=cols) if rows else pd.DataFrame(columns=cols)
[docs] def print_summary(self, alpha: Optional[float] = None) -> None: """Print summary to stdout.""" print(self.summary(alpha))
[docs] def to_dataframe(self, level: str = "group_time") -> pd.DataFrame: """ Convert results to DataFrame. Parameters ---------- level : str, default="group_time" Level of aggregation: "group_time", "event_study", or "group". Returns ------- pd.DataFrame Results as DataFrame. """ if level == "group_time": rows = [] for (g, t), data in self.group_time_effects.items(): row = { "group": g, "time": t, "effect": data["effect"], "se": data["se"], "t_stat": data["t_stat"], "p_value": data["p_value"], "conf_int_lower": data["conf_int"][0], "conf_int_upper": data["conf_int"][1], } if self.epv_diagnostics and (g, t) in self.epv_diagnostics: row["epv"] = self.epv_diagnostics[(g, t)].get("epv") rows.append(row) return pd.DataFrame(rows) elif level == "event_study": if self.event_study_effects is None: raise ValueError("Event study effects not computed. Use aggregate='event_study'.") rows = [] for rel_t, data in sorted(self.event_study_effects.items()): cband_ci = data.get("cband_conf_int", (np.nan, np.nan)) rows.append( { "relative_period": rel_t, "effect": data["effect"], "se": data["se"], "t_stat": data["t_stat"], "p_value": data["p_value"], "conf_int_lower": data["conf_int"][0], "conf_int_upper": data["conf_int"][1], "cband_lower": cband_ci[0], "cband_upper": cband_ci[1], } ) return pd.DataFrame(rows) elif level == "group": if self.group_effects is None: raise ValueError("Group effects not computed. Use aggregate='group'.") rows = [] for group, data in sorted(self.group_effects.items()): rows.append( { "group": group, "effect": data["effect"], "se": data["se"], "t_stat": data["t_stat"], "p_value": data["p_value"], "conf_int_lower": data["conf_int"][0], "conf_int_upper": data["conf_int"][1], } ) return pd.DataFrame(rows) else: raise ValueError( f"Unknown level: {level}. Use 'group_time', 'event_study', or 'group'." )
@property def is_significant(self) -> bool: """Check if overall ATT is significant.""" return bool(self.overall_p_value < self.alpha) @property def significance_stars(self) -> str: """Significance stars for overall ATT.""" return _get_significance_stars(self.overall_p_value)