"""
Result container classes for Callaway-Sant'Anna estimator.
This module provides dataclass containers for storing and presenting
group-time average treatment effects and their aggregations.
"""
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
from diff_diff.results import _format_survey_block, _get_significance_stars
if TYPE_CHECKING:
from diff_diff.staggered_bootstrap import CSBootstrapResults
[docs]
@dataclass
class GroupTimeEffect:
"""
Treatment effect for a specific group-time combination.
Attributes
----------
group : any
The treatment cohort (first treatment period).
time : any
The time period.
effect : float
The ATT(g,t) estimate.
se : float
Standard error.
n_treated : int
Number of treated observations.
n_control : int
Number of control observations.
"""
group: Any
time: Any
effect: float
se: float
t_stat: float
p_value: float
conf_int: Tuple[float, float]
n_treated: int
n_control: int
@property
def is_significant(self) -> bool:
"""Check if effect is significant at 0.05 level."""
return bool(self.p_value < 0.05)
@property
def significance_stars(self) -> str:
"""Return significance stars based on p-value."""
return _get_significance_stars(self.p_value)
[docs]
@dataclass
class CallawaySantAnnaResults:
"""
Results from Callaway-Sant'Anna (2021) staggered DiD estimation.
This class stores group-time average treatment effects ATT(g,t) and
provides methods for aggregation into summary measures.
Attributes
----------
group_time_effects : dict
Dictionary mapping (group, time) tuples to effect dictionaries.
overall_att : float
Overall average treatment effect (weighted average of ATT(g,t)).
overall_se : float
Standard error of overall ATT.
overall_p_value : float
P-value for overall ATT.
overall_conf_int : tuple
Confidence interval for overall ATT.
groups : list
List of treatment cohorts (first treatment periods).
time_periods : list
List of all time periods.
n_obs : int
Total number of observations.
n_treated_units : int
Number of ever-treated units.
n_control_units : int
Number of never-treated units (excludes not-yet-treated dynamic controls).
event_study_effects : dict, optional
Effects aggregated by relative time (event study).
group_effects : dict, optional
Effects aggregated by treatment cohort.
pscore_trim : float
Propensity score trimming bound used during estimation.
"""
group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]]
overall_att: float
overall_se: float
overall_t_stat: float
overall_p_value: float
overall_conf_int: Tuple[float, float]
groups: List[Any]
time_periods: List[Any]
n_obs: int
n_treated_units: int
n_control_units: int
alpha: float = 0.05
control_group: str = "never_treated"
base_period: str = "varying"
# Anticipation periods (``k``) used at fit time. Persisted on the
# result so downstream diagnostics (``BusinessReport`` /
# ``DiagnosticReport`` / ``compute_pretrends_power``) can classify
# pre-period vs anticipation-window coefficients without re-
# plumbing the kwarg through every call site. See REGISTRY.md
# §CallawaySantAnna lines 355-395 for the shifted-boundary
# contract.
anticipation: int = 0
panel: bool = True
event_study_effects: Optional[Dict[int, Dict[str, Any]]] = field(default=None)
group_effects: Optional[Dict[Any, Dict[str, Any]]] = field(default=None)
influence_functions: Optional["np.ndarray"] = field(default=None, repr=False)
# Full event-study VCV matrix (Phase 7d): indexed by event_study_vcov_index
event_study_vcov: Optional["np.ndarray"] = field(default=None, repr=False)
event_study_vcov_index: Optional[list] = field(default=None, repr=False)
bootstrap_results: Optional["CSBootstrapResults"] = field(default=None, repr=False)
cband_crit_value: Optional[float] = None
pscore_trim: float = 0.01
# Survey design metadata (SurveyMetadata instance from diff_diff.survey)
survey_metadata: Optional[Any] = field(default=None, repr=False)
# EPV diagnostics per (group, time) cell
epv_diagnostics: Optional[Dict[Tuple[Any, Any], Dict[str, Any]]] = field(
default=None, repr=False
)
epv_threshold: float = 10
pscore_fallback: str = "error"
# --- Inference-field aliases (balance/external-adapter compatibility) ---
@property
def att(self) -> float:
return self.overall_att
@property
def se(self) -> float:
return self.overall_se
@property
def conf_int(self) -> Tuple[float, float]:
return self.overall_conf_int
@property
def p_value(self) -> float:
return self.overall_p_value
@property
def t_stat(self) -> float:
return self.overall_t_stat
[docs]
def __repr__(self) -> str:
"""Concise string representation."""
sig = _get_significance_stars(self.overall_p_value)
return (
f"CallawaySantAnnaResults(ATT={self.overall_att:.4f}{sig}, "
f"SE={self.overall_se:.4f}, "
f"n_groups={len(self.groups)}, "
f"n_periods={len(self.time_periods)})"
)
@property
def coef_var(self) -> float:
"""Coefficient of variation: SE / abs(overall ATT). NaN when ATT is 0 or SE non-finite."""
if not (np.isfinite(self.overall_se) and self.overall_se >= 0):
return np.nan
if not np.isfinite(self.overall_att) or self.overall_att == 0:
return np.nan
return self.overall_se / abs(self.overall_att)
[docs]
def summary(self, alpha: Optional[float] = None) -> str:
"""
Generate formatted summary of estimation results.
Parameters
----------
alpha : float, optional
Significance level. Defaults to alpha used in estimation.
Returns
-------
str
Formatted summary.
"""
alpha = alpha or self.alpha
conf_level = int((1 - alpha) * 100)
lines = [
"=" * 85,
"Callaway-Sant'Anna Staggered Difference-in-Differences Results".center(85),
"=" * 85,
"",
f"{'Total observations:':<30} {self.n_obs:>10}",
f"{'Treated ' + ('obs:' if not self.panel else 'units:'):<30} {self.n_treated_units:>10}",
f"{'Never-treated ' + ('obs:' if not self.panel else 'units:'):<30} {self.n_control_units:>10}",
f"{'Treatment cohorts:':<30} {len(self.groups):>10}",
f"{'Time periods:':<30} {len(self.time_periods):>10}",
f"{'Control group:':<30} {self.control_group:>10}",
f"{'Base period:':<30} {self.base_period:>10}",
"",
]
# Survey design info
if self.survey_metadata is not None:
sm = self.survey_metadata
lines.extend(_format_survey_block(sm, 85))
# Overall ATT
lines.extend(
[
"-" * 85,
"Overall Average Treatment Effect on the Treated".center(85),
"-" * 85,
f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
"-" * 85,
f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} "
f"{self.overall_t_stat:>10.3f} {self.overall_p_value:>10.4f} "
f"{_get_significance_stars(self.overall_p_value):>6}",
"-" * 85,
"",
f"{conf_level}% Confidence Interval: [{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]",
]
)
cv = self.coef_var
if np.isfinite(cv):
lines.append(f"{'CV (SE/abs(ATT)):':<25} {cv:>10.4f}")
lines.append("")
# EPV diagnostics block (if any cohort has low EPV)
if self.epv_diagnostics:
low_epv = {k: v for k, v in self.epv_diagnostics.items() if v.get("is_low")}
if low_epv:
n_affected = len(low_epv)
n_total = len(self.epv_diagnostics)
min_entry = min(low_epv.values(), key=lambda v: v["epv"])
min_g = min(low_epv.keys(), key=lambda k: low_epv[k]["epv"])
lines.extend(
[
"-" * 85,
"Propensity Score Diagnostics".center(85),
"-" * 85,
f"WARNING: Low Events Per Variable (EPV) in "
f"{n_affected} of {n_total} cohort-time cell(s).",
f"Minimum EPV: {min_entry['epv']:.1f} "
f"(cohort g={min_g[0]}). Threshold: {self.epv_threshold:.0f}.",
"Consider: estimation_method='reg' or fewer covariates.",
"Call results.epv_summary() for per-cohort details.",
"-" * 85,
"",
]
)
# Event study effects if available
if self.event_study_effects:
ci_label = "Simult. CI" if self.cband_crit_value is not None else "Pointwise CI"
lines.extend(
[
"-" * 85,
"Event Study (Dynamic) Effects".center(85),
"-" * 85,
f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
"-" * 85,
]
)
for rel_t in sorted(self.event_study_effects.keys()):
eff = self.event_study_effects[rel_t]
sig = _get_significance_stars(eff["p_value"])
lines.append(
f"{rel_t:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} "
f"{eff['t_stat']:>10.3f} {eff['p_value']:>10.4f} {sig:>6}"
)
lines.extend(["-" * 85])
if self.cband_crit_value is not None:
lines.append(
f"{ci_label}: critical value = {self.cband_crit_value:.4f} "
f"(sup-t bootstrap, {conf_level}% family-wise)"
)
lines.append("")
# Group effects if available
if self.group_effects:
lines.extend(
[
"-" * 85,
"Effects by Treatment Cohort".center(85),
"-" * 85,
f"{'Cohort':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}",
"-" * 85,
]
)
for group in sorted(self.group_effects.keys()):
eff = self.group_effects[group]
sig = _get_significance_stars(eff["p_value"])
lines.append(
f"{group:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} "
f"{eff['t_stat']:>10.3f} {eff['p_value']:>10.4f} {sig:>6}"
)
lines.extend(["-" * 85, ""])
lines.extend(
[
"Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1",
"=" * 85,
]
)
return "\n".join(lines)
[docs]
def epv_summary(self, show_all: bool = False) -> pd.DataFrame:
"""
Return per-cohort EPV diagnostics as a DataFrame.
Parameters
----------
show_all : bool, default False
If False, only show cells with low EPV. If True, show all cells.
Returns
-------
pd.DataFrame
Columns: group, time, epv, n_events, n_params, is_low.
"""
if not self.epv_diagnostics:
return pd.DataFrame(columns=["group", "time", "epv", "n_events", "n_params", "is_low"])
rows = []
for (g, t), diag in sorted(self.epv_diagnostics.items()):
if show_all or diag.get("is_low", False):
rows.append(
{
"group": g,
"time": t,
"epv": diag.get("epv"),
"n_events": diag.get("n_events"),
"n_params": diag.get("k"),
"is_low": diag.get("is_low", False),
}
)
cols = ["group", "time", "epv", "n_events", "n_params", "is_low"]
return pd.DataFrame(rows, columns=cols) if rows else pd.DataFrame(columns=cols)
[docs]
def print_summary(self, alpha: Optional[float] = None) -> None:
"""Print summary to stdout."""
print(self.summary(alpha))
[docs]
def to_dataframe(self, level: str = "group_time") -> pd.DataFrame:
"""
Convert results to DataFrame.
Parameters
----------
level : str, default="group_time"
Level of aggregation: "group_time", "event_study", or "group".
Returns
-------
pd.DataFrame
Results as DataFrame.
"""
if level == "group_time":
rows = []
for (g, t), data in self.group_time_effects.items():
row = {
"group": g,
"time": t,
"effect": data["effect"],
"se": data["se"],
"t_stat": data["t_stat"],
"p_value": data["p_value"],
"conf_int_lower": data["conf_int"][0],
"conf_int_upper": data["conf_int"][1],
}
if self.epv_diagnostics and (g, t) in self.epv_diagnostics:
row["epv"] = self.epv_diagnostics[(g, t)].get("epv")
rows.append(row)
return pd.DataFrame(rows)
elif level == "event_study":
if self.event_study_effects is None:
raise ValueError("Event study effects not computed. Use aggregate='event_study'.")
rows = []
for rel_t, data in sorted(self.event_study_effects.items()):
cband_ci = data.get("cband_conf_int", (np.nan, np.nan))
rows.append(
{
"relative_period": rel_t,
"effect": data["effect"],
"se": data["se"],
"t_stat": data["t_stat"],
"p_value": data["p_value"],
"conf_int_lower": data["conf_int"][0],
"conf_int_upper": data["conf_int"][1],
"cband_lower": cband_ci[0],
"cband_upper": cband_ci[1],
}
)
return pd.DataFrame(rows)
elif level == "group":
if self.group_effects is None:
raise ValueError("Group effects not computed. Use aggregate='group'.")
rows = []
for group, data in sorted(self.group_effects.items()):
rows.append(
{
"group": group,
"effect": data["effect"],
"se": data["se"],
"t_stat": data["t_stat"],
"p_value": data["p_value"],
"conf_int_lower": data["conf_int"][0],
"conf_int_upper": data["conf_int"][1],
}
)
return pd.DataFrame(rows)
else:
raise ValueError(
f"Unknown level: {level}. Use 'group_time', 'event_study', or 'group'."
)
@property
def is_significant(self) -> bool:
"""Check if overall ATT is significant."""
return bool(self.overall_p_value < self.alpha)
@property
def significance_stars(self) -> str:
"""Significance stars for overall ATT."""
return _get_significance_stars(self.overall_p_value)