"""
Efficient Difference-in-Differences estimator.
Implements the ATT estimator from Chen, Sant'Anna & Xie (2025).
Without covariates, achieves the semiparametric efficiency bound via
closed-form within-group covariances. With covariates, uses a doubly
robust path with OLS outcome regression, sieve propensity ratios, and
kernel-smoothed conditional Omega*(X) (see class docstring for caveats).
Under PT-All the model is overidentified and EDiD exploits this for
tighter inference; under PT-Post it reduces to the standard
single-baseline estimator (Callaway-Sant'Anna).
"""
import warnings
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
from diff_diff.efficient_did_bootstrap import (
EDiDBootstrapResults,
EfficientDiDBootstrapMixin,
)
from diff_diff.efficient_did_covariates import (
compute_eif_cov,
compute_generated_outcomes_cov,
compute_omega_star_conditional,
compute_per_unit_weights,
estimate_inverse_propensity_sieve,
estimate_outcome_regression,
estimate_propensity_ratio_sieve,
)
from diff_diff.efficient_did_results import EfficientDiDResults, HausmanPretestResult
from diff_diff.efficient_did_weights import (
compute_efficient_weights,
compute_eif_nocov,
compute_generated_outcomes_nocov,
compute_omega_star_nocov,
enumerate_valid_triples,
)
from diff_diff.utils import safe_inference
# Re-export for convenience
__all__ = ["EfficientDiD", "EfficientDiDResults", "EDiDBootstrapResults"]
def _validate_and_build_cluster_mapping(
df: pd.DataFrame,
unit: str,
cluster: str,
all_units: list,
) -> Tuple[np.ndarray, int]:
"""Validate cluster column and build unit-to-cluster-index mapping.
Checks: column exists, no NaN, per-unit constancy, >= 2 clusters.
Returns (cluster_indices, n_clusters).
"""
if cluster not in df.columns:
raise ValueError(f"Cluster column '{cluster}' not found in data.")
if df[cluster].isna().any():
raise ValueError(f"Cluster column '{cluster}' contains missing values.")
cluster_by_unit = df.groupby(unit)[cluster]
if (cluster_by_unit.nunique() > 1).any():
raise ValueError(
f"Cluster column '{cluster}' varies within unit. "
"Cluster assignment must be constant per unit."
)
cluster_col = cluster_by_unit.first().reindex(all_units).values
unique_clusters = np.unique(cluster_col)
n_clusters = len(unique_clusters)
if n_clusters < 2:
raise ValueError(f"Need at least 2 clusters for cluster-robust SEs, got {n_clusters}.")
cluster_to_idx = {c: i for i, c in enumerate(unique_clusters)}
indices = np.array([cluster_to_idx[c] for c in cluster_col])
return indices, n_clusters
def _cluster_aggregate(
eif_mat: np.ndarray,
cluster_indices: np.ndarray,
n_clusters: int,
) -> np.ndarray:
"""Sum EIF values within clusters and center.
Parameters
----------
eif_mat : ndarray, shape (n_units,) or (n_units, k)
EIF values — 1-D for a single estimand, 2-D for multiple.
cluster_indices : ndarray, shape (n_units,)
Integer cluster assignment per unit.
n_clusters : int
Number of unique clusters.
Returns
-------
ndarray, shape (n_clusters,) or (n_clusters, k)
Centered cluster-level sums.
"""
if eif_mat.ndim == 1:
sums = np.bincount(cluster_indices, weights=eif_mat, minlength=n_clusters).astype(float)
else:
sums = np.column_stack(
[
np.bincount(cluster_indices, weights=eif_mat[:, j], minlength=n_clusters)
for j in range(eif_mat.shape[1])
]
).astype(float)
return sums - sums.mean(axis=0)
def _compute_se_from_eif(
eif: np.ndarray,
n_units: int,
cluster_indices: Optional[np.ndarray] = None,
n_clusters: Optional[int] = None,
) -> float:
"""SE from EIF values, optionally with cluster-robust correction.
Without clusters: ``sqrt(mean(EIF^2) / n)``.
With clusters: Liang-Zeger sandwich — aggregate EIF within clusters,
center, and apply G/(G-1) small-sample correction.
"""
if cluster_indices is not None and n_clusters is not None:
centered = _cluster_aggregate(eif, cluster_indices, n_clusters)
correction = n_clusters / (n_clusters - 1) if n_clusters > 1 else 1.0
var = correction * np.sum(centered**2) / (n_units**2)
return float(np.sqrt(max(var, 0.0)))
return float(np.sqrt(np.mean(eif**2) / n_units))
[docs]
class EfficientDiD(EfficientDiDBootstrapMixin):
"""Efficient DiD estimator (Chen, Sant'Anna & Xie 2025).
Without covariates, achieves the semiparametric efficiency bound for
ATT(g,t) using a closed-form estimator based on within-group sample
means and covariances.
With covariates, uses a doubly robust path: sieve-based propensity
score ratios (Eq 4.1-4.2), OLS outcome regression, sieve-estimated
inverse propensities (algorithm step 4), and kernel-smoothed
conditional Omega*(X) with per-unit efficient weights (Eq 3.12).
The DR property ensures consistency if either the OLS outcome model
or the sieve propensity ratio is correctly specified. The OLS
working model for outcome regressions does not generically guarantee
the semiparametric efficiency bound (see REGISTRY.md).
Parameters
----------
pt_assumption : str, default ``"all"``
Parallel trends variant: ``"all"`` (overidentified, uses all
pre-treatment periods and comparison groups) or ``"post"``
(just-identified, single baseline, equivalent to CS).
alpha : float, default 0.05
Significance level.
cluster : str or None
Column name for cluster-robust SEs. When set, analytical SEs
use the Liang-Zeger clustered sandwich estimator on EIF values.
With ``n_bootstrap > 0``, bootstrap weights are generated at the
cluster level (all units in a cluster share the same weight).
control_group : str, default ``"never_treated"``
Which units serve as the comparison group:
``"never_treated"`` requires a never-treated cohort (raises if
none exist); ``"last_cohort"`` reclassifies the latest treatment
cohort as pseudo-never-treated and drops post-treatment periods
for that cohort. Distinct from CallawaySantAnna's
``"not_yet_treated"`` — see REGISTRY.md for details.
n_bootstrap : int, default 0
Number of multiplier bootstrap iterations (0 = analytical only).
bootstrap_weights : str, default ``"rademacher"``
Bootstrap weight distribution.
seed : int or None
Random seed for reproducibility.
anticipation : int, default 0
Number of anticipation periods (shifts the effective treatment
boundary forward by this amount).
sieve_k_max : int or None
Maximum polynomial degree for sieve ratio estimation. None = auto
(``min(floor(n_gp^{1/5}), 5)``). Only used with covariates.
sieve_criterion : str, default ``"bic"``
Information criterion for sieve degree selection: ``"aic"`` or ``"bic"``.
ratio_clip : float, default 20.0
Clip sieve propensity ratios to ``[1/ratio_clip, ratio_clip]``.
kernel_bandwidth : float or None
Bandwidth for Gaussian kernel in conditional Omega* estimation.
None = Silverman's rule-of-thumb (automatic).
Examples
--------
>>> from diff_diff import EfficientDiD
>>> edid = EfficientDiD(pt_assumption="all")
>>> results = edid.fit(data, outcome="y", unit="id", time="t",
... first_treat="first_treat", aggregate="all")
>>> results.print_summary()
"""
[docs]
def __init__(
self,
pt_assumption: str = "all",
alpha: float = 0.05,
cluster: Optional[str] = None,
control_group: str = "never_treated",
n_bootstrap: int = 0,
bootstrap_weights: str = "rademacher",
seed: Optional[int] = None,
anticipation: int = 0,
sieve_k_max: Optional[int] = None,
sieve_criterion: str = "bic",
ratio_clip: float = 20.0,
kernel_bandwidth: Optional[float] = None,
):
self.pt_assumption = pt_assumption
self.alpha = alpha
self.cluster = cluster
self.control_group = control_group
self.n_bootstrap = n_bootstrap
self.bootstrap_weights = bootstrap_weights
self.seed = seed
self.anticipation = anticipation
self.sieve_k_max = sieve_k_max
self.sieve_criterion = sieve_criterion
self.ratio_clip = ratio_clip
self.kernel_bandwidth = kernel_bandwidth
self.is_fitted_ = False
self.results_: Optional[EfficientDiDResults] = None
self._unit_resolved_survey = None
self._validate_params()
def _validate_params(self) -> None:
"""Validate constrained parameters."""
if self.pt_assumption not in ("all", "post"):
raise ValueError(f"pt_assumption must be 'all' or 'post', got '{self.pt_assumption}'")
if self.control_group not in ("never_treated", "last_cohort"):
raise ValueError(
f"control_group must be 'never_treated' or 'last_cohort', "
f"got '{self.control_group}'"
)
valid_weights = ("rademacher", "mammen", "webb")
if self.bootstrap_weights not in valid_weights:
raise ValueError(
f"bootstrap_weights must be one of {valid_weights}, "
f"got '{self.bootstrap_weights}'"
)
if self.sieve_criterion not in ("aic", "bic"):
raise ValueError(
f"sieve_criterion must be 'aic' or 'bic', got '{self.sieve_criterion}'"
)
if not (np.isfinite(self.ratio_clip) and self.ratio_clip > 1.0):
raise ValueError(f"ratio_clip must be finite and > 1.0, got {self.ratio_clip}")
if self.kernel_bandwidth is not None:
if not (np.isfinite(self.kernel_bandwidth) and self.kernel_bandwidth > 0):
raise ValueError(
f"kernel_bandwidth must be finite and > 0 (or None for auto), "
f"got {self.kernel_bandwidth}"
)
if self.sieve_k_max is not None:
if not (isinstance(self.sieve_k_max, (int, np.integer)) and self.sieve_k_max > 0):
raise ValueError(
f"sieve_k_max must be a positive integer (or None for auto), "
f"got {self.sieve_k_max}"
)
# -- sklearn compatibility ------------------------------------------------
[docs]
def get_params(self) -> Dict[str, Any]:
"""Get estimator parameters (sklearn-compatible)."""
return {
"pt_assumption": self.pt_assumption,
"anticipation": self.anticipation,
"alpha": self.alpha,
"cluster": self.cluster,
"control_group": self.control_group,
"n_bootstrap": self.n_bootstrap,
"bootstrap_weights": self.bootstrap_weights,
"seed": self.seed,
"sieve_k_max": self.sieve_k_max,
"sieve_criterion": self.sieve_criterion,
"ratio_clip": self.ratio_clip,
"kernel_bandwidth": self.kernel_bandwidth,
}
[docs]
def set_params(self, **params: Any) -> "EfficientDiD":
"""Set estimator parameters (sklearn-compatible)."""
for key, value in params.items():
if hasattr(self, key):
setattr(self, key, value)
else:
raise ValueError(f"Unknown parameter: {key}")
self._validate_params()
return self
# -- Main estimation ------------------------------------------------------
[docs]
def fit(
self,
data: pd.DataFrame,
outcome: str,
unit: str,
time: str,
first_treat: str,
covariates: Optional[List[str]] = None,
aggregate: Optional[str] = None,
balance_e: Optional[int] = None,
survey_design: Optional[Any] = None,
store_eif: bool = False,
) -> EfficientDiDResults:
"""Fit the Efficient DiD estimator.
Parameters
----------
data : DataFrame
Balanced panel data.
outcome : str
Outcome variable column name.
unit : str
Unit identifier column name.
time : str
Time period column name.
first_treat : str
Column indicating first treatment period.
Use 0 or ``np.inf`` for never-treated units.
covariates : list of str, optional
Column names for time-invariant unit-level covariates.
When provided, uses the doubly robust path (outcome regression
+ propensity score ratios).
aggregate : str, optional
``None``, ``"simple"``, ``"event_study"``, ``"group"``, or
``"all"``.
balance_e : int, optional
Balance event study at this relative period.
survey_design : SurveyDesign, optional
Survey design specification for design-based inference.
Applies survey weights to all means, covariances, and cohort
fractions, and uses Taylor Series Linearization for SE
estimation. Cannot be combined with ``cluster``.
store_eif : bool, default False
Store per-(g,t) EIF vectors in the results object. Used
internally by :meth:`hausman_pretest`; not needed for
normal usage.
Returns
-------
EfficientDiDResults
Raises
------
ValueError
Missing columns, unbalanced panel, non-absorbing treatment,
or PT-Post without a never-treated group.
"""
self._validate_params()
if self.cluster is not None and survey_design is not None:
raise NotImplementedError(
"cluster and survey_design cannot both be set. "
"Use survey_design with PSU/strata for cluster-robust inference."
)
# Resolve survey design if provided
from diff_diff.survey import _resolve_survey_for_fit
resolved_survey, survey_weights, survey_weight_type, survey_metadata = (
_resolve_survey_for_fit(survey_design, data, "analytical")
)
# Validate within-unit constancy for panel survey designs
if resolved_survey is not None:
from diff_diff.survey import _validate_unit_constant_survey
_validate_unit_constant_survey(data, unit, survey_design)
# Store survey df for safe_inference calls (t-distribution with survey df)
self._survey_df = survey_metadata.df_survey if survey_metadata is not None else None
# Guard: replicate design with undefined df → NaN inference
if (
self._survey_df is None
and resolved_survey is not None
and hasattr(resolved_survey, "uses_replicate_variance")
and resolved_survey.uses_replicate_variance
):
self._survey_df = 0
# Bootstrap + survey supported via PSU-level multiplier bootstrap.
# Normalize empty covariates list to None (use nocov path)
if covariates is not None and len(covariates) == 0:
covariates = None
use_covariates = covariates is not None
# ----- Validate inputs -----
required_cols = [outcome, unit, time, first_treat]
missing = [c for c in required_cols if c not in data.columns]
if missing:
raise ValueError(f"Missing columns: {missing}")
df = data.copy()
df[time] = pd.to_numeric(df[time])
df[first_treat] = pd.to_numeric(df[first_treat])
# Normalize never-treated: inf -> 0 internally, keep track
df["_never_treated"] = (df[first_treat] == 0) | (df[first_treat] == np.inf)
df.loc[df[first_treat] == np.inf, first_treat] = 0
time_periods = sorted(df[time].unique())
treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])
# Validate balanced panel
unit_period_counts = df.groupby(unit)[time].nunique()
n_periods = len(time_periods)
if (unit_period_counts != n_periods).any():
raise ValueError(
"Unbalanced panel detected. EfficientDiD requires a balanced "
"panel where every unit is observed in every time period."
)
# Reject non-finite outcomes (NaN/Inf corrupt Omega*/EIF calculations)
non_finite_mask = ~np.isfinite(df[outcome])
if non_finite_mask.any():
n_bad = int(non_finite_mask.sum())
raise ValueError(
f"Found {n_bad} non-finite value(s) in outcome column '{outcome}'. "
"EfficientDiD requires finite outcomes for all unit-period observations."
)
# Reject duplicate (unit, time) rows
dup_mask = df.duplicated(subset=[unit, time], keep=False)
if dup_mask.any():
n_dups = int(dup_mask.sum())
raise ValueError(
f"Found {n_dups} duplicate ({unit}, {time}) rows. "
"EfficientDiD requires exactly one observation per unit-period."
)
# Validate absorbing treatment (vectorized)
ft_nunique = df.groupby(unit)[first_treat].nunique()
bad_units = ft_nunique[ft_nunique > 1]
if len(bad_units) > 0:
uid = bad_units.index[0]
raise ValueError(
f"Non-absorbing treatment detected for unit {uid}: "
"first_treat value changes over time."
)
# Unit info
unit_info = (
df.groupby(unit)
.agg(
{
first_treat: "first",
"_never_treated": "first",
}
)
.reset_index()
)
n_treated_units = int((unit_info[first_treat] > 0).sum())
n_control_units = int(unit_info["_never_treated"].sum())
# Control group logic
if self.control_group == "last_cohort":
# Always reclassify last cohort as pseudo-control when requested
if not treatment_groups:
raise ValueError(
"No treated cohorts found. control_group='last_cohort' requires "
"at least 2 treatment cohorts."
)
last_g = max(treatment_groups)
treatment_groups = [g for g in treatment_groups if g != last_g]
if not treatment_groups:
raise ValueError("Only one treatment cohort; cannot use last_cohort control.")
effective_last = last_g - self.anticipation
time_periods = [t for t in time_periods if t < effective_last]
if len(time_periods) < 2:
raise ValueError(
"Fewer than 2 time periods remain after trimming for last_cohort control."
)
unit_info.loc[unit_info[first_treat] == last_g, first_treat] = 0
unit_info.loc[unit_info[first_treat] == 0, "_never_treated"] = True
n_treated_units = int((unit_info[first_treat] > 0).sum())
n_control_units = int(unit_info["_never_treated"].sum())
elif n_control_units == 0:
raise ValueError(
"No never-treated units found. Use control_group='last_cohort' "
"to use the last treatment cohort as a pseudo-control."
)
# ----- Prepare data -----
all_units = sorted(df[unit].unique())
n_units = len(all_units)
# Build unit-to-first-panel-row index aligned to all_units (sorted)
# order. The previous approach (groupby cumcount == 0) yielded
# first-appearance order which can differ from sorted order when the
# input DataFrame is not pre-sorted by unit.
first_pos: Dict[Any, int] = {}
for i, u in enumerate(df[unit].values):
if u not in first_pos:
first_pos[u] = i
self._unit_first_panel_row = np.array([first_pos[u] for u in all_units])
# Build unit-level ResolvedSurveyDesign once (avoids repeated
# construction in _compute_survey_eif_se and ensures consistent
# unit-level df for safe_inference t-distribution).
if resolved_survey is not None:
from diff_diff.survey import ResolvedSurveyDesign
row_idx = self._unit_first_panel_row
unit_weights_s = resolved_survey.weights[row_idx]
unit_strata = (
resolved_survey.strata[row_idx] if resolved_survey.strata is not None else None
)
unit_psu = resolved_survey.psu[row_idx] if resolved_survey.psu is not None else None
unit_fpc = resolved_survey.fpc[row_idx] if resolved_survey.fpc is not None else None
n_strata_u = len(np.unique(unit_strata)) if unit_strata is not None else 0
n_psu_u = len(np.unique(unit_psu)) if unit_psu is not None else 0
self._unit_resolved_survey = resolved_survey.subset_to_units(
row_idx,
unit_weights_s,
unit_strata,
unit_psu,
unit_fpc,
n_strata_u,
n_psu_u,
)
# Use unit-level df (not panel-level) for t-distribution
self._survey_df = self._unit_resolved_survey.df_survey
# Re-apply replicate guard: undefined df → NaN inference
if self._survey_df is None and self._unit_resolved_survey.uses_replicate_variance:
self._survey_df = 0
else:
self._unit_resolved_survey = None
# Build cluster mapping if cluster-robust SEs requested
if self.cluster is not None:
unit_cluster_indices, n_clusters = _validate_and_build_cluster_mapping(
df, unit, self.cluster, all_units
)
if n_clusters < 50:
warnings.warn(
f"Only {n_clusters} clusters. Analytical clustered SEs may "
"be unreliable. Consider n_bootstrap > 0 for cluster "
"bootstrap inference.",
UserWarning,
stacklevel=2,
)
else:
unit_cluster_indices = None
n_clusters = None
period_to_col = {p: i for i, p in enumerate(time_periods)}
period_1 = time_periods[0]
period_1_col = period_to_col[period_1]
# Pivot outcome to wide matrix (n_units, n_periods)
pivot = df.pivot(index=unit, columns=time, values=outcome)
# Reindex to match all_units ordering and time_periods column order
pivot = pivot.reindex(index=all_units, columns=time_periods)
outcome_wide = pivot.values.astype(float)
# Build cohort masks and fractions
unit_info_indexed = unit_info.set_index(unit)
unit_cohorts = unit_info_indexed.reindex(all_units)[first_treat].values.astype(
float
) # 0 = never-treated
cohort_masks: Dict[float, np.ndarray] = {}
for g in treatment_groups:
cohort_masks[g] = unit_cohorts == g
never_treated_mask = unit_cohorts == 0
cohort_masks[np.inf] = never_treated_mask # also keyed by inf sentinel
# ----- Unit-level survey weights -----
# Survey weights in the panel are at obs level (unit x time).
# EfficientDiD works at unit level. Extract one weight per unit
# by taking the first observation per unit (balanced panel, so
# weights should be constant within unit).
unit_level_weights: Optional[np.ndarray] = None
if resolved_survey is not None:
# Use the resolved survey's weights (already normalized per weight_type)
# subset to unit level via _unit_first_panel_row (aligned to all_units)
unit_level_weights = self._unit_resolved_survey.weights
self._unit_level_weights = unit_level_weights
cohort_fractions: Dict[float, float] = {}
if unit_level_weights is not None:
# Survey-weighted cohort fractions: sum(w_i for i in cohort) / sum(w_i)
total_w = float(np.sum(unit_level_weights))
for g in treatment_groups:
cohort_fractions[g] = float(np.sum(unit_level_weights[cohort_masks[g]])) / total_w
cohort_fractions[np.inf] = (
float(np.sum(unit_level_weights[never_treated_mask])) / total_w
)
else:
for g in treatment_groups:
cohort_fractions[g] = float(np.sum(cohort_masks[g])) / n_units
cohort_fractions[np.inf] = float(np.sum(never_treated_mask)) / n_units
# ----- Small cohort warnings -----
for g in treatment_groups:
n_g = int(np.sum(cohort_masks[g]))
frac_g = cohort_fractions[g]
if n_g < 2:
warnings.warn(
f"Cohort {g} has only {n_g} unit. Omega* inversion and "
"EIF computation may be numerically unstable.",
UserWarning,
stacklevel=2,
)
elif frac_g < 0.01:
warnings.warn(
f"Cohort {g} represents {frac_g:.1%} of the sample (< 1%). "
"Efficient weights may be imprecise.",
UserWarning,
stacklevel=2,
)
# Guard: never-treated with zero survey weight → no valid comparisons
# Applies to both covariates (DR nuisance) and nocov (weighted means) paths
if cohort_fractions.get(np.inf, 0.0) <= 0 and unit_level_weights is not None:
raise ValueError(
"Never-treated group has zero survey weight. EfficientDiD "
"requires a never-treated control group with positive "
"survey weight for estimation."
)
# ----- Covariate preparation (if provided) -----
covariate_matrix: Optional[np.ndarray] = None
m_hat_cache: Dict[Tuple, np.ndarray] = {}
r_hat_cache: Dict[Tuple[float, float], np.ndarray] = {}
s_hat_cache: Dict[float, np.ndarray] = {} # inverse propensities per group
if use_covariates:
assert covariates is not None # for type narrowing
# Validate covariate columns exist
missing_cov = [c for c in covariates if c not in data.columns]
if missing_cov:
raise ValueError(f"Missing covariate columns: {missing_cov}")
# Validate no NaN/Inf in covariates
for col_name in covariates:
non_finite_cov = ~np.isfinite(pd.to_numeric(df[col_name], errors="coerce"))
if non_finite_cov.any():
n_bad = int(non_finite_cov.sum())
raise ValueError(
f"Found {n_bad} non-finite value(s) in covariate column "
f"'{col_name}'. Covariates must be finite."
)
# Validate time-invariance: covariates must be constant within each unit
for col_name in covariates:
cov_nunique = df.groupby(unit)[col_name].nunique()
varying = cov_nunique[cov_nunique > 1]
if len(varying) > 0:
uid = varying.index[0]
raise ValueError(
f"Covariate '{col_name}' varies over time for unit {uid}. "
"EfficientDiD requires time-invariant covariates. "
"Extract base-period values before calling fit()."
)
# Extract unit-level covariate matrix from period_1 observations
base_df = df[df[time] == period_1].set_index(unit).reindex(all_units)
covariate_matrix = base_df[list(covariates)].values.astype(float)
# ----- Core estimation: ATT(g, t) for each target -----
# Precompute per-group unit counts (avoid repeated np.sum in loop)
n_treated_per_g = {g: int(np.sum(cohort_masks[g])) for g in treatment_groups}
n_control_count = int(np.sum(never_treated_mask))
group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]] = {}
eif_by_gt: Dict[Tuple[Any, Any], np.ndarray] = {}
stored_weights: Dict[Tuple[Any, Any], np.ndarray] = {}
stored_cond: Dict[Tuple[Any, Any], float] = {}
for g in treatment_groups:
# Under PT-Post, use per-group baseline Y_{g-1-anticipation}
# instead of the universal Y_1. This implements the weaker
# PT-Post assumption (parallel trends only from g-1 onward),
# matching the Callaway-Sant'Anna estimator exactly.
if self.pt_assumption == "post":
effective_base = g - 1 - self.anticipation
if effective_base not in period_to_col:
warnings.warn(
f"Cohort g={g} dropped: baseline period {effective_base} "
f"(g-1-anticipation) is not in the data.",
UserWarning,
stacklevel=2,
)
continue
effective_p1_col = period_to_col[effective_base]
else:
effective_p1_col = period_1_col
# Guard: skip cohorts with zero survey weight (all units zero-weighted)
if cohort_fractions[g] <= 0:
warnings.warn(
f"Cohort {g} has zero survey weight; skipping.",
UserWarning,
stacklevel=2,
)
continue
# Estimate all (g, t) cells including pre-treatment. Under PT-Post,
# pre-treatment cells serve as placebo/pre-trend diagnostics, matching
# the CallawaySantAnna implementation. Users filter to t >= g for
# post-treatment effects; pre-treatment cells are clearly labeled by
# their (g, t) coordinates in the results object.
for t in time_periods:
# Skip period_1 — it's the universal reference baseline,
# not a target period
if t == period_1:
continue
# Enumerate valid comparison pairs
pairs = enumerate_valid_triples(
target_g=g,
treatment_groups=treatment_groups,
time_periods=time_periods,
period_1=period_1,
pt_assumption=self.pt_assumption,
anticipation=self.anticipation,
)
# Filter out comparison pairs with zero survey weight
if unit_level_weights is not None and pairs:
pairs = [
(gp, tpre)
for gp, tpre in pairs
if np.sum(
unit_level_weights[
never_treated_mask if np.isinf(gp) else cohort_masks[gp]
]
)
> 0
]
if not pairs:
warnings.warn(
f"No valid comparison pairs for (g={g}, t={t}). " "ATT will be NaN.",
UserWarning,
stacklevel=2,
)
t_stat, p_val, ci = np.nan, np.nan, (np.nan, np.nan)
group_time_effects[(g, t)] = {
"effect": np.nan,
"se": np.nan,
"t_stat": t_stat,
"p_value": p_val,
"conf_int": ci,
"n_treated": n_treated_per_g[g],
"n_control": n_control_count,
}
eif_by_gt[(g, t)] = np.zeros(n_units)
continue
if use_covariates:
assert covariate_matrix is not None
t_col_val = period_to_col[t]
# Lazily populate nuisance caches for this (g, t)
for gp, tpre in pairs:
tpre_col_val = period_to_col[tpre]
# m_{inf, t, tpre}(X)
key_inf_t = (np.inf, t_col_val, tpre_col_val)
if key_inf_t not in m_hat_cache:
m_hat_cache[key_inf_t] = estimate_outcome_regression(
outcome_wide,
covariate_matrix,
never_treated_mask,
t_col_val,
tpre_col_val,
unit_weights=unit_level_weights,
)
# m_{g', tpre, 1}(X)
key_gp_tpre = (gp, tpre_col_val, effective_p1_col)
if key_gp_tpre not in m_hat_cache:
gp_mask_for_reg = (
never_treated_mask if np.isinf(gp) else cohort_masks[gp]
)
m_hat_cache[key_gp_tpre] = estimate_outcome_regression(
outcome_wide,
covariate_matrix,
gp_mask_for_reg,
tpre_col_val,
effective_p1_col,
unit_weights=unit_level_weights,
)
# r_{g, inf}(X) and r_{g, g'}(X) via sieve (Eq 4.1-4.2)
for comp in {np.inf, gp}:
rkey = (g, comp)
if rkey not in r_hat_cache:
comp_mask = (
never_treated_mask if np.isinf(comp) else cohort_masks[comp]
)
r_hat_cache[rkey] = estimate_propensity_ratio_sieve(
covariate_matrix,
cohort_masks[g],
comp_mask,
k_max=self.sieve_k_max,
criterion=self.sieve_criterion,
ratio_clip=self.ratio_clip,
unit_weights=unit_level_weights,
)
# Per-unit DR generated outcomes: shape (n_units, H)
gen_out = compute_generated_outcomes_cov(
target_g=g,
target_t=t,
valid_pairs=pairs,
outcome_wide=outcome_wide,
cohort_masks=cohort_masks,
never_treated_mask=never_treated_mask,
period_to_col=period_to_col,
period_1_col=effective_p1_col,
cohort_fractions=cohort_fractions,
m_hat_cache=m_hat_cache,
r_hat_cache=r_hat_cache,
)
y_hat = np.mean(gen_out, axis=0) # shape (H,)
# Inverse propensity estimation (algorithm step 4)
# s_hat_{g'}(X) = 1/p_{g'}(X) for Eq 3.12 scaling
for group_id in {g, np.inf} | {gp for gp, _ in pairs}:
if group_id not in s_hat_cache:
group_mask_s = (
never_treated_mask if np.isinf(group_id) else cohort_masks[group_id]
)
s_hat_cache[group_id] = estimate_inverse_propensity_sieve(
covariate_matrix,
group_mask_s,
k_max=self.sieve_k_max,
criterion=self.sieve_criterion,
unit_weights=unit_level_weights,
)
# Conditional Omega*(X) with per-unit propensities (Eq 3.12)
omega_cond = compute_omega_star_conditional(
target_g=g,
target_t=t,
valid_pairs=pairs,
outcome_wide=outcome_wide,
cohort_masks=cohort_masks,
never_treated_mask=never_treated_mask,
period_to_col=period_to_col,
period_1_col=effective_p1_col,
cohort_fractions=cohort_fractions,
covariate_matrix=covariate_matrix,
s_hat_cache=s_hat_cache,
bandwidth=self.kernel_bandwidth,
unit_weights=unit_level_weights,
)
# Per-unit weights: (n_units, H)
per_unit_w = compute_per_unit_weights(omega_cond)
# ATT = (survey-)weighted mean of per-unit DR scores
if per_unit_w.shape[1] > 0:
per_unit_scores = np.sum(per_unit_w * gen_out, axis=1)
if unit_level_weights is not None:
att_gt = float(np.average(per_unit_scores, weights=unit_level_weights))
else:
att_gt = float(np.mean(per_unit_scores))
else:
att_gt = np.nan
# EIF with per-unit weights (Remark 4.2: plug-in valid)
# Center on scalar ATT, not per-pair means (ensures mean(EIF) ≈ 0)
eif_vals = compute_eif_cov(per_unit_w, gen_out, att_gt, n_units)
eif_by_gt[(g, t)] = eif_vals
else:
# No-covariates path (closed-form)
omega = compute_omega_star_nocov(
target_g=g,
target_t=t,
valid_pairs=pairs,
outcome_wide=outcome_wide,
cohort_masks=cohort_masks,
never_treated_mask=never_treated_mask,
period_to_col=period_to_col,
period_1_col=effective_p1_col,
cohort_fractions=cohort_fractions,
unit_weights=unit_level_weights,
)
weights, _, cond_num = compute_efficient_weights(omega)
stored_weights[(g, t)] = weights
if omega.size > 0:
stored_cond[(g, t)] = cond_num
y_hat = compute_generated_outcomes_nocov(
target_g=g,
target_t=t,
valid_pairs=pairs,
outcome_wide=outcome_wide,
cohort_masks=cohort_masks,
never_treated_mask=never_treated_mask,
period_to_col=period_to_col,
period_1_col=effective_p1_col,
unit_weights=unit_level_weights,
)
att_gt = float(weights @ y_hat) if len(weights) > 0 else np.nan
eif_vals = compute_eif_nocov(
target_g=g,
target_t=t,
weights=weights,
valid_pairs=pairs,
outcome_wide=outcome_wide,
cohort_masks=cohort_masks,
never_treated_mask=never_treated_mask,
period_to_col=period_to_col,
period_1_col=effective_p1_col,
cohort_fractions=cohort_fractions,
n_units=n_units,
unit_weights=unit_level_weights,
)
eif_by_gt[(g, t)] = eif_vals
# Analytical SE = sqrt(mean(EIF^2) / n) [paper p.21]
# With survey: use TSL variance via compute_survey_vcov
if self._unit_resolved_survey is not None:
se_gt = self._compute_survey_eif_se(eif_vals)
else:
se_gt = _compute_se_from_eif(
eif_vals, n_units, unit_cluster_indices, n_clusters
)
t_stat, p_val, ci = safe_inference(
att_gt, se_gt, alpha=self.alpha, df=self._survey_df
)
group_time_effects[(g, t)] = {
"effect": att_gt,
"se": se_gt,
"t_stat": t_stat,
"p_value": p_val,
"conf_int": ci,
"n_treated": int(np.sum(cohort_masks[g])),
"n_control": int(np.sum(never_treated_mask)),
}
if not group_time_effects:
raise ValueError(
"Could not estimate any group-time effects. "
"Check data has sufficient observations."
)
# ----- Aggregation -----
overall_att, overall_se = self._aggregate_overall(
group_time_effects,
eif_by_gt,
n_units,
cohort_fractions,
unit_cohorts,
cluster_indices=unit_cluster_indices,
n_clusters=n_clusters,
)
overall_t, overall_p, overall_ci = safe_inference(
overall_att, overall_se, alpha=self.alpha, df=self._survey_df
)
event_study_effects = None
group_effects = None
if aggregate in ("event_study", "all"):
event_study_effects = self._aggregate_event_study(
group_time_effects,
eif_by_gt,
n_units,
cohort_fractions,
treatment_groups,
time_periods,
balance_e,
unit_cohorts=unit_cohorts,
cluster_indices=unit_cluster_indices,
n_clusters=n_clusters,
)
if aggregate in ("group", "all"):
group_effects = self._aggregate_by_group(
group_time_effects,
eif_by_gt,
n_units,
cohort_fractions,
treatment_groups,
unit_cohorts=unit_cohorts,
cluster_indices=unit_cluster_indices,
n_clusters=n_clusters,
)
# ----- Bootstrap -----
# Reject replicate-weight designs for bootstrap — replicate variance
# is an analytical alternative, not compatible with bootstrap
if (
self.n_bootstrap > 0
and self._unit_resolved_survey is not None
and self._unit_resolved_survey.uses_replicate_variance
):
raise NotImplementedError(
"EfficientDiD bootstrap (n_bootstrap > 0) is not supported "
"with replicate-weight survey designs. Replicate weights provide "
"analytical variance; use n_bootstrap=0 instead."
)
bootstrap_results = None
if self.n_bootstrap > 0 and eif_by_gt:
bootstrap_results = self._run_multiplier_bootstrap(
group_time_effects=group_time_effects,
eif_by_gt=eif_by_gt,
n_units=n_units,
aggregate=aggregate,
balance_e=balance_e,
treatment_groups=treatment_groups,
cohort_fractions=cohort_fractions,
cluster_indices=unit_cluster_indices,
n_clusters=n_clusters,
resolved_survey=self._unit_resolved_survey,
unit_level_weights=self._unit_level_weights,
)
# Update estimates with bootstrap inference
overall_se = bootstrap_results.overall_att_se
overall_t = safe_inference(overall_att, overall_se, alpha=self.alpha)[0]
overall_p = bootstrap_results.overall_att_p_value
overall_ci = bootstrap_results.overall_att_ci
for gt in group_time_effects:
if gt in bootstrap_results.group_time_ses:
group_time_effects[gt]["se"] = bootstrap_results.group_time_ses[gt]
group_time_effects[gt]["conf_int"] = bootstrap_results.group_time_cis[gt]
group_time_effects[gt]["p_value"] = bootstrap_results.group_time_p_values[gt]
eff = float(group_time_effects[gt]["effect"])
se = float(group_time_effects[gt]["se"])
group_time_effects[gt]["t_stat"] = safe_inference(eff, se, alpha=self.alpha)[0]
es_cis = bootstrap_results.event_study_cis
es_pvs = bootstrap_results.event_study_p_values
if (
event_study_effects is not None
and bootstrap_results.event_study_ses is not None
and es_cis is not None
and es_pvs is not None
):
for e in event_study_effects:
if e in bootstrap_results.event_study_ses:
event_study_effects[e]["se"] = bootstrap_results.event_study_ses[e]
event_study_effects[e]["conf_int"] = es_cis[e]
event_study_effects[e]["p_value"] = es_pvs[e]
eff = float(event_study_effects[e]["effect"])
se = float(event_study_effects[e]["se"])
event_study_effects[e]["t_stat"] = safe_inference(
eff, se, alpha=self.alpha
)[0]
g_cis = bootstrap_results.group_effect_cis
g_pvs = bootstrap_results.group_effect_p_values
if (
group_effects is not None
and bootstrap_results.group_effect_ses is not None
and g_cis is not None
and g_pvs is not None
):
for g in group_effects:
if g in bootstrap_results.group_effect_ses:
group_effects[g]["se"] = bootstrap_results.group_effect_ses[g]
group_effects[g]["conf_int"] = g_cis[g]
group_effects[g]["p_value"] = g_pvs[g]
eff = float(group_effects[g]["effect"])
se = float(group_effects[g]["se"])
group_effects[g]["t_stat"] = safe_inference(eff, se, alpha=self.alpha)[0]
# ----- Build results -----
self.results_ = EfficientDiDResults(
group_time_effects=group_time_effects,
overall_att=overall_att,
overall_se=overall_se,
overall_t_stat=overall_t,
overall_p_value=overall_p,
overall_conf_int=overall_ci,
groups=treatment_groups,
time_periods=time_periods,
n_obs=n_units * len(time_periods),
n_treated_units=n_treated_units,
n_control_units=n_control_units,
alpha=self.alpha,
pt_assumption=self.pt_assumption,
anticipation=self.anticipation,
n_bootstrap=self.n_bootstrap,
bootstrap_weights=self.bootstrap_weights,
seed=self.seed,
event_study_effects=event_study_effects,
group_effects=group_effects,
efficient_weights=stored_weights if stored_weights else None,
omega_condition_numbers=stored_cond if stored_cond else None,
control_group=self.control_group,
cluster=self.cluster,
influence_functions=eif_by_gt if store_eif else None,
bootstrap_results=bootstrap_results,
estimation_path="dr" if use_covariates else "nocov",
sieve_k_max=self.sieve_k_max,
sieve_criterion=self.sieve_criterion,
ratio_clip=self.ratio_clip,
kernel_bandwidth=self.kernel_bandwidth,
survey_metadata=(
self._recompute_unit_survey_metadata(survey_metadata)
if survey_metadata is not None
else None
),
)
self.is_fitted_ = True
return self.results_
def _recompute_unit_survey_metadata(self, panel_metadata):
"""Recompute survey metadata from unit-level design if available."""
if self._unit_resolved_survey is not None:
from diff_diff.survey import compute_survey_metadata
meta = compute_survey_metadata(
self._unit_resolved_survey,
self._unit_resolved_survey.weights,
)
# Propagate effective replicate df if available
# (but not the df=0 sentinel — keep metadata as None for undefined df)
if (
self._survey_df is not None
and self._survey_df != 0
and meta.df_survey != self._survey_df
):
meta.df_survey = self._survey_df
return meta
return panel_metadata
# -- Survey SE helpers ----------------------------------------------------
def _compute_survey_eif_se(self, eif_vals: np.ndarray) -> float:
"""Compute SE from EIF scores using Taylor Series Linearization.
Uses the pre-built unit-level ``_unit_resolved_survey`` constructed
once in ``fit()``, ensuring consistent unit-level arrays and
avoiding repeated subsetting of panel-level survey data.
"""
if self._unit_resolved_survey.uses_replicate_variance:
from diff_diff.survey import compute_replicate_if_variance
# Score-scale IFs to match TSL bread: psi = w * eif / sum(w)
w = self._unit_resolved_survey.weights
psi_scaled = w * eif_vals / w.sum()
variance, n_valid = compute_replicate_if_variance(
psi_scaled, self._unit_resolved_survey
)
# Update survey df to reflect effective replicate count
if n_valid < self._unit_resolved_survey.n_replicates:
self._survey_df = n_valid - 1 if n_valid > 1 else None
return float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
from diff_diff.survey import compute_survey_vcov
X_ones = np.ones((len(eif_vals), 1))
vcov = compute_survey_vcov(X_ones, eif_vals, self._unit_resolved_survey)
return float(np.sqrt(np.abs(vcov[0, 0])))
def _eif_se(
self,
eif_vals: np.ndarray,
n_units: int,
cluster_indices: Optional[np.ndarray] = None,
n_clusters: Optional[int] = None,
) -> float:
"""Compute SE from aggregated EIF scores.
Dispatches to survey TSL when ``_unit_resolved_survey`` is set
(during fit), otherwise uses cluster-robust or standard formula.
"""
if self._unit_resolved_survey is not None:
return self._compute_survey_eif_se(eif_vals)
return _compute_se_from_eif(eif_vals, n_units, cluster_indices, n_clusters)
# -- Aggregation helpers --------------------------------------------------
def _compute_wif_contribution(
self,
keepers: List[Tuple],
effects: np.ndarray,
unit_cohorts: np.ndarray,
cohort_fractions: Dict[float, float],
n_units: int,
unit_weights: Optional[np.ndarray] = None,
) -> np.ndarray:
"""Compute weight influence function correction (O(1) scale, matching EIF).
This accounts for uncertainty in cohort-size aggregation weights.
Matches R's ``did`` package WIF formula (staggered_aggregation.py:282-309),
adapted to EDiD's EIF scale.
Parameters
----------
keepers : list of (g, t) tuples
Post-treatment group-time pairs included in aggregation.
effects : ndarray, shape (n_keepers,)
ATT estimates for each keeper.
unit_cohorts : ndarray, shape (n_units,)
Cohort assignment for each unit (0 = never-treated).
cohort_fractions : dict
``{cohort: n_cohort / n}`` for each cohort.
n_units : int
Total number of units.
unit_weights : ndarray, shape (n_units,), optional
Survey weights at the unit level. When provided, uses the
survey-weighted WIF formula: IF_i(p_g) = (w_i * 1{G_i=g} - pg_k).
Returns
-------
ndarray, shape (n_units,)
WIF contribution at O(1) scale, additive with ``agg_eif``.
"""
groups_for_keepers = np.array([g for (g, t) in keepers])
pg_keepers = np.array([cohort_fractions.get(g, 0.0) for g, t in keepers])
sum_pg = pg_keepers.sum()
if sum_pg == 0:
return np.zeros(n_units)
indicator = (unit_cohorts[:, None] == groups_for_keepers[None, :]).astype(float)
if unit_weights is not None:
# Survey-weighted WIF (matches staggered_aggregation.py:392-401):
# IF_i(p_g) = (w_i * 1{G_i=g} - pg_k), NOT (1{G_i=g} - pg_k)
weighted_indicator = indicator * unit_weights[:, None]
indicator_diff = weighted_indicator - pg_keepers
indicator_sum = np.sum(indicator_diff, axis=1)
else:
indicator_diff = indicator - pg_keepers
indicator_sum = np.sum(indicator_diff, axis=1)
with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
if1 = indicator_diff / sum_pg
if2 = np.outer(indicator_sum, pg_keepers) / sum_pg**2
wif_matrix = if1 - if2
wif_contrib = wif_matrix @ effects
return wif_contrib # O(1) scale, same as agg_eif
def _aggregate_overall(
self,
group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]],
eif_by_gt: Dict[Tuple[Any, Any], np.ndarray],
n_units: int,
cohort_fractions: Dict[float, float],
unit_cohorts: np.ndarray,
cluster_indices: Optional[np.ndarray] = None,
n_clusters: Optional[int] = None,
) -> Tuple[float, float]:
"""Compute overall ATT with WIF-adjusted SE.
Parameters
----------
group_time_effects : dict
Group-time ATT estimates.
eif_by_gt : dict
Per-unit EIF values for each (g, t).
n_units : int
Total number of units.
cohort_fractions : dict
Cohort size fractions.
unit_cohorts : ndarray, shape (n_units,)
Cohort assignment for each unit.
"""
# Filter to post-treatment effects
keepers = [
(g, t)
for (g, t) in group_time_effects
if t >= g - self.anticipation and np.isfinite(group_time_effects[(g, t)]["effect"])
]
if not keepers:
return np.nan, np.nan
# Cohort-size weights
pg = np.array([cohort_fractions.get(g, 0.0) for (g, _) in keepers])
total_pg = pg.sum()
if total_pg == 0:
return np.nan, np.nan
w = pg / total_pg
effects = np.array([group_time_effects[gt]["effect"] for gt in keepers])
overall_att = float(np.sum(w * effects))
# Aggregate EIF
agg_eif = np.zeros(n_units)
for k, gt in enumerate(keepers):
agg_eif += w[k] * eif_by_gt[gt]
# WIF correction: accounts for uncertainty in cohort-size weights
wif = self._compute_wif_contribution(
keepers,
effects,
unit_cohorts,
cohort_fractions,
n_units,
unit_weights=self._unit_level_weights,
)
# Compute SE: survey path uses score-level psi to avoid double-weighting
# (compute_survey_vcov applies w_i internally, which would double-weight
# the survey-weighted WIF term). Dispatch replicate vs TSL.
if self._unit_resolved_survey is not None:
uw = self._unit_level_weights
total_w = float(np.sum(uw))
psi_total = uw * agg_eif / total_w + wif / total_w
if (
hasattr(self._unit_resolved_survey, "uses_replicate_variance")
and self._unit_resolved_survey.uses_replicate_variance
):
from diff_diff.survey import compute_replicate_if_variance
variance, _ = compute_replicate_if_variance(psi_total, self._unit_resolved_survey)
else:
from diff_diff.survey import compute_survey_if_variance
variance = compute_survey_if_variance(psi_total, self._unit_resolved_survey)
se = float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
else:
agg_eif_total = agg_eif + wif
se = self._eif_se(agg_eif_total, n_units, cluster_indices, n_clusters)
return overall_att, se
def _aggregate_event_study(
self,
group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]],
eif_by_gt: Dict[Tuple[Any, Any], np.ndarray],
n_units: int,
cohort_fractions: Dict[float, float],
treatment_groups: List[Any],
time_periods: List[Any],
balance_e: Optional[int] = None,
unit_cohorts: Optional[np.ndarray] = None,
cluster_indices: Optional[np.ndarray] = None,
n_clusters: Optional[int] = None,
) -> Dict[int, Dict[str, Any]]:
"""Aggregate ATT(g,t) by relative time e = t - g.
Parameters
----------
group_time_effects : dict
Group-time ATT estimates.
eif_by_gt : dict
Per-unit EIF values for each (g, t).
n_units : int
Total number of units.
cohort_fractions : dict
Cohort size fractions.
treatment_groups : list
Treatment cohort identifiers.
time_periods : list
All time periods.
balance_e : int, optional
Balance event study at this relative period.
unit_cohorts : ndarray, optional
Cohort assignment for each unit (for WIF correction).
"""
# Organize by relative time
effects_by_e: Dict[int, List[Tuple[Tuple[Any, Any], float, float]]] = {}
for (g, t), data in group_time_effects.items():
if not np.isfinite(data["effect"]):
continue
e = int(t - g)
if e not in effects_by_e:
effects_by_e[e] = []
effects_by_e[e].append(((g, t), data["effect"], cohort_fractions.get(g, 0.0)))
# Balance if requested
if balance_e is not None:
groups_at_e = {gt[0] for gt, _, _ in effects_by_e.get(balance_e, [])}
balanced: Dict[int, List[Tuple[Tuple[Any, Any], float, float]]] = {}
for (g, t), data in group_time_effects.items():
if not np.isfinite(data["effect"]):
continue
if g in groups_at_e:
e = int(t - g)
if e not in balanced:
balanced[e] = []
balanced[e].append(((g, t), data["effect"], cohort_fractions.get(g, 0.0)))
effects_by_e = balanced
if balance_e is not None and not effects_by_e:
warnings.warn(
f"balance_e={balance_e}: no cohort has a finite effect at the "
"anchor horizon. Event study will be empty.",
UserWarning,
stacklevel=2,
)
result: Dict[int, Dict[str, Any]] = {}
for e, elist in sorted(effects_by_e.items()):
gt_pairs = [x[0] for x in elist]
effs = np.array([x[1] for x in elist])
pgs = np.array([x[2] for x in elist])
total_pg = pgs.sum()
w = pgs / total_pg if total_pg > 0 else np.ones(len(pgs)) / len(pgs)
agg_eff = float(np.sum(w * effs))
# Aggregate EIF
agg_eif = np.zeros(n_units)
for k, gt in enumerate(gt_pairs):
agg_eif += w[k] * eif_by_gt[gt]
# WIF correction for event-study aggregation
wif_e = np.zeros(n_units)
if unit_cohorts is not None:
es_keepers = [(g, t) for (g, t) in gt_pairs]
es_effects = effs
wif_e = self._compute_wif_contribution(
es_keepers,
es_effects,
unit_cohorts,
cohort_fractions,
n_units,
unit_weights=self._unit_level_weights,
)
if self._unit_resolved_survey is not None:
uw = self._unit_level_weights
total_w = float(np.sum(uw))
psi_total = uw * agg_eif / total_w + wif_e / total_w
if (
hasattr(self._unit_resolved_survey, "uses_replicate_variance")
and self._unit_resolved_survey.uses_replicate_variance
):
from diff_diff.survey import compute_replicate_if_variance
variance, _ = compute_replicate_if_variance(
psi_total, self._unit_resolved_survey
)
else:
from diff_diff.survey import compute_survey_if_variance
variance = compute_survey_if_variance(psi_total, self._unit_resolved_survey)
agg_se = float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
else:
agg_eif = agg_eif + wif_e
agg_se = self._eif_se(agg_eif, n_units, cluster_indices, n_clusters)
t_stat, p_val, ci = safe_inference(
agg_eff, agg_se, alpha=self.alpha, df=self._survey_df
)
result[e] = {
"effect": agg_eff,
"se": agg_se,
"t_stat": t_stat,
"p_value": p_val,
"conf_int": ci,
"n_groups": len(elist),
}
return result
def _aggregate_by_group(
self,
group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]],
eif_by_gt: Dict[Tuple[Any, Any], np.ndarray],
n_units: int,
cohort_fractions: Dict[float, float],
treatment_groups: List[Any],
unit_cohorts: Optional[np.ndarray] = None,
cluster_indices: Optional[np.ndarray] = None,
n_clusters: Optional[int] = None,
) -> Dict[Any, Dict[str, Any]]:
"""Aggregate ATT(g,t) by treatment cohort.
Parameters
----------
group_time_effects : dict
Group-time ATT estimates.
eif_by_gt : dict
Per-unit EIF values for each (g, t).
n_units : int
Total number of units.
cohort_fractions : dict
Cohort size fractions.
treatment_groups : list
Treatment cohort identifiers.
unit_cohorts : ndarray, optional
Cohort assignment for each unit (unused — group aggregation
uses equal weights, not cohort-size weights).
"""
result: Dict[Any, Dict[str, Any]] = {}
for g in treatment_groups:
g_gts = [
(gg, t)
for (gg, t) in group_time_effects
if gg == g
and t >= g - self.anticipation
and np.isfinite(group_time_effects[(gg, t)]["effect"])
]
if not g_gts:
continue
effs = np.array([group_time_effects[gt]["effect"] for gt in g_gts])
w = np.ones(len(effs)) / len(effs)
agg_eff = float(np.sum(w * effs))
agg_eif = np.zeros(n_units)
for k, gt in enumerate(g_gts):
agg_eif += w[k] * eif_by_gt[gt]
agg_se = self._eif_se(agg_eif, n_units, cluster_indices, n_clusters)
t_stat, p_val, ci = safe_inference(
agg_eff, agg_se, alpha=self.alpha, df=self._survey_df
)
result[g] = {
"effect": agg_eff,
"se": agg_se,
"t_stat": t_stat,
"p_value": p_val,
"conf_int": ci,
"n_periods": len(g_gts),
}
return result
[docs]
def summary(self) -> str:
"""Get summary of estimation results."""
if not self.is_fitted_:
raise RuntimeError("Model must be fitted before calling summary()")
assert self.results_ is not None
return self.results_.summary()
[docs]
def print_summary(self) -> None:
"""Print summary to stdout."""
print(self.summary())
# -- Hausman pretest -------------------------------------------------------
[docs]
@classmethod
def hausman_pretest(
cls,
data: pd.DataFrame,
outcome: str,
unit: str,
time: str,
first_treat: str,
covariates: Optional[List[str]] = None,
cluster: Optional[str] = None,
anticipation: int = 0,
control_group: str = "never_treated",
alpha: float = 0.05,
**nuisance_kwargs: Any,
) -> HausmanPretestResult:
"""Hausman pretest for PT-All vs PT-Post (Theorem A.1).
Fits the estimator under both parallel trends assumptions and
compares the results. Under H0 (PT-All holds), both are consistent
but PT-All is more efficient. Rejection suggests PT-All is too
strong; use PT-Post instead.
Parameters
----------
data, outcome, unit, time, first_treat, covariates
Same as :meth:`fit`.
cluster : str, optional
Cluster column for cluster-robust covariance.
anticipation : int
Anticipation periods.
control_group : str
``"never_treated"`` or ``"last_cohort"``.
alpha : float
Significance level for the test.
**nuisance_kwargs
Passed to both fits (e.g. ``sieve_k_max``, ``ratio_clip``).
Returns
-------
HausmanPretestResult
"""
from scipy.stats import chi2
# Fit under both assumptions (analytical SEs only, no bootstrap)
common_kwargs = dict(
cluster=cluster,
control_group=control_group,
anticipation=anticipation,
n_bootstrap=0,
**nuisance_kwargs,
)
fit_kwargs = dict(
data=data,
outcome=outcome,
unit=unit,
time=time,
first_treat=first_treat,
covariates=covariates,
aggregate=None,
)
edid_all = cls(pt_assumption="all", alpha=alpha, **common_kwargs)
result_all = edid_all.fit(**fit_kwargs, store_eif=True)
edid_post = cls(pt_assumption="post", alpha=alpha, **common_kwargs)
result_post = edid_post.fit(**fit_kwargs, store_eif=True)
# Find common (g,t) pairs — PT-Post pairs are a subset of PT-All
common_gts = sorted(
set(result_all.group_time_effects.keys()) & set(result_post.group_time_effects.keys())
)
def _nan_result() -> HausmanPretestResult:
return HausmanPretestResult(
statistic=np.nan,
p_value=np.nan,
df=0,
reject=False,
alpha=alpha,
att_all=result_all.overall_att,
att_post=result_post.overall_att,
recommendation="inconclusive",
gt_details=None,
)
if not common_gts:
return _nan_result()
eif_all = result_all.influence_functions
eif_post = result_post.influence_functions
assert eif_all is not None and eif_post is not None
n_units = len(next(iter(eif_all.values())))
# --- Aggregate to post-treatment ES(e) per Theorem A.1 ---
# Derive cohort fractions from data for proper weights
all_units_list = sorted(data[unit].unique())
unit_cohorts = (
data.groupby(unit)[first_treat].first().reindex(all_units_list).values.astype(float)
)
cohort_fractions: Dict[float, float] = {}
for g in set(result_all.groups) | set(result_post.groups):
cohort_fractions[g] = float(np.sum(unit_cohorts == g)) / n_units
def _aggregate_es(
gt_effects: Dict, eif_dict: Dict, groups: List, ant: int
) -> Dict[int, Tuple[float, np.ndarray]]:
"""Aggregate (g,t) effects to post-treatment ES(e) with WIF-corrected EIF."""
by_e: Dict[int, List[Tuple[Tuple, float, float, np.ndarray]]] = {}
for (g, t), d in gt_effects.items():
e = int(t - g)
if e < -ant:
continue
if not np.isfinite(d["effect"]):
continue
if (g, t) not in eif_dict:
continue
eif_vec = eif_dict[(g, t)]
if not np.all(np.isfinite(eif_vec)):
continue
pg = cohort_fractions.get(g, 0.0)
if e not in by_e:
by_e[e] = []
by_e[e].append(((g, t), d["effect"], pg, eif_vec))
result: Dict[int, Tuple[float, np.ndarray]] = {}
for e, items in by_e.items():
if e < 0:
continue
effs = np.array([x[1] for x in items])
pgs = np.array([x[2] for x in items])
eifs = [x[3] for x in items]
gt_pairs_e = [x[0] for x in items]
total_pg = pgs.sum()
w = pgs / total_pg if total_pg > 0 else np.ones(len(pgs)) / len(pgs)
es_eff = float(np.sum(w * effs))
es_eif = np.zeros(n_units)
for k_idx in range(len(eifs)):
es_eif += w[k_idx] * eifs[k_idx]
# WIF correction for estimated cohort-size weights
groups_e = np.array([g for (g, t) in gt_pairs_e])
pg_e = np.array([cohort_fractions.get(g, 0.0) for g, t in gt_pairs_e])
sum_pg = pg_e.sum()
if sum_pg > 0:
indicator = (unit_cohorts[:, None] == groups_e[None, :]).astype(float)
indicator_sum = np.sum(indicator - pg_e, axis=1)
with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
if1 = (indicator - pg_e) / sum_pg
if2 = np.outer(indicator_sum, pg_e) / sum_pg**2
wif = (if1 - if2) @ effs
es_eif = es_eif + wif
result[e] = (es_eff, es_eif)
return result
es_all = _aggregate_es(
result_all.group_time_effects, eif_all, result_all.groups, anticipation
)
es_post = _aggregate_es(
result_post.group_time_effects, eif_post, result_post.groups, anticipation
)
# Find common post-treatment horizons
common_e = sorted(set(es_all.keys()) & set(es_post.keys()))
if not common_e:
return _nan_result()
delta = np.array([es_post[e][0] - es_all[e][0] for e in common_e])
# Build ES(e)-level EIF matrices
eif_all_mat = np.column_stack([es_all[e][1] for e in common_e])
eif_post_mat = np.column_stack([es_post[e][1] for e in common_e])
# Filter units with non-finite EIF values
row_finite = np.all(np.isfinite(eif_all_mat), axis=1) & np.all(
np.isfinite(eif_post_mat), axis=1
)
cl_idx: Optional[np.ndarray] = None
n_cl: Optional[int] = None
if cluster is not None:
cl_idx, n_cl = _validate_and_build_cluster_mapping(data, unit, cluster, all_units_list)
if not np.all(row_finite):
eif_all_mat = eif_all_mat[row_finite]
eif_post_mat = eif_post_mat[row_finite]
n_units = int(np.sum(row_finite))
if cl_idx is not None:
cl_idx = cl_idx[row_finite]
# Recompute effective cluster count and remap to contiguous
# indices — entire clusters may have been dropped by filtering
unique_cl, cl_idx = np.unique(cl_idx, return_inverse=True)
n_cl = len(unique_cl)
# Compute full covariance matrices
if cl_idx is not None and n_cl is not None:
def _eif_cov(eif_mat: np.ndarray) -> np.ndarray:
centered = _cluster_aggregate(eif_mat, cl_idx, n_cl)
correction = n_cl / (n_cl - 1) if n_cl > 1 else 1.0
return correction * (centered.T @ centered) / (n_units**2)
cov_all = _eif_cov(eif_all_mat)
cov_post = _eif_cov(eif_post_mat)
else:
with np.errstate(over="ignore", invalid="ignore"):
cov_all = (eif_all_mat.T @ eif_all_mat) / (n_units**2)
cov_post = (eif_post_mat.T @ eif_post_mat) / (n_units**2)
V = cov_post - cov_all
if not np.all(np.isfinite(V)):
warnings.warn(
"Hausman covariance matrix contains non-finite values. " "The test is unreliable.",
UserWarning,
stacklevel=2,
)
return _nan_result()
# Eigendecompose V — check for non-PSD
eigvals = np.linalg.eigvalsh(V)
max_eigval = np.max(np.abs(eigvals)) if len(eigvals) > 0 else 0.0
tol = max(1e-10 * max_eigval, 1e-15)
n_negative = int(np.sum(eigvals < -tol))
if n_negative > 0:
warnings.warn(
f"Hausman variance-difference matrix V has {n_negative} "
"substantially negative eigenvalue(s). The test may be "
"unreliable (finite-sample efficiency reversal).",
UserWarning,
stacklevel=2,
)
effective_rank = int(np.sum(eigvals > tol))
if effective_rank == 0:
return _nan_result()
V_pinv = np.linalg.pinv(V, rcond=tol / max_eigval if max_eigval > 0 else 1e-10)
H = float(delta @ V_pinv @ delta)
H = max(H, 0.0)
p_value = float(chi2.sf(H, df=effective_rank))
reject = p_value < alpha
es_details = pd.DataFrame(
{
"relative_period": common_e,
"es_all": [es_all[e][0] for e in common_e],
"es_post": [es_post[e][0] for e in common_e],
"delta": delta,
}
)
return HausmanPretestResult(
statistic=H,
p_value=p_value,
df=effective_rank,
reject=reject,
alpha=alpha,
att_all=result_all.overall_att,
att_post=result_post.overall_att,
recommendation="pt_post" if reject else "pt_all",
gt_details=es_details,
)