Source code for diff_diff.results

"""
Results classes for difference-in-differences estimation.

Provides statsmodels-style output with a more Pythonic interface.
"""

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd


[docs] @dataclass class DiDResults: """ Results from a Difference-in-Differences estimation. Provides easy access to coefficients, standard errors, confidence intervals, and summary statistics in a Pythonic way. Attributes ---------- att : float Average Treatment effect on the Treated (ATT). se : float Standard error of the ATT estimate. t_stat : float T-statistic for the ATT estimate. p_value : float P-value for the null hypothesis that ATT = 0. conf_int : tuple[float, float] Confidence interval for the ATT. n_obs : int Number of observations used in estimation. n_treated : int Number of treated units. n_control : int Number of control units. """ att: float se: float t_stat: float p_value: float conf_int: Tuple[float, float] n_obs: int n_treated: int n_control: int alpha: float = 0.05 coefficients: Optional[Dict[str, float]] = field(default=None) vcov: Optional[np.ndarray] = field(default=None) residuals: Optional[np.ndarray] = field(default=None) fitted_values: Optional[np.ndarray] = field(default=None) r_squared: Optional[float] = field(default=None) # Bootstrap inference fields inference_method: str = field(default="analytical") n_bootstrap: Optional[int] = field(default=None) n_clusters: Optional[int] = field(default=None) bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
[docs] def __repr__(self) -> str: """Concise string representation.""" return ( f"DiDResults(ATT={self.att:.4f}{self.significance_stars}, " f"SE={self.se:.4f}, " f"p={self.p_value:.4f})" )
[docs] def summary(self, alpha: Optional[float] = None) -> str: """ Generate a formatted summary of the estimation results. Parameters ---------- alpha : float, optional Significance level for confidence intervals. Defaults to the alpha used during estimation. Returns ------- str Formatted summary table. """ alpha = alpha or self.alpha conf_level = int((1 - alpha) * 100) lines = [ "=" * 70, "Difference-in-Differences Estimation Results".center(70), "=" * 70, "", f"{'Observations:':<25} {self.n_obs:>10}", f"{'Treated units:':<25} {self.n_treated:>10}", f"{'Control units:':<25} {self.n_control:>10}", ] if self.r_squared is not None: lines.append(f"{'R-squared:':<25} {self.r_squared:>10.4f}") # Add inference method info if self.inference_method != "analytical": lines.append(f"{'Inference method:':<25} {self.inference_method:>10}") if self.n_bootstrap is not None: lines.append(f"{'Bootstrap replications:':<25} {self.n_bootstrap:>10}") if self.n_clusters is not None: lines.append(f"{'Number of clusters:':<25} {self.n_clusters:>10}") lines.extend( [ "", "-" * 70, f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'':>5}", "-" * 70, f"{'ATT':<15} {self.att:>12.4f} {self.se:>12.4f} {self.t_stat:>10.3f} {self.p_value:>10.4f} {self.significance_stars:>5}", "-" * 70, "", f"{conf_level}% Confidence Interval: [{self.conf_int[0]:.4f}, {self.conf_int[1]:.4f}]", ] ) # Add significance codes lines.extend( [ "", "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", "=" * 70, ] ) return "\n".join(lines)
[docs] def print_summary(self, alpha: Optional[float] = None) -> None: """Print the summary to stdout.""" print(self.summary(alpha))
[docs] def to_dict(self) -> Dict[str, Any]: """ Convert results to a dictionary. Returns ------- Dict[str, Any] Dictionary containing all estimation results. """ result = { "att": self.att, "se": self.se, "t_stat": self.t_stat, "p_value": self.p_value, "conf_int_lower": self.conf_int[0], "conf_int_upper": self.conf_int[1], "n_obs": self.n_obs, "n_treated": self.n_treated, "n_control": self.n_control, "r_squared": self.r_squared, "inference_method": self.inference_method, } if self.n_bootstrap is not None: result["n_bootstrap"] = self.n_bootstrap if self.n_clusters is not None: result["n_clusters"] = self.n_clusters return result
[docs] def to_dataframe(self) -> pd.DataFrame: """ Convert results to a pandas DataFrame. Returns ------- pd.DataFrame DataFrame with estimation results. """ return pd.DataFrame([self.to_dict()])
@property def is_significant(self) -> bool: """Check if the ATT is statistically significant at the alpha level.""" return bool(self.p_value < self.alpha) @property def significance_stars(self) -> str: """Return significance stars based on p-value.""" return _get_significance_stars(self.p_value)
def _get_significance_stars(p_value: float) -> str: """Return significance stars based on p-value. Returns empty string for NaN p-values (unidentified coefficients from rank-deficient matrices). """ import numpy as np if np.isnan(p_value): return "" if p_value < 0.001: return "***" elif p_value < 0.01: return "**" elif p_value < 0.05: return "*" elif p_value < 0.1: return "." return ""
[docs] @dataclass class PeriodEffect: """ Treatment effect for a single time period. Attributes ---------- period : any The time period identifier. effect : float The treatment effect estimate for this period. se : float Standard error of the effect estimate. t_stat : float T-statistic for the effect estimate. p_value : float P-value for the null hypothesis that effect = 0. conf_int : tuple[float, float] Confidence interval for the effect. """ period: Any effect: float se: float t_stat: float p_value: float conf_int: Tuple[float, float]
[docs] def __repr__(self) -> str: """Concise string representation.""" sig = _get_significance_stars(self.p_value) return ( f"PeriodEffect(period={self.period}, effect={self.effect:.4f}{sig}, " f"SE={self.se:.4f}, p={self.p_value:.4f})" )
@property def is_significant(self) -> bool: """Check if the effect is statistically significant at 0.05 level.""" return bool(self.p_value < 0.05) @property def significance_stars(self) -> str: """Return significance stars based on p-value.""" return _get_significance_stars(self.p_value)
[docs] @dataclass class MultiPeriodDiDResults: """ Results from a Multi-Period Difference-in-Differences estimation. Provides access to period-specific treatment effects as well as an aggregate average treatment effect. Attributes ---------- period_effects : dict[any, PeriodEffect] Dictionary mapping period identifiers to their PeriodEffect objects. Contains all estimated period effects (pre and post, excluding the reference period which is normalized to zero). avg_att : float Average Treatment effect on the Treated across post-periods only. avg_se : float Standard error of the average ATT. avg_t_stat : float T-statistic for the average ATT. avg_p_value : float P-value for the null hypothesis that average ATT = 0. avg_conf_int : tuple[float, float] Confidence interval for the average ATT. n_obs : int Number of observations used in estimation. n_treated : int Number of treated observations. n_control : int Number of control observations. pre_periods : list List of pre-treatment period identifiers. post_periods : list List of post-treatment period identifiers. reference_period : any, optional The reference (omitted) period. Its coefficient is zero by construction and it is excluded from ``period_effects``. interaction_indices : dict, optional Mapping from period identifier to column index in the full variance-covariance matrix. Used internally for sub-VCV extraction (e.g., by HonestDiD and PreTrendsPower). """ period_effects: Dict[Any, PeriodEffect] avg_att: float avg_se: float avg_t_stat: float avg_p_value: float avg_conf_int: Tuple[float, float] n_obs: int n_treated: int n_control: int pre_periods: List[Any] post_periods: List[Any] alpha: float = 0.05 coefficients: Optional[Dict[str, float]] = field(default=None) vcov: Optional[np.ndarray] = field(default=None) residuals: Optional[np.ndarray] = field(default=None) fitted_values: Optional[np.ndarray] = field(default=None) r_squared: Optional[float] = field(default=None) reference_period: Optional[Any] = field(default=None) interaction_indices: Optional[Dict[Any, int]] = field(default=None, repr=False)
[docs] def __repr__(self) -> str: """Concise string representation.""" sig = _get_significance_stars(self.avg_p_value) return ( f"MultiPeriodDiDResults(avg_ATT={self.avg_att:.4f}{sig}, " f"SE={self.avg_se:.4f}, " f"n_post_periods={len(self.post_periods)})" )
@property def pre_period_effects(self) -> Dict[Any, PeriodEffect]: """Pre-period effects only (for parallel trends assessment).""" return {p: pe for p, pe in self.period_effects.items() if p in self.pre_periods} @property def post_period_effects(self) -> Dict[Any, PeriodEffect]: """Post-period effects only.""" return {p: pe for p, pe in self.period_effects.items() if p in self.post_periods}
[docs] def summary(self, alpha: Optional[float] = None) -> str: """ Generate a formatted summary of the estimation results. Parameters ---------- alpha : float, optional Significance level for confidence intervals. Defaults to the alpha used during estimation. Returns ------- str Formatted summary table. """ alpha = alpha or self.alpha conf_level = int((1 - alpha) * 100) lines = [ "=" * 80, "Multi-Period Difference-in-Differences Estimation Results".center(80), "=" * 80, "", f"{'Observations:':<25} {self.n_obs:>10}", f"{'Treated observations:':<25} {self.n_treated:>10}", f"{'Control observations:':<25} {self.n_control:>10}", f"{'Pre-treatment periods:':<25} {len(self.pre_periods):>10}", f"{'Post-treatment periods:':<25} {len(self.post_periods):>10}", ] if self.r_squared is not None: lines.append(f"{'R-squared:':<25} {self.r_squared:>10.4f}") # Pre-period effects (parallel trends test) pre_effects = {p: pe for p, pe in self.period_effects.items() if p in self.pre_periods} if pre_effects: lines.extend( [ "", "-" * 80, "Pre-Period Effects (Parallel Trends Test)".center(80), "-" * 80, f"{'Period':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", "-" * 80, ] ) for period in self.pre_periods: if period in self.period_effects: pe = self.period_effects[period] stars = pe.significance_stars lines.append( f"{str(period):<15} {pe.effect:>12.4f} {pe.se:>12.4f} " f"{pe.t_stat:>10.3f} {pe.p_value:>10.4f} {stars:>6}" ) # Show reference period if self.reference_period is not None: lines.append( f"[ref: {self.reference_period}]" f"{'0.0000':>21} {'---':>12} {'---':>10} {'---':>10} {'':>6}" ) lines.append("-" * 80) # Post-period treatment effects lines.extend( [ "", "-" * 80, "Post-Period Treatment Effects".center(80), "-" * 80, f"{'Period':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", "-" * 80, ] ) for period in self.post_periods: pe = self.period_effects[period] stars = pe.significance_stars lines.append( f"{str(period):<15} {pe.effect:>12.4f} {pe.se:>12.4f} " f"{pe.t_stat:>10.3f} {pe.p_value:>10.4f} {stars:>6}" ) # Average effect lines.extend( [ "-" * 80, "", "-" * 80, "Average Treatment Effect (across post-periods)".center(80), "-" * 80, f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", "-" * 80, f"{'Avg ATT':<15} {self.avg_att:>12.4f} {self.avg_se:>12.4f} " f"{self.avg_t_stat:>10.3f} {self.avg_p_value:>10.4f} {self.significance_stars:>6}", "-" * 80, "", f"{conf_level}% Confidence Interval: [{self.avg_conf_int[0]:.4f}, {self.avg_conf_int[1]:.4f}]", ] ) # Add significance codes lines.extend( [ "", "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", "=" * 80, ] ) return "\n".join(lines)
[docs] def print_summary(self, alpha: Optional[float] = None) -> None: """Print the summary to stdout.""" print(self.summary(alpha))
[docs] def get_effect(self, period) -> PeriodEffect: """ Get the treatment effect for a specific period. Parameters ---------- period : any The period identifier. Returns ------- PeriodEffect The treatment effect for the specified period. Raises ------ KeyError If the period is not found in post-treatment periods. """ if period not in self.period_effects: if hasattr(self, "reference_period") and period == self.reference_period: raise KeyError( f"Period '{period}' is the reference period (coefficient " f"normalized to zero by construction). Its effect is 0.0 with " f"no associated uncertainty." ) raise KeyError( f"Period '{period}' not found. " f"Available periods: {list(self.period_effects.keys())}" ) return self.period_effects[period]
[docs] def to_dict(self) -> Dict[str, Any]: """ Convert results to a dictionary. Returns ------- Dict[str, Any] Dictionary containing all estimation results. """ result: Dict[str, Any] = { "avg_att": self.avg_att, "avg_se": self.avg_se, "avg_t_stat": self.avg_t_stat, "avg_p_value": self.avg_p_value, "avg_conf_int_lower": self.avg_conf_int[0], "avg_conf_int_upper": self.avg_conf_int[1], "n_obs": self.n_obs, "n_treated": self.n_treated, "n_control": self.n_control, "n_pre_periods": len(self.pre_periods), "n_post_periods": len(self.post_periods), "r_squared": self.r_squared, "reference_period": self.reference_period, } # Add period-specific effects for period, pe in self.period_effects.items(): result[f"effect_period_{period}"] = pe.effect result[f"se_period_{period}"] = pe.se result[f"pval_period_{period}"] = pe.p_value return result
[docs] def to_dataframe(self) -> pd.DataFrame: """ Convert period-specific effects to a pandas DataFrame. Returns ------- pd.DataFrame DataFrame with one row per estimated period (pre and post). """ rows = [] for period, pe in self.period_effects.items(): rows.append( { "period": period, "effect": pe.effect, "se": pe.se, "t_stat": pe.t_stat, "p_value": pe.p_value, "conf_int_lower": pe.conf_int[0], "conf_int_upper": pe.conf_int[1], "is_significant": pe.is_significant, "is_post": period in self.post_periods, } ) return pd.DataFrame(rows)
@property def is_significant(self) -> bool: """Check if the average ATT is statistically significant at the alpha level.""" return bool(self.avg_p_value < self.alpha) @property def significance_stars(self) -> str: """Return significance stars for the average ATT based on p-value.""" return _get_significance_stars(self.avg_p_value)
[docs] @dataclass class SyntheticDiDResults: """ Results from a Synthetic Difference-in-Differences estimation. Combines DiD with synthetic control by re-weighting control units to match pre-treatment trends of treated units. Attributes ---------- att : float Average Treatment effect on the Treated (ATT). se : float Standard error of the ATT estimate (bootstrap or placebo-based). t_stat : float T-statistic for the ATT estimate. p_value : float P-value for the null hypothesis that ATT = 0. conf_int : tuple[float, float] Confidence interval for the ATT. n_obs : int Number of observations used in estimation. n_treated : int Number of treated units. n_control : int Number of control units. unit_weights : dict Dictionary mapping control unit IDs to their synthetic weights. time_weights : dict Dictionary mapping pre-treatment periods to their time weights. pre_periods : list List of pre-treatment period identifiers. post_periods : list List of post-treatment period identifiers. variance_method : str Method used for variance estimation: "bootstrap" or "placebo". """ att: float se: float t_stat: float p_value: float conf_int: Tuple[float, float] n_obs: int n_treated: int n_control: int unit_weights: Dict[Any, float] time_weights: Dict[Any, float] pre_periods: List[Any] post_periods: List[Any] alpha: float = 0.05 variance_method: str = field(default="placebo") noise_level: Optional[float] = field(default=None) zeta_omega: Optional[float] = field(default=None) zeta_lambda: Optional[float] = field(default=None) pre_treatment_fit: Optional[float] = field(default=None) placebo_effects: Optional[np.ndarray] = field(default=None) n_bootstrap: Optional[int] = field(default=None)
[docs] def __repr__(self) -> str: """Concise string representation.""" sig = _get_significance_stars(self.p_value) return ( f"SyntheticDiDResults(ATT={self.att:.4f}{sig}, " f"SE={self.se:.4f}, " f"p={self.p_value:.4f})" )
[docs] def summary(self, alpha: Optional[float] = None) -> str: """ Generate a formatted summary of the estimation results. Parameters ---------- alpha : float, optional Significance level for confidence intervals. Defaults to the alpha used during estimation. Returns ------- str Formatted summary table. """ alpha = alpha or self.alpha conf_level = int((1 - alpha) * 100) lines = [ "=" * 75, "Synthetic Difference-in-Differences Estimation Results".center(75), "=" * 75, "", f"{'Observations:':<25} {self.n_obs:>10}", f"{'Treated units:':<25} {self.n_treated:>10}", f"{'Control units:':<25} {self.n_control:>10}", f"{'Pre-treatment periods:':<25} {len(self.pre_periods):>10}", f"{'Post-treatment periods:':<25} {len(self.post_periods):>10}", ] if self.zeta_omega is not None: lines.append(f"{'Zeta (unit weights):':<25} {self.zeta_omega:>10.4f}") if self.zeta_lambda is not None: lines.append(f"{'Zeta (time weights):':<25} {self.zeta_lambda:>10.6f}") if self.noise_level is not None: lines.append(f"{'Noise level:':<25} {self.noise_level:>10.4f}") if self.pre_treatment_fit is not None: lines.append(f"{'Pre-treatment fit (RMSE):':<25} {self.pre_treatment_fit:>10.4f}") # Variance method info lines.append(f"{'Variance method:':<25} {self.variance_method:>10}") if self.variance_method == "bootstrap" and self.n_bootstrap is not None: lines.append(f"{'Bootstrap replications:':<25} {self.n_bootstrap:>10}") lines.extend( [ "", "-" * 75, f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'':>5}", "-" * 75, f"{'ATT':<15} {self.att:>12.4f} {self.se:>12.4f} {self.t_stat:>10.3f} {self.p_value:>10.4f} {self.significance_stars:>5}", "-" * 75, "", f"{conf_level}% Confidence Interval: [{self.conf_int[0]:.4f}, {self.conf_int[1]:.4f}]", ] ) # Show top unit weights if self.unit_weights: sorted_weights = sorted(self.unit_weights.items(), key=lambda x: x[1], reverse=True) top_n = min(5, len(sorted_weights)) lines.extend( [ "", "-" * 75, "Top Unit Weights (Synthetic Control)".center(75), "-" * 75, ] ) for unit, weight in sorted_weights[:top_n]: if weight > 0.001: # Only show meaningful weights lines.append(f" Unit {unit}: {weight:.4f}") # Show how many units have non-trivial weight n_nonzero = sum(1 for w in self.unit_weights.values() if w > 0.001) lines.append(f" ({n_nonzero} units with weight > 0.001)") # Add significance codes lines.extend( [ "", "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", "=" * 75, ] ) return "\n".join(lines)
[docs] def print_summary(self, alpha: Optional[float] = None) -> None: """Print the summary to stdout.""" print(self.summary(alpha))
[docs] def to_dict(self) -> Dict[str, Any]: """ Convert results to a dictionary. Returns ------- Dict[str, Any] Dictionary containing all estimation results. """ result = { "att": self.att, "se": self.se, "t_stat": self.t_stat, "p_value": self.p_value, "conf_int_lower": self.conf_int[0], "conf_int_upper": self.conf_int[1], "n_obs": self.n_obs, "n_treated": self.n_treated, "n_control": self.n_control, "n_pre_periods": len(self.pre_periods), "n_post_periods": len(self.post_periods), "variance_method": self.variance_method, "noise_level": self.noise_level, "zeta_omega": self.zeta_omega, "zeta_lambda": self.zeta_lambda, "pre_treatment_fit": self.pre_treatment_fit, } if self.n_bootstrap is not None: result["n_bootstrap"] = self.n_bootstrap return result
[docs] def to_dataframe(self) -> pd.DataFrame: """ Convert results to a pandas DataFrame. Returns ------- pd.DataFrame DataFrame with estimation results. """ return pd.DataFrame([self.to_dict()])
[docs] def get_unit_weights_df(self) -> pd.DataFrame: """ Get unit weights as a pandas DataFrame. Returns ------- pd.DataFrame DataFrame with unit IDs and their weights. """ return pd.DataFrame( [{"unit": unit, "weight": weight} for unit, weight in self.unit_weights.items()] ).sort_values("weight", ascending=False)
[docs] def get_time_weights_df(self) -> pd.DataFrame: """ Get time weights as a pandas DataFrame. Returns ------- pd.DataFrame DataFrame with time periods and their weights. """ return pd.DataFrame( [{"period": period, "weight": weight} for period, weight in self.time_weights.items()] )
@property def is_significant(self) -> bool: """Check if the ATT is statistically significant at the alpha level.""" return bool(self.p_value < self.alpha) @property def significance_stars(self) -> str: """Return significance stars based on p-value.""" return _get_significance_stars(self.p_value)