Source code for diff_diff.imputation_results

"""
Result containers for the Imputation DiD estimator.

This module contains ImputationBootstrapResults and ImputationDiDResults
dataclasses. Extracted from imputation.py for module size management.
"""

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd

from diff_diff.results import _format_survey_block, _get_significance_stars

__all__ = [
    "ImputationBootstrapResults",
    "ImputationDiDResults",
]


[docs] @dataclass class ImputationBootstrapResults: """ Results from ImputationDiD bootstrap inference. Bootstrap is a library extension beyond Borusyak et al. (2024), which proposes only analytical inference via the conservative variance estimator. Provided for consistency with CallawaySantAnna and SunAbraham. Attributes ---------- n_bootstrap : int Number of bootstrap iterations. weight_type : str Type of bootstrap weights: "rademacher", "mammen", or "webb". alpha : float Significance level used for confidence intervals. overall_att_se : float Bootstrap standard error for overall ATT. overall_att_ci : tuple Bootstrap confidence interval for overall ATT. overall_att_p_value : float Bootstrap p-value for overall ATT. event_study_ses : dict, optional Bootstrap SEs for event study effects. event_study_cis : dict, optional Bootstrap CIs for event study effects. event_study_p_values : dict, optional Bootstrap p-values for event study effects. group_ses : dict, optional Bootstrap SEs for group effects. group_cis : dict, optional Bootstrap CIs for group effects. group_p_values : dict, optional Bootstrap p-values for group effects. bootstrap_distribution : np.ndarray, optional Full bootstrap distribution of overall ATT. """ n_bootstrap: int weight_type: str alpha: float overall_att_se: float overall_att_ci: Tuple[float, float] overall_att_p_value: float event_study_ses: Optional[Dict[int, float]] = None event_study_cis: Optional[Dict[int, Tuple[float, float]]] = None event_study_p_values: Optional[Dict[int, float]] = None group_ses: Optional[Dict[Any, float]] = None group_cis: Optional[Dict[Any, Tuple[float, float]]] = None group_p_values: Optional[Dict[Any, float]] = None bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
[docs] @dataclass class ImputationDiDResults: """ Results from Borusyak-Jaravel-Spiess (2024) imputation DiD estimation. Attributes ---------- treatment_effects : pd.DataFrame Unit-level treatment effects with columns: unit, time, tau_hat, weight. overall_att : float Overall average treatment effect on the treated. overall_se : float Standard error of overall ATT. overall_t_stat : float T-statistic for overall ATT. overall_p_value : float P-value for overall ATT. overall_conf_int : tuple Confidence interval for overall ATT. event_study_effects : dict, optional Dictionary mapping relative time h to effect dict with keys: 'effect', 'se', 't_stat', 'p_value', 'conf_int', 'n_obs'. group_effects : dict, optional Dictionary mapping cohort g to effect dict. groups : list List of treatment cohorts. time_periods : list List of all time periods. n_obs : int Total number of observations. n_treated_obs : int Number of treated observations (:math:`|\\Omega_1|`). n_untreated_obs : int Number of untreated observations (:math:`|\\Omega_0|`). n_treated_units : int Number of ever-treated units. n_control_units : int Number of units contributing to Omega_0. alpha : float Significance level used. pretrend_results : dict, optional Populated by pretrend_test(). bootstrap_results : ImputationBootstrapResults, optional Bootstrap inference results. """ treatment_effects: pd.DataFrame overall_att: float overall_se: float overall_t_stat: float overall_p_value: float overall_conf_int: Tuple[float, float] event_study_effects: Optional[Dict[int, Dict[str, Any]]] group_effects: Optional[Dict[Any, Dict[str, Any]]] groups: List[Any] time_periods: List[Any] n_obs: int n_treated_obs: int n_untreated_obs: int n_treated_units: int n_control_units: int alpha: float = 0.05 anticipation: int = 0 pretrend_results: Optional[Dict[str, Any]] = field(default=None, repr=False) bootstrap_results: Optional[ImputationBootstrapResults] = field(default=None, repr=False) # Internal: stores data needed for pretrend_test() _estimator_ref: Optional[Any] = field(default=None, repr=False) # Survey design metadata (SurveyMetadata instance from diff_diff.survey) survey_metadata: Optional[Any] = field(default=None, repr=False) # --- Inference-field aliases (balance/external-adapter compatibility) --- @property def att(self) -> float: return self.overall_att @property def se(self) -> float: return self.overall_se @property def conf_int(self) -> Tuple[float, float]: return self.overall_conf_int @property def p_value(self) -> float: return self.overall_p_value @property def t_stat(self) -> float: return self.overall_t_stat
[docs] def __repr__(self) -> str: """Concise string representation.""" sig = _get_significance_stars(self.overall_p_value) return ( f"ImputationDiDResults(ATT={self.overall_att:.4f}{sig}, " f"SE={self.overall_se:.4f}, " f"n_groups={len(self.groups)}, " f"n_treated_obs={self.n_treated_obs})" )
@property def coef_var(self) -> float: """Coefficient of variation: SE / abs(overall ATT). NaN when ATT is 0 or SE non-finite.""" if not (np.isfinite(self.overall_se) and self.overall_se >= 0): return np.nan if not np.isfinite(self.overall_att) or self.overall_att == 0: return np.nan return self.overall_se / abs(self.overall_att)
[docs] def summary(self, alpha: Optional[float] = None) -> str: """ Generate formatted summary of estimation results. Parameters ---------- alpha : float, optional Significance level. Defaults to alpha used in estimation. Returns ------- str Formatted summary. """ alpha = alpha or self.alpha conf_level = int((1 - alpha) * 100) lines = [ "=" * 85, "Imputation DiD Estimator Results (Borusyak et al. 2024)".center(85), "=" * 85, "", f"{'Total observations:':<30} {self.n_obs:>10}", f"{'Treated observations:':<30} {self.n_treated_obs:>10}", f"{'Untreated observations:':<30} {self.n_untreated_obs:>10}", f"{'Treated units:':<30} {self.n_treated_units:>10}", f"{'Control units:':<30} {self.n_control_units:>10}", f"{'Treatment cohorts:':<30} {len(self.groups):>10}", f"{'Time periods:':<30} {len(self.time_periods):>10}", "", ] # Survey design info if self.survey_metadata is not None: sm = self.survey_metadata lines.extend(_format_survey_block(sm, 85)) # Overall ATT lines.extend( [ "-" * 85, "Overall Average Treatment Effect on the Treated".center(85), "-" * 85, f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} " f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", "-" * 85, ] ) t_str = ( f"{self.overall_t_stat:>10.3f}" if np.isfinite(self.overall_t_stat) else f"{'NaN':>10}" ) p_str = ( f"{self.overall_p_value:>10.4f}" if np.isfinite(self.overall_p_value) else f"{'NaN':>10}" ) sig = _get_significance_stars(self.overall_p_value) lines.extend( [ f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} " f"{t_str} {p_str} {sig:>6}", "-" * 85, "", f"{conf_level}% Confidence Interval: " f"[{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]", ] ) cv = self.coef_var if np.isfinite(cv): lines.append(f"{'CV (SE/abs(ATT)):':<25} {cv:>10.4f}") lines.append("") # Event study effects if self.event_study_effects: lines.extend( [ "-" * 85, "Event Study (Dynamic) Effects".center(85), "-" * 85, f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} " f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", "-" * 85, ] ) for h in sorted(self.event_study_effects.keys()): eff = self.event_study_effects[h] if eff.get("n_obs", 1) == 0: # Reference period marker lines.append( f"[ref: {h}]" f"{'0.0000':>17} {'---':>12} {'---':>10} {'---':>10} {'':>6}" ) elif np.isnan(eff["effect"]): lines.append(f"{h:<15} {'NaN':>12} {'NaN':>12} {'NaN':>10} {'NaN':>10} {'':>6}") else: e_sig = _get_significance_stars(eff["p_value"]) e_t = ( f"{eff['t_stat']:>10.3f}" if np.isfinite(eff["t_stat"]) else f"{'NaN':>10}" ) e_p = ( f"{eff['p_value']:>10.4f}" if np.isfinite(eff["p_value"]) else f"{'NaN':>10}" ) lines.append( f"{h:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} " f"{e_t} {e_p} {e_sig:>6}" ) lines.extend(["-" * 85, ""]) # Group effects if self.group_effects: lines.extend( [ "-" * 85, "Group (Cohort) Effects".center(85), "-" * 85, f"{'Cohort':<15} {'Estimate':>12} {'Std. Err.':>12} " f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", "-" * 85, ] ) for g in sorted(self.group_effects.keys()): eff = self.group_effects[g] if np.isnan(eff["effect"]): lines.append(f"{g:<15} {'NaN':>12} {'NaN':>12} {'NaN':>10} {'NaN':>10} {'':>6}") else: g_sig = _get_significance_stars(eff["p_value"]) g_t = ( f"{eff['t_stat']:>10.3f}" if np.isfinite(eff["t_stat"]) else f"{'NaN':>10}" ) g_p = ( f"{eff['p_value']:>10.4f}" if np.isfinite(eff["p_value"]) else f"{'NaN':>10}" ) lines.append( f"{g:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} " f"{g_t} {g_p} {g_sig:>6}" ) lines.extend(["-" * 85, ""]) # Pre-trend test if self.pretrend_results is not None: pt = self.pretrend_results lines.extend( [ "-" * 85, "Pre-Trend Test (Equation 9)".center(85), "-" * 85, f"{'F-statistic:':<30} {pt['f_stat']:>10.3f}", f"{'P-value:':<30} {pt['p_value']:>10.4f}", f"{'Degrees of freedom:':<30} {pt['df']:>10}", f"{'Number of leads:':<30} {pt['n_leads']:>10}", "-" * 85, "", ] ) lines.extend( [ "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", "=" * 85, ] ) return "\n".join(lines)
[docs] def print_summary(self, alpha: Optional[float] = None) -> None: """Print summary to stdout.""" print(self.summary(alpha))
[docs] def to_dataframe(self, level: str = "observation") -> pd.DataFrame: """ Convert results to DataFrame. Parameters ---------- level : str, default="observation" Level of aggregation: - "observation": Unit-level treatment effects - "event_study": Event study effects by relative time - "group": Group (cohort) effects Returns ------- pd.DataFrame Results as DataFrame. """ if level == "observation": return self.treatment_effects.copy() elif level == "event_study": if self.event_study_effects is None: raise ValueError( "Event study effects not computed. " "Use aggregate='event_study' or aggregate='all'." ) rows = [] for h, data in sorted(self.event_study_effects.items()): rows.append( { "relative_period": h, "effect": data["effect"], "se": data["se"], "t_stat": data["t_stat"], "p_value": data["p_value"], "conf_int_lower": data["conf_int"][0], "conf_int_upper": data["conf_int"][1], "n_obs": data.get("n_obs", np.nan), } ) return pd.DataFrame(rows) elif level == "group": if self.group_effects is None: raise ValueError( "Group effects not computed. " "Use aggregate='group' or aggregate='all'." ) rows = [] for g, data in sorted(self.group_effects.items()): rows.append( { "group": g, "effect": data["effect"], "se": data["se"], "t_stat": data["t_stat"], "p_value": data["p_value"], "conf_int_lower": data["conf_int"][0], "conf_int_upper": data["conf_int"][1], "n_obs": data.get("n_obs", np.nan), } ) return pd.DataFrame(rows) else: raise ValueError( f"Unknown level: {level}. Use 'observation', 'event_study', or 'group'." )
[docs] def pretrend_test(self, n_leads: Optional[int] = None) -> Dict[str, Any]: """ Run a pre-trend test (Equation 9 of Borusyak et al. 2024). Adds pre-treatment lead indicators to the Step 1 OLS and tests their joint significance via a Wald F-test (cluster-robust, or design-based survey VCV when survey_design was provided at fit). Parameters ---------- n_leads : int, optional Number of pre-treatment leads to include. If None, uses all available pre-treatment periods minus one (for the reference period). Returns ------- dict Dictionary with keys: 'f_stat', 'p_value', 'df', 'n_leads', 'lead_coefficients'. """ if self._estimator_ref is None: raise RuntimeError( "Pre-trend test requires internal estimator reference. " "Re-fit the model to use this method." ) result = self._estimator_ref._pretrend_test(n_leads=n_leads) self.pretrend_results = result return result
@property def is_significant(self) -> bool: """Check if overall ATT is significant.""" return bool(self.overall_p_value < self.alpha) @property def significance_stars(self) -> str: """Significance stars for overall ATT.""" return _get_significance_stars(self.overall_p_value)