"""
Efficient Difference-in-Differences estimator.
Implements the ATT estimator from Chen, Sant'Anna & Xie (2025).
Without covariates, achieves the semiparametric efficiency bound via
closed-form within-group covariances. With covariates, uses a doubly
robust path with sieve outcome regressions, sieve propensity ratios, and
kernel-smoothed conditional Omega*(X) (see class docstring for details).
Under PT-All the model is overidentified and EDiD exploits this for
tighter inference; under PT-Post it reduces to the standard
single-baseline estimator (Callaway-Sant'Anna).
The variance machinery is purely influence-function-based: per-unit EIF
values aggregate via ``sqrt(mean(EIF**2)/n)`` (unclustered, HC1-style),
Liang-Zeger CR1 on cluster-aggregated EIF (under ``cluster=``), or
Taylor Series Linearization on the combined IF (under ``survey_design=``).
Because the per-unit EIF aggregation has no equivalent single design
matrix, analytical-sandwich families ``{classical, hc2, hc2_bm}`` cannot
be defined and the ``vcov_type`` input contract is permanently narrow to
``{"hc1"}`` — see ``docs/methodology/REGISTRY.md`` "IF-based variance
estimators vs analytical-sandwich estimators" for the structural rationale.
"""
import warnings
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
from diff_diff.efficient_did_bootstrap import (
EDiDBootstrapResults,
EfficientDiDBootstrapMixin,
)
from diff_diff.efficient_did_covariates import (
compute_eif_cov,
compute_generated_outcomes_cov,
compute_omega_star_conditional,
compute_per_unit_weights,
estimate_inverse_propensity_sieve,
estimate_outcome_regression,
estimate_propensity_ratio_sieve,
)
from diff_diff.efficient_did_results import EfficientDiDResults, HausmanPretestResult
from diff_diff.efficient_did_weights import (
compute_efficient_weights,
compute_eif_nocov,
compute_generated_outcomes_nocov,
compute_omega_star_nocov,
enumerate_valid_triples,
)
from diff_diff.utils import safe_inference
# Re-export for convenience
__all__ = ["EfficientDiD", "EfficientDiDResults", "EDiDBootstrapResults"]
def _validate_and_build_cluster_mapping(
df: pd.DataFrame,
unit: str,
cluster: str,
all_units: list,
) -> Tuple[np.ndarray, int]:
"""Validate cluster column and build unit-to-cluster-index mapping.
Checks: column exists, no NaN, per-unit constancy, >= 2 clusters.
Returns (cluster_indices, n_clusters).
"""
if cluster not in df.columns:
raise ValueError(f"Cluster column '{cluster}' not found in data.")
if df[cluster].isna().any():
raise ValueError(f"Cluster column '{cluster}' contains missing values.")
cluster_by_unit = df.groupby(unit)[cluster]
if (cluster_by_unit.nunique() > 1).any():
raise ValueError(
f"Cluster column '{cluster}' varies within unit. "
"Cluster assignment must be constant per unit."
)
cluster_col = cluster_by_unit.first().reindex(all_units).values
unique_clusters = np.unique(cluster_col)
n_clusters = len(unique_clusters)
if n_clusters < 2:
raise ValueError(f"Need at least 2 clusters for cluster-robust SEs, got {n_clusters}.")
cluster_to_idx = {c: i for i, c in enumerate(unique_clusters)}
indices = np.array([cluster_to_idx[c] for c in cluster_col])
return indices, n_clusters
def _cluster_aggregate(
eif_mat: np.ndarray,
cluster_indices: np.ndarray,
n_clusters: int,
) -> np.ndarray:
"""Sum EIF values within clusters and center.
Parameters
----------
eif_mat : ndarray, shape (n_units,) or (n_units, k)
EIF values — 1-D for a single estimand, 2-D for multiple.
cluster_indices : ndarray, shape (n_units,)
Integer cluster assignment per unit.
n_clusters : int
Number of unique clusters.
Returns
-------
ndarray, shape (n_clusters,) or (n_clusters, k)
Centered cluster-level sums.
"""
if eif_mat.ndim == 1:
sums = np.bincount(cluster_indices, weights=eif_mat, minlength=n_clusters).astype(float)
else:
sums = np.column_stack(
[
np.bincount(cluster_indices, weights=eif_mat[:, j], minlength=n_clusters)
for j in range(eif_mat.shape[1])
]
).astype(float)
return sums - sums.mean(axis=0)
def _compute_se_from_eif(
eif: np.ndarray,
n_units: int,
cluster_indices: Optional[np.ndarray] = None,
n_clusters: Optional[int] = None,
) -> float:
"""SE from EIF values, optionally with cluster-robust correction.
Without clusters: ``sqrt(mean(EIF^2) / n)``.
With clusters: Liang-Zeger sandwich — aggregate EIF within clusters,
center, and apply G/(G-1) small-sample correction.
"""
if cluster_indices is not None and n_clusters is not None:
centered = _cluster_aggregate(eif, cluster_indices, n_clusters)
correction = n_clusters / (n_clusters - 1) if n_clusters > 1 else 1.0
var = correction * np.sum(centered**2) / (n_units**2)
return float(np.sqrt(max(var, 0.0)))
return float(np.sqrt(np.mean(eif**2) / n_units))
def _hausman_quadratic_form(
delta: np.ndarray,
cov_post: np.ndarray,
cov_all: np.ndarray,
) -> Tuple[float, int, float, int, bool]:
"""Hausman statistic from the event-study delta and the two ES covariances.
Implements the Theorem A.1 test statistic of Chen, Sant'Anna & Xie (2025,
arXiv:2506.17729v1). The variance-difference matrix is
V = aCov(ES_post) - aCov(ES_all) = cov_post - cov_all
(restricted minus efficient, PSD under H0 because the efficient estimator has
the smaller variance), and the statistic is ``H = delta' V^+ delta`` with
``delta = ES_post - ES_all``. ``V`` is inverted by Moore-Penrose pseudoinverse
and the number of strictly positive eigenvalues is used as the chi-square
degrees of freedom -- a finite-sample safeguard for a non-PSD ``V`` that equals
``|E|`` (the number of post-treatment horizons) when ``V`` is well-conditioned.
Parameters
----------
delta : ndarray, shape (|E|,)
Event-study difference ``ES_post - ES_all`` (restricted minus efficient).
cov_post, cov_all : ndarray, shape (|E|, |E|)
Estimator-scale covariances of the restricted (PT-Post) and efficient
(PT-All) event-study vectors.
Returns
-------
H : float
The Hausman statistic (``max(delta' V^+ delta, 0)``); NaN if ``V`` is
non-finite or has no positive eigenvalues.
effective_rank : int
Number of positive eigenvalues of ``V`` (the chi-square degrees of freedom).
p_value : float
Upper-tail ``chi2(effective_rank)`` p-value; NaN when ``H`` is NaN.
n_negative : int
Number of substantially negative eigenvalues of ``V`` (efficiency-reversal
diagnostic).
finite_ok : bool
False when ``V`` contains non-finite entries.
"""
from scipy.stats import chi2
V = cov_post - cov_all
if not np.all(np.isfinite(V)):
return np.nan, 0, np.nan, 0, False
eigvals = np.linalg.eigvalsh(V)
max_eigval = float(np.max(np.abs(eigvals))) if len(eigvals) > 0 else 0.0
tol = max(1e-10 * max_eigval, 1e-15)
n_negative = int(np.sum(eigvals < -tol))
effective_rank = int(np.sum(eigvals > tol))
if effective_rank == 0:
return np.nan, 0, np.nan, n_negative, True
V_pinv = np.linalg.pinv(V, rcond=tol / max_eigval if max_eigval > 0 else 1e-10)
H = max(float(delta @ V_pinv @ delta), 0.0)
p_value = float(chi2.sf(H, df=effective_rank))
return H, effective_rank, p_value, n_negative, True
[docs]
class EfficientDiD(EfficientDiDBootstrapMixin):
"""Efficient DiD estimator (Chen, Sant'Anna & Xie 2025).
Without covariates, achieves the semiparametric efficiency bound for
ATT(g,t) using a closed-form estimator based on within-group sample
means and covariances.
With covariates, uses a doubly robust path: sieve-based propensity
score ratios (Eq 4.1-4.2), sieve outcome regressions (polynomial
basis, AIC/BIC order selection), sieve-estimated inverse propensities
(algorithm step 4), and kernel-smoothed conditional Omega*(X) with
per-unit efficient weights (Eq 3.12). The DR property ensures
consistency if either the outcome regression or the sieve propensity
ratio is correctly specified; because all nuisances are sieves /
kernel smoothers (the paper's flexible-nuisance specification), the
covariate path attains the semiparametric efficiency bound under the
paper's regularity conditions (see REGISTRY.md).
Parameters
----------
pt_assumption : str, default ``"all"``
Parallel trends variant: ``"all"`` (overidentified, uses all
pre-treatment periods and comparison groups) or ``"post"``
(just-identified, single baseline, equivalent to CS).
alpha : float, default 0.05
Significance level.
cluster : str or None
Column name for cluster-robust SEs. When set, analytical SEs
use the Liang-Zeger clustered sandwich estimator on EIF values.
With ``n_bootstrap > 0``, bootstrap weights are generated at the
cluster level (all units in a cluster share the same weight).
vcov_type : str, default ``"hc1"``
Variance-estimator family. Permanently narrow to ``{"hc1"}`` per
the Chen-Sant'Anna-Xie (2025) IF-based variance — analytical-sandwich
families ``{classical, hc2, hc2_bm}`` and ``conley`` are rejected
at ``__init__`` / ``set_params``. See REGISTRY.md for the
methodology rationale (no single design matrix on which hat-matrix
leverage or Bell-McCaffrey Satterthwaite DOF can be defined).
Use ``cluster=<col>`` for Liang-Zeger CR1 on cluster-aggregated EIF;
use ``survey_design=`` for Taylor Series Linearization on the
combined IF.
control_group : str, default ``"never_treated"``
Which units serve as the comparison group:
``"never_treated"`` requires a never-treated cohort (raises if
none exist); ``"last_cohort"`` reclassifies the latest treatment
cohort as pseudo-never-treated and drops periods at
``t >= last_g - anticipation`` so the pseudo-control's
pre-treatment window excludes anticipation-contaminated periods.
Distinct from CallawaySantAnna's ``"not_yet_treated"`` — see
REGISTRY.md for details.
n_bootstrap : int, default 0
Number of multiplier bootstrap iterations (0 = analytical only).
bootstrap_weights : str, default ``"rademacher"``
Bootstrap weight distribution.
seed : int or None
Random seed for reproducibility.
anticipation : int, default 0
Number of anticipation periods (shifts the effective treatment
boundary forward by this amount). When combined with
``control_group="last_cohort"``, also trims the pseudo-control
period set at ``t >= last_g - anticipation`` (see REGISTRY.md).
sieve_k_max : int or None
Maximum polynomial degree for the covariate-path sieves — the
propensity-ratio, inverse-propensity, AND outcome-regression fits all
use it. None = auto (``floor(n_pos^{1/5})`` over each group's
positive-weight support ``n_pos`` — the raw group size when unweighted —
a growing sieve with no fixed ceiling, bounded by ``n_basis < n_pos``;
zero-weight survey rows do not affect order selection). Only
used with covariates. ``sieve_k_max=1`` forces every covariate-path
sieve (outcome regression and both propensity sieves) to degree 1: it
recovers the pre-sieve linear-OLS *outcome regression* but also
degree-1-constrains the propensity sieves, so it does not reproduce the
exact pre-sieve estimator.
sieve_criterion : str, default ``"bic"``
Information criterion (``"aic"`` or ``"bic"``) for the order selection
of all covariate-path sieves (propensity ratio, inverse propensity, and
outcome regression).
ratio_clip : float, default 20.0
Clip sieve propensity ratios to ``[1/ratio_clip, ratio_clip]``.
kernel_bandwidth : float or None
Bandwidth for Gaussian kernel in conditional Omega* estimation.
None = Silverman's rule-of-thumb (automatic).
Examples
--------
>>> from diff_diff import EfficientDiD
>>> edid = EfficientDiD(pt_assumption="all")
>>> results = edid.fit(data, outcome="y", unit="id", time="t",
... first_treat="first_treat", aggregate="all")
>>> results.print_summary()
"""
[docs]
def __init__(
self,
pt_assumption: str = "all",
alpha: float = 0.05,
cluster: Optional[str] = None,
vcov_type: str = "hc1",
control_group: str = "never_treated",
n_bootstrap: int = 0,
bootstrap_weights: str = "rademacher",
seed: Optional[int] = None,
anticipation: int = 0,
sieve_k_max: Optional[int] = None,
sieve_criterion: str = "bic",
ratio_clip: float = 20.0,
kernel_bandwidth: Optional[float] = None,
):
self.pt_assumption = pt_assumption
self.alpha = alpha
self.cluster = cluster
self.vcov_type = vcov_type
self.control_group = control_group
self.n_bootstrap = n_bootstrap
self.bootstrap_weights = bootstrap_weights
self.seed = seed
self.anticipation = anticipation
self.sieve_k_max = sieve_k_max
self.sieve_criterion = sieve_criterion
self.ratio_clip = ratio_clip
self.kernel_bandwidth = kernel_bandwidth
self.is_fitted_ = False
self.results_: Optional[EfficientDiDResults] = None
self._unit_resolved_survey = None
self._validate_params()
def _validate_params(self) -> None:
"""Validate constrained parameters."""
if self.pt_assumption not in ("all", "post"):
raise ValueError(f"pt_assumption must be 'all' or 'post', got '{self.pt_assumption}'")
if self.control_group not in ("never_treated", "last_cohort"):
raise ValueError(
f"control_group must be 'never_treated' or 'last_cohort', "
f"got '{self.control_group}'"
)
valid_weights = ("rademacher", "mammen", "webb")
if self.bootstrap_weights not in valid_weights:
raise ValueError(
f"bootstrap_weights must be one of {valid_weights}, "
f"got '{self.bootstrap_weights}'"
)
if self.sieve_criterion not in ("aic", "bic"):
raise ValueError(
f"sieve_criterion must be 'aic' or 'bic', got '{self.sieve_criterion}'"
)
if not (np.isfinite(self.ratio_clip) and self.ratio_clip > 1.0):
raise ValueError(f"ratio_clip must be finite and > 1.0, got {self.ratio_clip}")
if self.kernel_bandwidth is not None:
if not (np.isfinite(self.kernel_bandwidth) and self.kernel_bandwidth > 0):
raise ValueError(
f"kernel_bandwidth must be finite and > 0 (or None for auto), "
f"got {self.kernel_bandwidth}"
)
if self.sieve_k_max is not None:
if not (isinstance(self.sieve_k_max, (int, np.integer)) and self.sieve_k_max > 0):
raise ValueError(
f"sieve_k_max must be a positive integer (or None for auto), "
f"got {self.sieve_k_max}"
)
self._validate_vcov_type(self.vcov_type)
@staticmethod
def _validate_vcov_type(vcov_type: str) -> None:
"""Validate ``vcov_type`` against EfficientDiD's narrow IF-based contract.
Permanently accepts ``{"hc1"}`` only — EfficientDiD uses
influence-function-based variance per Chen-Sant'Anna-Xie (2025) achieving
the semiparametric efficiency bound. The per-unit EIF aggregation has no
equivalent single design matrix, so analytical-sandwich families
(``classical``, ``hc2``, ``hc2_bm``) cannot be defined; ``conley`` is
deferred (see TODO.md). Mirrors the narrow-contract pattern in
:class:`ImputationDiD`, :class:`CallawaySantAnna`, and
:class:`TripleDifference`.
"""
_accepted_vcov = {"hc1"}
_if_incompatible_vcov = {"classical", "hc2", "hc2_bm"}
_deferred_vcov = {"conley"}
if vcov_type in _if_incompatible_vcov:
raise ValueError(
f"EfficientDiD(vcov_type={vcov_type!r}) is rejected: "
f"EfficientDiD uses influence-function-based variance per Chen, "
f"Sant'Anna, and Xie (2025) achieving the semiparametric efficiency "
f"bound for ATT(g,t). The per-unit EIF aggregation has no equivalent "
f"single design matrix on which hat matrix leverage or Bell-McCaffrey "
f"Satterthwaite DOF can be defined, so analytical-sandwich families "
f"{{classical, hc2, hc2_bm}} are not paper-prescribed. Use "
f"vcov_type='hc1' (the default) with cluster=<col> for the "
f"Liang-Zeger clustered EIF sandwich estimator."
)
if vcov_type in _deferred_vcov:
raise ValueError(
f"EfficientDiD(vcov_type={vcov_type!r}) is not yet supported: "
f"spatial-HAC composition with EIF aggregation has no reference "
f"implementation today. See TODO.md for the deferred follow-up row. "
f"Use vcov_type='hc1' (the default) with cluster=<col> for "
f"cluster-robust inference."
)
if vcov_type not in _accepted_vcov:
raise ValueError(
f"EfficientDiD(vcov_type={vcov_type!r}) is invalid. "
f"Accepted: {sorted(_accepted_vcov)}."
)
# -- sklearn compatibility ------------------------------------------------
[docs]
def get_params(self) -> Dict[str, Any]:
"""Get estimator parameters (sklearn-compatible)."""
return {
"pt_assumption": self.pt_assumption,
"anticipation": self.anticipation,
"alpha": self.alpha,
"cluster": self.cluster,
"vcov_type": self.vcov_type,
"control_group": self.control_group,
"n_bootstrap": self.n_bootstrap,
"bootstrap_weights": self.bootstrap_weights,
"seed": self.seed,
"sieve_k_max": self.sieve_k_max,
"sieve_criterion": self.sieve_criterion,
"ratio_clip": self.ratio_clip,
"kernel_bandwidth": self.kernel_bandwidth,
}
[docs]
def set_params(self, **params: Any) -> "EfficientDiD":
"""Set estimator parameters (sklearn-compatible).
Atomic: snapshots the original attribute values before applying
mutations, validates the new state via ``_validate_params``, and
rolls every attribute back to its pre-call value if validation
raises. Without this, ``set_params(vcov_type="classical",
alpha=0.1)`` would leave ``self.vcov_type`` partially mutated
even though the call raised, defeating the eager-validation
contract for callers that catch ``ValueError`` and keep using
the estimator.
"""
snapshot: Dict[str, Any] = {}
for key in params:
if not hasattr(self, key):
raise ValueError(f"Unknown parameter: {key}")
snapshot[key] = getattr(self, key)
for key, value in params.items():
setattr(self, key, value)
try:
self._validate_params()
except Exception:
for key, value in snapshot.items():
setattr(self, key, value)
raise
return self
# -- Main estimation ------------------------------------------------------
[docs]
def fit(
self,
data: pd.DataFrame,
outcome: str,
unit: str,
time: str,
first_treat: str,
covariates: Optional[List[str]] = None,
aggregate: Optional[str] = None,
balance_e: Optional[int] = None,
survey_design: Optional[Any] = None,
store_eif: bool = False,
) -> EfficientDiDResults:
"""Fit the Efficient DiD estimator.
Parameters
----------
data : DataFrame
Balanced panel data.
outcome : str
Outcome variable column name.
unit : str
Unit identifier column name.
time : str
Time period column name.
first_treat : str
Column indicating first treatment period.
Use 0 or ``np.inf`` for never-treated units.
covariates : list of str, optional
Column names for time-invariant unit-level covariates.
When provided, uses the doubly robust path (outcome regression
+ propensity score ratios).
aggregate : str, optional
``None``, ``"simple"``, ``"event_study"``, ``"group"``, or
``"all"``.
balance_e : int, optional
Balance event study at this relative period.
survey_design : SurveyDesign, optional
Survey design specification for design-based inference.
Applies survey weights to all means, covariances, and cohort
fractions, and uses Taylor Series Linearization for SE
estimation. Cannot be combined with ``cluster``.
store_eif : bool, default False
Store per-(g,t) EIF vectors in the results object. Used
internally by :meth:`hausman_pretest`; not needed for
normal usage.
Returns
-------
EfficientDiDResults
Raises
------
ValueError
Missing columns, unbalanced panel, non-absorbing treatment,
or PT-Post without a never-treated group.
"""
self._validate_params()
if self.cluster is not None and survey_design is not None:
raise NotImplementedError(
"cluster and survey_design cannot both be set. "
"Use survey_design with PSU/strata for cluster-robust inference."
)
# Resolve survey design if provided
from diff_diff.survey import _resolve_survey_for_fit
resolved_survey, survey_weights, survey_weight_type, survey_metadata = (
_resolve_survey_for_fit(survey_design, data, "analytical")
)
# Validate within-unit constancy for panel survey designs
if resolved_survey is not None:
from diff_diff.survey import _validate_unit_constant_survey
_validate_unit_constant_survey(data, unit, survey_design)
# Store survey df for safe_inference calls (t-distribution with survey df)
self._survey_df = survey_metadata.df_survey if survey_metadata is not None else None
# Guard: replicate design with undefined df → NaN inference
if (
self._survey_df is None
and resolved_survey is not None
and hasattr(resolved_survey, "uses_replicate_variance")
and resolved_survey.uses_replicate_variance
):
self._survey_df = 0
# Bootstrap + survey supported via PSU-level multiplier bootstrap.
# Normalize empty covariates list to None (use nocov path)
if covariates is not None and len(covariates) == 0:
covariates = None
use_covariates = covariates is not None
# ----- Validate inputs -----
required_cols = [outcome, unit, time, first_treat]
missing = [c for c in required_cols if c not in data.columns]
if missing:
raise ValueError(f"Missing columns: {missing}")
df = data.copy()
df[time] = pd.to_numeric(df[time])
df[first_treat] = pd.to_numeric(df[first_treat])
# Normalize never-treated: inf -> 0 internally, keep track
df["_never_treated"] = (df[first_treat] == 0) | (df[first_treat] == np.inf)
df.loc[df[first_treat] == np.inf, first_treat] = 0
time_periods = sorted(df[time].unique())
treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])
# Validate balanced panel
unit_period_counts = df.groupby(unit)[time].nunique()
n_periods = len(time_periods)
if (unit_period_counts != n_periods).any():
raise ValueError(
"Unbalanced panel detected. EfficientDiD requires a balanced "
"panel where every unit is observed in every time period."
)
# Reject non-finite outcomes (NaN/Inf corrupt Omega*/EIF calculations)
non_finite_mask = ~np.isfinite(df[outcome])
if non_finite_mask.any():
n_bad = int(non_finite_mask.sum())
raise ValueError(
f"Found {n_bad} non-finite value(s) in outcome column '{outcome}'. "
"EfficientDiD requires finite outcomes for all unit-period observations."
)
# Reject duplicate (unit, time) rows
dup_mask = df.duplicated(subset=[unit, time], keep=False)
if dup_mask.any():
n_dups = int(dup_mask.sum())
raise ValueError(
f"Found {n_dups} duplicate ({unit}, {time}) rows. "
"EfficientDiD requires exactly one observation per unit-period."
)
# Validate absorbing treatment (vectorized)
ft_nunique = df.groupby(unit)[first_treat].nunique()
bad_units = ft_nunique[ft_nunique > 1]
if len(bad_units) > 0:
uid = bad_units.index[0]
raise ValueError(
f"Non-absorbing treatment detected for unit {uid}: "
"first_treat value changes over time."
)
# Unit info
unit_info = (
df.groupby(unit)
.agg(
{
first_treat: "first",
"_never_treated": "first",
}
)
.reset_index()
)
n_treated_units = int((unit_info[first_treat] > 0).sum())
n_control_units = int(unit_info["_never_treated"].sum())
# Control group logic
if self.control_group == "last_cohort":
# Always reclassify last cohort as pseudo-control when requested
if not treatment_groups:
raise ValueError(
"No treated cohorts found. control_group='last_cohort' requires "
"at least 2 treatment cohorts."
)
last_g = max(treatment_groups)
treatment_groups = [g for g in treatment_groups if g != last_g]
if not treatment_groups:
raise ValueError("Only one treatment cohort; cannot use last_cohort control.")
effective_last = last_g - self.anticipation
time_periods = [t for t in time_periods if t < effective_last]
if len(time_periods) < 2:
raise ValueError(
"Fewer than 2 time periods remain after trimming for last_cohort control."
)
unit_info.loc[unit_info[first_treat] == last_g, first_treat] = 0
unit_info.loc[unit_info[first_treat] == 0, "_never_treated"] = True
n_treated_units = int((unit_info[first_treat] > 0).sum())
n_control_units = int(unit_info["_never_treated"].sum())
elif n_control_units == 0:
raise ValueError(
"No never-treated units found. Use control_group='last_cohort' "
"to use the last treatment cohort as a pseudo-control."
)
# ----- Prepare data -----
all_units = sorted(df[unit].unique())
n_units = len(all_units)
# Build unit-to-first-panel-row index aligned to all_units (sorted)
# order. The previous approach (groupby cumcount == 0) yielded
# first-appearance order which can differ from sorted order when the
# input DataFrame is not pre-sorted by unit.
first_pos: Dict[Any, int] = {}
for i, u in enumerate(df[unit].values):
if u not in first_pos:
first_pos[u] = i
self._unit_first_panel_row = np.array([first_pos[u] for u in all_units])
# Build unit-level ResolvedSurveyDesign once (avoids repeated
# construction in _compute_survey_eif_se and ensures consistent
# unit-level df for safe_inference t-distribution).
if resolved_survey is not None:
row_idx = self._unit_first_panel_row
unit_weights_s = resolved_survey.weights[row_idx]
unit_strata = (
resolved_survey.strata[row_idx] if resolved_survey.strata is not None else None
)
unit_psu = resolved_survey.psu[row_idx] if resolved_survey.psu is not None else None
unit_fpc = resolved_survey.fpc[row_idx] if resolved_survey.fpc is not None else None
n_strata_u = len(np.unique(unit_strata)) if unit_strata is not None else 0
n_psu_u = len(np.unique(unit_psu)) if unit_psu is not None else 0
self._unit_resolved_survey = resolved_survey.subset_to_units(
row_idx,
unit_weights_s,
unit_strata,
unit_psu,
unit_fpc,
n_strata_u,
n_psu_u,
)
# Use unit-level df (not panel-level) for t-distribution
self._survey_df = self._unit_resolved_survey.df_survey
# Re-apply replicate guard: undefined df → NaN inference
if self._survey_df is None and self._unit_resolved_survey.uses_replicate_variance:
self._survey_df = 0
else:
self._unit_resolved_survey = None
# Build cluster mapping if cluster-robust SEs requested
if self.cluster is not None:
unit_cluster_indices, n_clusters = _validate_and_build_cluster_mapping(
df, unit, self.cluster, all_units
)
if n_clusters < 50:
warnings.warn(
f"Only {n_clusters} clusters. Analytical clustered SEs may "
"be unreliable. Consider n_bootstrap > 0 for cluster "
"bootstrap inference.",
UserWarning,
stacklevel=2,
)
else:
unit_cluster_indices = None
n_clusters = None
period_to_col = {p: i for i, p in enumerate(time_periods)}
period_1 = time_periods[0]
period_1_col = period_to_col[period_1]
# Pivot outcome to wide matrix (n_units, n_periods)
pivot = df.pivot(index=unit, columns=time, values=outcome)
# Reindex to match all_units ordering and time_periods column order
pivot = pivot.reindex(index=all_units, columns=time_periods)
outcome_wide = pivot.values.astype(float)
# Build cohort masks and fractions
unit_info_indexed = unit_info.set_index(unit)
unit_cohorts = unit_info_indexed.reindex(all_units)[first_treat].values.astype(
float
) # 0 = never-treated
cohort_masks: Dict[float, np.ndarray] = {}
for g in treatment_groups:
cohort_masks[g] = unit_cohorts == g
never_treated_mask = unit_cohorts == 0
cohort_masks[np.inf] = never_treated_mask # also keyed by inf sentinel
# ----- Unit-level survey weights -----
# Survey weights in the panel are at obs level (unit x time).
# EfficientDiD works at unit level. Extract one weight per unit
# by taking the first observation per unit (balanced panel, so
# weights should be constant within unit).
unit_level_weights: Optional[np.ndarray] = None
if resolved_survey is not None:
# Use the resolved survey's weights (already normalized per weight_type)
# subset to unit level via _unit_first_panel_row (aligned to all_units)
unit_level_weights = self._unit_resolved_survey.weights
self._unit_level_weights = unit_level_weights
cohort_fractions: Dict[float, float] = {}
if unit_level_weights is not None:
# Survey-weighted cohort fractions: sum(w_i for i in cohort) / sum(w_i)
total_w = float(np.sum(unit_level_weights))
for g in treatment_groups:
cohort_fractions[g] = float(np.sum(unit_level_weights[cohort_masks[g]])) / total_w
cohort_fractions[np.inf] = (
float(np.sum(unit_level_weights[never_treated_mask])) / total_w
)
else:
for g in treatment_groups:
cohort_fractions[g] = float(np.sum(cohort_masks[g])) / n_units
cohort_fractions[np.inf] = float(np.sum(never_treated_mask)) / n_units
# ----- Small cohort warnings -----
for g in treatment_groups:
n_g = int(np.sum(cohort_masks[g]))
frac_g = cohort_fractions[g]
if n_g < 2:
warnings.warn(
f"Cohort {g} has only {n_g} unit. Omega* inversion and "
"EIF computation may be numerically unstable.",
UserWarning,
stacklevel=2,
)
elif frac_g < 0.01:
warnings.warn(
f"Cohort {g} represents {frac_g:.1%} of the sample (< 1%). "
"Efficient weights may be imprecise.",
UserWarning,
stacklevel=2,
)
# Guard: never-treated with zero survey weight → no valid comparisons
# Applies to both covariates (DR nuisance) and nocov (weighted means) paths
if cohort_fractions.get(np.inf, 0.0) <= 0 and unit_level_weights is not None:
raise ValueError(
"Never-treated group has zero survey weight. EfficientDiD "
"requires a never-treated control group with positive "
"survey weight for estimation."
)
# ----- Covariate preparation (if provided) -----
covariate_matrix: Optional[np.ndarray] = None
m_hat_cache: Dict[Tuple, np.ndarray] = {}
r_hat_cache: Dict[Tuple[float, float], np.ndarray] = {}
s_hat_cache: Dict[float, np.ndarray] = {} # inverse propensities per group
if use_covariates:
assert covariates is not None # for type narrowing
# Validate covariate columns exist
missing_cov = [c for c in covariates if c not in data.columns]
if missing_cov:
raise ValueError(f"Missing covariate columns: {missing_cov}")
# Validate no NaN/Inf in covariates
for col_name in covariates:
non_finite_cov = ~np.isfinite(pd.to_numeric(df[col_name], errors="coerce"))
if non_finite_cov.any():
n_bad = int(non_finite_cov.sum())
raise ValueError(
f"Found {n_bad} non-finite value(s) in covariate column "
f"'{col_name}'. Covariates must be finite."
)
# Validate time-invariance: covariates must be constant within each unit
for col_name in covariates:
cov_nunique = df.groupby(unit)[col_name].nunique()
varying = cov_nunique[cov_nunique > 1]
if len(varying) > 0:
uid = varying.index[0]
raise ValueError(
f"Covariate '{col_name}' varies over time for unit {uid}. "
"EfficientDiD requires time-invariant covariates. "
"Extract base-period values before calling fit()."
)
# Extract unit-level covariate matrix from period_1 observations
base_df = df[df[time] == period_1].set_index(unit).reindex(all_units)
covariate_matrix = base_df[list(covariates)].values.astype(float)
# ----- Core estimation: ATT(g, t) for each target -----
# Precompute per-group unit counts (avoid repeated np.sum in loop)
n_treated_per_g = {g: int(np.sum(cohort_masks[g])) for g in treatment_groups}
n_control_count = int(np.sum(never_treated_mask))
group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]] = {}
eif_by_gt: Dict[Tuple[Any, Any], np.ndarray] = {}
stored_weights: Dict[Tuple[Any, Any], np.ndarray] = {}
stored_cond: Dict[Tuple[Any, Any], float] = {}
for g in treatment_groups:
# Under PT-Post, use per-group baseline Y_{g-1-anticipation}
# instead of the universal Y_1. This implements the weaker
# PT-Post assumption (parallel trends only from g-1 onward),
# matching the Callaway-Sant'Anna estimator exactly.
if self.pt_assumption == "post":
effective_base = g - 1 - self.anticipation
if effective_base not in period_to_col:
warnings.warn(
f"Cohort g={g} dropped: baseline period {effective_base} "
f"(g-1-anticipation) is not in the data.",
UserWarning,
stacklevel=2,
)
continue
effective_p1_col = period_to_col[effective_base]
else:
effective_p1_col = period_1_col
# Guard: skip cohorts with zero survey weight (all units zero-weighted)
if cohort_fractions[g] <= 0:
warnings.warn(
f"Cohort {g} has zero survey weight; skipping.",
UserWarning,
stacklevel=2,
)
continue
# Estimate all (g, t) cells including pre-treatment. Under PT-Post,
# pre-treatment cells serve as placebo/pre-trend diagnostics, matching
# the CallawaySantAnna implementation. Users filter to t >= g for
# post-treatment effects; pre-treatment cells are clearly labeled by
# their (g, t) coordinates in the results object.
for t in time_periods:
# Skip period_1 — it's the universal reference baseline,
# not a target period
if t == period_1:
continue
# Enumerate valid comparison pairs
pairs = enumerate_valid_triples(
target_g=g,
treatment_groups=treatment_groups,
time_periods=time_periods,
period_1=period_1,
pt_assumption=self.pt_assumption,
anticipation=self.anticipation,
)
# Filter out comparison pairs with zero survey weight
if unit_level_weights is not None and pairs:
pairs = [
(gp, tpre)
for gp, tpre in pairs
if np.sum(
unit_level_weights[
never_treated_mask if np.isinf(gp) else cohort_masks[gp]
]
)
> 0
]
if not pairs:
warnings.warn(
f"No valid comparison pairs for (g={g}, t={t}). " "ATT will be NaN.",
UserWarning,
stacklevel=2,
)
t_stat, p_val, ci = np.nan, np.nan, (np.nan, np.nan)
group_time_effects[(g, t)] = {
"effect": np.nan,
"se": np.nan,
"t_stat": t_stat,
"p_value": p_val,
"conf_int": ci,
"n_treated": n_treated_per_g[g],
"n_control": n_control_count,
}
eif_by_gt[(g, t)] = np.zeros(n_units)
continue
if use_covariates:
assert covariate_matrix is not None
t_col_val = period_to_col[t]
# Lazily populate nuisance caches for this (g, t)
for gp, tpre in pairs:
tpre_col_val = period_to_col[tpre]
# m_{inf, t, tpre}(X)
key_inf_t = (np.inf, t_col_val, tpre_col_val)
if key_inf_t not in m_hat_cache:
m_hat_cache[key_inf_t] = estimate_outcome_regression(
outcome_wide,
covariate_matrix,
never_treated_mask,
t_col_val,
tpre_col_val,
k_max=self.sieve_k_max,
criterion=self.sieve_criterion,
unit_weights=unit_level_weights,
)
# m_{g', tpre, 1}(X)
key_gp_tpre = (gp, tpre_col_val, effective_p1_col)
if key_gp_tpre not in m_hat_cache:
gp_mask_for_reg = (
never_treated_mask if np.isinf(gp) else cohort_masks[gp]
)
m_hat_cache[key_gp_tpre] = estimate_outcome_regression(
outcome_wide,
covariate_matrix,
gp_mask_for_reg,
tpre_col_val,
effective_p1_col,
k_max=self.sieve_k_max,
criterion=self.sieve_criterion,
unit_weights=unit_level_weights,
)
# r_{g, inf}(X) and r_{g, g'}(X) via sieve (Eq 4.1-4.2)
for comp in {np.inf, gp}:
rkey = (g, comp)
if rkey not in r_hat_cache:
comp_mask = (
never_treated_mask if np.isinf(comp) else cohort_masks[comp]
)
r_hat_cache[rkey] = estimate_propensity_ratio_sieve(
covariate_matrix,
cohort_masks[g],
comp_mask,
k_max=self.sieve_k_max,
criterion=self.sieve_criterion,
ratio_clip=self.ratio_clip,
unit_weights=unit_level_weights,
)
# Per-unit DR generated outcomes: shape (n_units, H)
gen_out = compute_generated_outcomes_cov(
target_g=g,
target_t=t,
valid_pairs=pairs,
outcome_wide=outcome_wide,
cohort_masks=cohort_masks,
never_treated_mask=never_treated_mask,
period_to_col=period_to_col,
period_1_col=effective_p1_col,
cohort_fractions=cohort_fractions,
m_hat_cache=m_hat_cache,
r_hat_cache=r_hat_cache,
)
y_hat = np.mean(gen_out, axis=0) # shape (H,)
# Inverse propensity estimation (algorithm step 4)
# s_hat_{g'}(X) = 1/p_{g'}(X) for Eq 3.12 scaling
for group_id in {g, np.inf} | {gp for gp, _ in pairs}:
if group_id not in s_hat_cache:
group_mask_s = (
never_treated_mask if np.isinf(group_id) else cohort_masks[group_id]
)
s_hat_cache[group_id] = estimate_inverse_propensity_sieve(
covariate_matrix,
group_mask_s,
k_max=self.sieve_k_max,
criterion=self.sieve_criterion,
unit_weights=unit_level_weights,
)
# Conditional Omega*(X) with per-unit propensities (Eq 3.12)
omega_cond = compute_omega_star_conditional(
target_g=g,
target_t=t,
valid_pairs=pairs,
outcome_wide=outcome_wide,
cohort_masks=cohort_masks,
never_treated_mask=never_treated_mask,
period_to_col=period_to_col,
period_1_col=effective_p1_col,
cohort_fractions=cohort_fractions,
covariate_matrix=covariate_matrix,
s_hat_cache=s_hat_cache,
bandwidth=self.kernel_bandwidth,
unit_weights=unit_level_weights,
)
# Per-unit weights: (n_units, H)
per_unit_w = compute_per_unit_weights(omega_cond)
# ATT = (survey-)weighted mean of per-unit DR scores
if per_unit_w.shape[1] > 0:
per_unit_scores = np.sum(per_unit_w * gen_out, axis=1)
if unit_level_weights is not None:
att_gt = float(np.average(per_unit_scores, weights=unit_level_weights))
else:
att_gt = float(np.mean(per_unit_scores))
else:
att_gt = np.nan
# EIF with per-unit weights (Remark 4.2: plug-in valid)
# Center on scalar ATT, not per-pair means (ensures mean(EIF) ≈ 0)
eif_vals = compute_eif_cov(per_unit_w, gen_out, att_gt, n_units)
eif_by_gt[(g, t)] = eif_vals
else:
# No-covariates path (closed-form)
omega = compute_omega_star_nocov(
target_g=g,
target_t=t,
valid_pairs=pairs,
outcome_wide=outcome_wide,
cohort_masks=cohort_masks,
never_treated_mask=never_treated_mask,
period_to_col=period_to_col,
period_1_col=effective_p1_col,
cohort_fractions=cohort_fractions,
unit_weights=unit_level_weights,
)
weights, _, cond_num = compute_efficient_weights(omega)
stored_weights[(g, t)] = weights
if omega.size > 0:
stored_cond[(g, t)] = cond_num
y_hat = compute_generated_outcomes_nocov(
target_g=g,
target_t=t,
valid_pairs=pairs,
outcome_wide=outcome_wide,
cohort_masks=cohort_masks,
never_treated_mask=never_treated_mask,
period_to_col=period_to_col,
period_1_col=effective_p1_col,
unit_weights=unit_level_weights,
)
att_gt = float(weights @ y_hat) if len(weights) > 0 else np.nan
eif_vals = compute_eif_nocov(
target_g=g,
target_t=t,
weights=weights,
valid_pairs=pairs,
outcome_wide=outcome_wide,
cohort_masks=cohort_masks,
never_treated_mask=never_treated_mask,
period_to_col=period_to_col,
period_1_col=effective_p1_col,
cohort_fractions=cohort_fractions,
n_units=n_units,
unit_weights=unit_level_weights,
)
eif_by_gt[(g, t)] = eif_vals
# Analytical SE = sqrt(mean(EIF^2) / n) [paper p.21]
# With survey: use TSL variance via compute_survey_vcov
if self._unit_resolved_survey is not None:
se_gt = self._compute_survey_eif_se(eif_vals)
else:
se_gt = _compute_se_from_eif(
eif_vals, n_units, unit_cluster_indices, n_clusters
)
t_stat, p_val, ci = safe_inference(
att_gt, se_gt, alpha=self.alpha, df=self._survey_df
)
group_time_effects[(g, t)] = {
"effect": att_gt,
"se": se_gt,
"t_stat": t_stat,
"p_value": p_val,
"conf_int": ci,
"n_treated": int(np.sum(cohort_masks[g])),
"n_control": int(np.sum(never_treated_mask)),
}
if not group_time_effects:
raise ValueError(
"Could not estimate any group-time effects. "
"Check data has sufficient observations."
)
# ----- Aggregation -----
overall_att, overall_se = self._aggregate_overall(
group_time_effects,
eif_by_gt,
n_units,
cohort_fractions,
unit_cohorts,
cluster_indices=unit_cluster_indices,
n_clusters=n_clusters,
)
overall_t, overall_p, overall_ci = safe_inference(
overall_att, overall_se, alpha=self.alpha, df=self._survey_df
)
event_study_effects = None
group_effects = None
if aggregate in ("event_study", "all"):
event_study_effects = self._aggregate_event_study(
group_time_effects,
eif_by_gt,
n_units,
cohort_fractions,
treatment_groups,
time_periods,
balance_e,
unit_cohorts=unit_cohorts,
cluster_indices=unit_cluster_indices,
n_clusters=n_clusters,
)
if aggregate in ("group", "all"):
group_effects = self._aggregate_by_group(
group_time_effects,
eif_by_gt,
n_units,
cohort_fractions,
treatment_groups,
unit_cohorts=unit_cohorts,
cluster_indices=unit_cluster_indices,
n_clusters=n_clusters,
)
# ----- Bootstrap -----
# Reject replicate-weight designs for bootstrap — replicate variance
# is an analytical alternative, not compatible with bootstrap
if (
self.n_bootstrap > 0
and self._unit_resolved_survey is not None
and self._unit_resolved_survey.uses_replicate_variance
):
raise NotImplementedError(
"EfficientDiD bootstrap (n_bootstrap > 0) is not supported "
"with replicate-weight survey designs. Replicate weights provide "
"analytical variance; use n_bootstrap=0 instead."
)
bootstrap_results = None
if self.n_bootstrap > 0 and eif_by_gt:
bootstrap_results = self._run_multiplier_bootstrap(
group_time_effects=group_time_effects,
eif_by_gt=eif_by_gt,
n_units=n_units,
aggregate=aggregate,
balance_e=balance_e,
treatment_groups=treatment_groups,
cohort_fractions=cohort_fractions,
cluster_indices=unit_cluster_indices,
n_clusters=n_clusters,
resolved_survey=self._unit_resolved_survey,
unit_level_weights=self._unit_level_weights,
)
# Update estimates with bootstrap inference
overall_se = bootstrap_results.overall_att_se
overall_t = safe_inference(overall_att, overall_se, alpha=self.alpha)[0]
overall_p = bootstrap_results.overall_att_p_value
overall_ci = bootstrap_results.overall_att_ci
for gt in group_time_effects:
if gt in bootstrap_results.group_time_ses:
group_time_effects[gt]["se"] = bootstrap_results.group_time_ses[gt]
group_time_effects[gt]["conf_int"] = bootstrap_results.group_time_cis[gt]
group_time_effects[gt]["p_value"] = bootstrap_results.group_time_p_values[gt]
eff = float(group_time_effects[gt]["effect"])
se = float(group_time_effects[gt]["se"])
group_time_effects[gt]["t_stat"] = safe_inference(eff, se, alpha=self.alpha)[0]
es_cis = bootstrap_results.event_study_cis
es_pvs = bootstrap_results.event_study_p_values
if (
event_study_effects is not None
and bootstrap_results.event_study_ses is not None
and es_cis is not None
and es_pvs is not None
):
for e in event_study_effects:
if e in bootstrap_results.event_study_ses:
event_study_effects[e]["se"] = bootstrap_results.event_study_ses[e]
event_study_effects[e]["conf_int"] = es_cis[e]
event_study_effects[e]["p_value"] = es_pvs[e]
eff = float(event_study_effects[e]["effect"])
se = float(event_study_effects[e]["se"])
event_study_effects[e]["t_stat"] = safe_inference(
eff, se, alpha=self.alpha
)[0]
g_cis = bootstrap_results.group_effect_cis
g_pvs = bootstrap_results.group_effect_p_values
if (
group_effects is not None
and bootstrap_results.group_effect_ses is not None
and g_cis is not None
and g_pvs is not None
):
for g in group_effects:
if g in bootstrap_results.group_effect_ses:
group_effects[g]["se"] = bootstrap_results.group_effect_ses[g]
group_effects[g]["conf_int"] = g_cis[g]
group_effects[g]["p_value"] = g_pvs[g]
eff = float(group_effects[g]["effect"])
se = float(group_effects[g]["se"])
group_effects[g]["t_stat"] = safe_inference(eff, se, alpha=self.alpha)[0]
# ----- Build results -----
self.results_ = EfficientDiDResults(
group_time_effects=group_time_effects,
overall_att=overall_att,
overall_se=overall_se,
overall_t_stat=overall_t,
overall_p_value=overall_p,
overall_conf_int=overall_ci,
groups=treatment_groups,
time_periods=time_periods,
n_obs=n_units * len(time_periods),
n_treated_units=n_treated_units,
n_control_units=n_control_units,
alpha=self.alpha,
pt_assumption=self.pt_assumption,
anticipation=self.anticipation,
n_bootstrap=self.n_bootstrap,
bootstrap_weights=self.bootstrap_weights,
seed=self.seed,
event_study_effects=event_study_effects,
group_effects=group_effects,
efficient_weights=stored_weights if stored_weights else None,
omega_condition_numbers=stored_cond if stored_cond else None,
control_group=self.control_group,
# 2-branch cluster_name/n_clusters resolution: suppress under any
# survey design (analytical TSL or replicate); populate under bare
# ``cluster=``; default to None under unclustered, non-survey fits.
# The default per-unit EIF SE ``sqrt(mean(EIF^2)/n)`` is HC1-style
# (not auto-cluster-at-unit), so no third unit-default branch.
cluster_name=(
None
if resolved_survey is not None
else (self.cluster if self.cluster is not None else None)
),
n_clusters=(
None
if resolved_survey is not None
else (n_clusters if self.cluster is not None else None)
),
vcov_type=self.vcov_type,
influence_functions=eif_by_gt if store_eif else None,
bootstrap_results=bootstrap_results,
estimation_path="dr" if use_covariates else "nocov",
sieve_k_max=self.sieve_k_max,
sieve_criterion=self.sieve_criterion,
ratio_clip=self.ratio_clip,
kernel_bandwidth=self.kernel_bandwidth,
survey_metadata=(
self._recompute_unit_survey_metadata(survey_metadata)
if survey_metadata is not None
else None
),
)
self.is_fitted_ = True
return self.results_
def _recompute_unit_survey_metadata(self, panel_metadata):
"""Recompute survey metadata from unit-level design if available."""
if self._unit_resolved_survey is not None:
from diff_diff.survey import compute_survey_metadata
meta = compute_survey_metadata(
self._unit_resolved_survey,
self._unit_resolved_survey.weights,
)
# Propagate effective replicate df if available
# (but not the df=0 sentinel — keep metadata as None for undefined df)
if (
self._survey_df is not None
and self._survey_df != 0
and meta.df_survey != self._survey_df
):
meta.df_survey = self._survey_df
return meta
return panel_metadata
# -- Survey SE helpers ----------------------------------------------------
def _compute_survey_eif_se(self, eif_vals: np.ndarray) -> float:
"""Compute SE from EIF scores using Taylor Series Linearization.
Uses the pre-built unit-level ``_unit_resolved_survey`` constructed
once in ``fit()``, ensuring consistent unit-level arrays and
avoiding repeated subsetting of panel-level survey data.
"""
if self._unit_resolved_survey.uses_replicate_variance:
from diff_diff.survey import compute_replicate_if_variance
# Score-scale IFs to match TSL bread: psi = w * eif / sum(w)
w = self._unit_resolved_survey.weights
psi_scaled = w * eif_vals / w.sum()
variance, n_valid = compute_replicate_if_variance(
psi_scaled, self._unit_resolved_survey
)
# Update survey df to reflect effective replicate count
if n_valid < self._unit_resolved_survey.n_replicates:
self._survey_df = n_valid - 1 if n_valid > 1 else None
return float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
from diff_diff.survey import compute_survey_vcov
X_ones = np.ones((len(eif_vals), 1))
vcov = compute_survey_vcov(X_ones, eif_vals, self._unit_resolved_survey)
return float(np.sqrt(np.abs(vcov[0, 0])))
def _eif_se(
self,
eif_vals: np.ndarray,
n_units: int,
cluster_indices: Optional[np.ndarray] = None,
n_clusters: Optional[int] = None,
) -> float:
"""Compute SE from aggregated EIF scores.
Dispatches to survey TSL when ``_unit_resolved_survey`` is set
(during fit), otherwise uses cluster-robust or standard formula.
"""
if self._unit_resolved_survey is not None:
return self._compute_survey_eif_se(eif_vals)
return _compute_se_from_eif(eif_vals, n_units, cluster_indices, n_clusters)
# -- Aggregation helpers --------------------------------------------------
def _compute_wif_contribution(
self,
keepers: List[Tuple],
effects: np.ndarray,
unit_cohorts: np.ndarray,
cohort_fractions: Dict[float, float],
n_units: int,
unit_weights: Optional[np.ndarray] = None,
) -> np.ndarray:
"""Compute weight influence function correction (O(1) scale, matching EIF).
This accounts for uncertainty in cohort-size aggregation weights.
Matches R's ``did`` package WIF formula (staggered_aggregation.py:282-309),
adapted to EDiD's EIF scale.
Parameters
----------
keepers : list of (g, t) tuples
Post-treatment group-time pairs included in aggregation.
effects : ndarray, shape (n_keepers,)
ATT estimates for each keeper.
unit_cohorts : ndarray, shape (n_units,)
Cohort assignment for each unit (0 = never-treated).
cohort_fractions : dict
``{cohort: n_cohort / n}`` for each cohort.
n_units : int
Total number of units.
unit_weights : ndarray, shape (n_units,), optional
Survey weights at the unit level. When provided, uses the
survey-weighted WIF formula: IF_i(p_g) = (w_i * 1{G_i=g} - pg_k).
Returns
-------
ndarray, shape (n_units,)
WIF contribution at O(1) scale, additive with ``agg_eif``.
"""
groups_for_keepers = np.array([g for (g, t) in keepers])
pg_keepers = np.array([cohort_fractions.get(g, 0.0) for g, t in keepers])
sum_pg = pg_keepers.sum()
if sum_pg == 0:
return np.zeros(n_units)
indicator = (unit_cohorts[:, None] == groups_for_keepers[None, :]).astype(float)
if unit_weights is not None:
# Survey-weighted WIF (matches staggered_aggregation.py:392-401):
# IF_i(p_g) = (w_i * 1{G_i=g} - pg_k), NOT (1{G_i=g} - pg_k)
weighted_indicator = indicator * unit_weights[:, None]
indicator_diff = weighted_indicator - pg_keepers
indicator_sum = np.sum(indicator_diff, axis=1)
else:
indicator_diff = indicator - pg_keepers
indicator_sum = np.sum(indicator_diff, axis=1)
with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
if1 = indicator_diff / sum_pg
if2 = np.outer(indicator_sum, pg_keepers) / sum_pg**2
wif_matrix = if1 - if2
wif_contrib = wif_matrix @ effects
return wif_contrib # O(1) scale, same as agg_eif
def _aggregate_overall(
self,
group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]],
eif_by_gt: Dict[Tuple[Any, Any], np.ndarray],
n_units: int,
cohort_fractions: Dict[float, float],
unit_cohorts: np.ndarray,
cluster_indices: Optional[np.ndarray] = None,
n_clusters: Optional[int] = None,
) -> Tuple[float, float]:
"""Compute overall ATT with WIF-adjusted SE.
Parameters
----------
group_time_effects : dict
Group-time ATT estimates.
eif_by_gt : dict
Per-unit EIF values for each (g, t).
n_units : int
Total number of units.
cohort_fractions : dict
Cohort size fractions.
unit_cohorts : ndarray, shape (n_units,)
Cohort assignment for each unit.
"""
# Filter to post-treatment effects
keepers = [
(g, t)
for (g, t) in group_time_effects
if t >= g - self.anticipation and np.isfinite(group_time_effects[(g, t)]["effect"])
]
if not keepers:
return np.nan, np.nan
# Cohort-size weights
pg = np.array([cohort_fractions.get(g, 0.0) for (g, _) in keepers])
total_pg = pg.sum()
if total_pg == 0:
return np.nan, np.nan
w = pg / total_pg
effects = np.array([group_time_effects[gt]["effect"] for gt in keepers])
overall_att = float(np.sum(w * effects))
# Aggregate EIF
agg_eif = np.zeros(n_units)
for k, gt in enumerate(keepers):
agg_eif += w[k] * eif_by_gt[gt]
# WIF correction: accounts for uncertainty in cohort-size weights
wif = self._compute_wif_contribution(
keepers,
effects,
unit_cohorts,
cohort_fractions,
n_units,
unit_weights=self._unit_level_weights,
)
# Compute SE: survey path uses score-level psi to avoid double-weighting
# (compute_survey_vcov applies w_i internally, which would double-weight
# the survey-weighted WIF term). Dispatch replicate vs TSL.
if self._unit_resolved_survey is not None:
uw = self._unit_level_weights
total_w = float(np.sum(uw))
psi_total = uw * agg_eif / total_w + wif / total_w
if (
hasattr(self._unit_resolved_survey, "uses_replicate_variance")
and self._unit_resolved_survey.uses_replicate_variance
):
from diff_diff.survey import compute_replicate_if_variance
variance, _ = compute_replicate_if_variance(psi_total, self._unit_resolved_survey)
else:
from diff_diff.survey import compute_survey_if_variance
variance = compute_survey_if_variance(psi_total, self._unit_resolved_survey)
se = float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
else:
agg_eif_total = agg_eif + wif
se = self._eif_se(agg_eif_total, n_units, cluster_indices, n_clusters)
return overall_att, se
def _aggregate_event_study(
self,
group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]],
eif_by_gt: Dict[Tuple[Any, Any], np.ndarray],
n_units: int,
cohort_fractions: Dict[float, float],
treatment_groups: List[Any],
time_periods: List[Any],
balance_e: Optional[int] = None,
unit_cohorts: Optional[np.ndarray] = None,
cluster_indices: Optional[np.ndarray] = None,
n_clusters: Optional[int] = None,
) -> Dict[int, Dict[str, Any]]:
"""Aggregate ATT(g,t) by relative time e = t - g.
Parameters
----------
group_time_effects : dict
Group-time ATT estimates.
eif_by_gt : dict
Per-unit EIF values for each (g, t).
n_units : int
Total number of units.
cohort_fractions : dict
Cohort size fractions.
treatment_groups : list
Treatment cohort identifiers.
time_periods : list
All time periods.
balance_e : int, optional
Balance event study at this relative period.
unit_cohorts : ndarray, optional
Cohort assignment for each unit (for WIF correction).
"""
# Organize by relative time
effects_by_e: Dict[int, List[Tuple[Tuple[Any, Any], float, float]]] = {}
for (g, t), data in group_time_effects.items():
if not np.isfinite(data["effect"]):
continue
e = int(t - g)
if e not in effects_by_e:
effects_by_e[e] = []
effects_by_e[e].append(((g, t), data["effect"], cohort_fractions.get(g, 0.0)))
# Balance if requested
if balance_e is not None:
groups_at_e = {gt[0] for gt, _, _ in effects_by_e.get(balance_e, [])}
balanced: Dict[int, List[Tuple[Tuple[Any, Any], float, float]]] = {}
for (g, t), data in group_time_effects.items():
if not np.isfinite(data["effect"]):
continue
if g in groups_at_e:
e = int(t - g)
if e not in balanced:
balanced[e] = []
balanced[e].append(((g, t), data["effect"], cohort_fractions.get(g, 0.0)))
effects_by_e = balanced
if balance_e is not None and not effects_by_e:
warnings.warn(
f"balance_e={balance_e}: no cohort has a finite effect at the "
"anchor horizon. Event study will be empty.",
UserWarning,
stacklevel=2,
)
result: Dict[int, Dict[str, Any]] = {}
for e, elist in sorted(effects_by_e.items()):
gt_pairs = [x[0] for x in elist]
effs = np.array([x[1] for x in elist])
pgs = np.array([x[2] for x in elist])
total_pg = pgs.sum()
w = pgs / total_pg if total_pg > 0 else np.ones(len(pgs)) / len(pgs)
agg_eff = float(np.sum(w * effs))
# Aggregate EIF
agg_eif = np.zeros(n_units)
for k, gt in enumerate(gt_pairs):
agg_eif += w[k] * eif_by_gt[gt]
# WIF correction for event-study aggregation
wif_e = np.zeros(n_units)
if unit_cohorts is not None:
es_keepers = [(g, t) for (g, t) in gt_pairs]
es_effects = effs
wif_e = self._compute_wif_contribution(
es_keepers,
es_effects,
unit_cohorts,
cohort_fractions,
n_units,
unit_weights=self._unit_level_weights,
)
if self._unit_resolved_survey is not None:
uw = self._unit_level_weights
total_w = float(np.sum(uw))
psi_total = uw * agg_eif / total_w + wif_e / total_w
if (
hasattr(self._unit_resolved_survey, "uses_replicate_variance")
and self._unit_resolved_survey.uses_replicate_variance
):
from diff_diff.survey import compute_replicate_if_variance
variance, _ = compute_replicate_if_variance(
psi_total, self._unit_resolved_survey
)
else:
from diff_diff.survey import compute_survey_if_variance
variance = compute_survey_if_variance(psi_total, self._unit_resolved_survey)
agg_se = float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
else:
agg_eif = agg_eif + wif_e
agg_se = self._eif_se(agg_eif, n_units, cluster_indices, n_clusters)
t_stat, p_val, ci = safe_inference(
agg_eff, agg_se, alpha=self.alpha, df=self._survey_df
)
result[e] = {
"effect": agg_eff,
"se": agg_se,
"t_stat": t_stat,
"p_value": p_val,
"conf_int": ci,
"n_groups": len(elist),
}
return result
def _aggregate_by_group(
self,
group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]],
eif_by_gt: Dict[Tuple[Any, Any], np.ndarray],
n_units: int,
cohort_fractions: Dict[float, float],
treatment_groups: List[Any],
unit_cohorts: Optional[np.ndarray] = None,
cluster_indices: Optional[np.ndarray] = None,
n_clusters: Optional[int] = None,
) -> Dict[Any, Dict[str, Any]]:
"""Aggregate ATT(g,t) by treatment cohort.
Parameters
----------
group_time_effects : dict
Group-time ATT estimates.
eif_by_gt : dict
Per-unit EIF values for each (g, t).
n_units : int
Total number of units.
cohort_fractions : dict
Cohort size fractions.
treatment_groups : list
Treatment cohort identifiers.
unit_cohorts : ndarray, optional
Cohort assignment for each unit (unused — group aggregation
uses equal weights, not cohort-size weights).
"""
result: Dict[Any, Dict[str, Any]] = {}
for g in treatment_groups:
g_gts = [
(gg, t)
for (gg, t) in group_time_effects
if gg == g
and t >= g - self.anticipation
and np.isfinite(group_time_effects[(gg, t)]["effect"])
]
if not g_gts:
continue
effs = np.array([group_time_effects[gt]["effect"] for gt in g_gts])
w = np.ones(len(effs)) / len(effs)
agg_eff = float(np.sum(w * effs))
agg_eif = np.zeros(n_units)
for k, gt in enumerate(g_gts):
agg_eif += w[k] * eif_by_gt[gt]
agg_se = self._eif_se(agg_eif, n_units, cluster_indices, n_clusters)
t_stat, p_val, ci = safe_inference(
agg_eff, agg_se, alpha=self.alpha, df=self._survey_df
)
result[g] = {
"effect": agg_eff,
"se": agg_se,
"t_stat": t_stat,
"p_value": p_val,
"conf_int": ci,
"n_periods": len(g_gts),
}
return result
[docs]
def summary(self) -> str:
"""Get summary of estimation results."""
if not self.is_fitted_:
raise RuntimeError("Model must be fitted before calling summary()")
assert self.results_ is not None
return self.results_.summary()
[docs]
def print_summary(self) -> None:
"""Print summary to stdout."""
print(self.summary())
# -- Hausman pretest -------------------------------------------------------
[docs]
@classmethod
def hausman_pretest(
cls,
data: pd.DataFrame,
outcome: str,
unit: str,
time: str,
first_treat: str,
covariates: Optional[List[str]] = None,
cluster: Optional[str] = None,
anticipation: int = 0,
control_group: str = "never_treated",
alpha: float = 0.05,
**nuisance_kwargs: Any,
) -> HausmanPretestResult:
"""Hausman pretest for PT-All vs PT-Post (Theorem A.1).
Fits the estimator under both parallel trends assumptions and
compares the results. Under H0 (PT-All holds), both are consistent
but PT-All is more efficient. Rejection suggests PT-All is too
strong; use PT-Post instead.
Parameters
----------
data, outcome, unit, time, first_treat, covariates
Same as :meth:`fit`.
cluster : str, optional
Cluster column for cluster-robust covariance.
anticipation : int
Anticipation periods.
control_group : str
``"never_treated"`` or ``"last_cohort"``.
alpha : float
Significance level for the test.
**nuisance_kwargs
Passed to both fits (e.g. ``sieve_k_max``, ``ratio_clip``).
Returns
-------
HausmanPretestResult
"""
# Fit under both assumptions (analytical SEs only, no bootstrap)
common_kwargs = dict(
cluster=cluster,
control_group=control_group,
anticipation=anticipation,
n_bootstrap=0,
**nuisance_kwargs,
)
fit_kwargs = dict(
data=data,
outcome=outcome,
unit=unit,
time=time,
first_treat=first_treat,
covariates=covariates,
aggregate=None,
)
edid_all = cls(pt_assumption="all", alpha=alpha, **common_kwargs)
result_all = edid_all.fit(**fit_kwargs, store_eif=True)
edid_post = cls(pt_assumption="post", alpha=alpha, **common_kwargs)
result_post = edid_post.fit(**fit_kwargs, store_eif=True)
# Find common (g,t) pairs — PT-Post pairs are a subset of PT-All
common_gts = sorted(
set(result_all.group_time_effects.keys()) & set(result_post.group_time_effects.keys())
)
def _nan_result() -> HausmanPretestResult:
return HausmanPretestResult(
statistic=np.nan,
p_value=np.nan,
df=0,
reject=False,
alpha=alpha,
att_all=result_all.overall_att,
att_post=result_post.overall_att,
recommendation="inconclusive",
gt_details=None,
)
if not common_gts:
return _nan_result()
eif_all = result_all.influence_functions
eif_post = result_post.influence_functions
assert eif_all is not None and eif_post is not None
n_units = len(next(iter(eif_all.values())))
# --- Aggregate to post-treatment ES(e) per Theorem A.1 ---
# Derive cohort fractions from data for proper weights
all_units_list = sorted(data[unit].unique())
unit_cohorts = (
data.groupby(unit)[first_treat].first().reindex(all_units_list).values.astype(float)
)
cohort_fractions: Dict[float, float] = {}
for g in set(result_all.groups) | set(result_post.groups):
cohort_fractions[g] = float(np.sum(unit_cohorts == g)) / n_units
def _aggregate_es(
gt_effects: Dict, eif_dict: Dict, groups: List, ant: int
) -> Dict[int, Tuple[float, np.ndarray]]:
"""Aggregate (g,t) effects to post-treatment ES(e) with WIF-corrected EIF."""
by_e: Dict[int, List[Tuple[Tuple, float, float, np.ndarray]]] = {}
for (g, t), d in gt_effects.items():
e = int(t - g)
if e < -ant:
continue
if not np.isfinite(d["effect"]):
continue
if (g, t) not in eif_dict:
continue
eif_vec = eif_dict[(g, t)]
if not np.all(np.isfinite(eif_vec)):
continue
pg = cohort_fractions.get(g, 0.0)
if e not in by_e:
by_e[e] = []
by_e[e].append(((g, t), d["effect"], pg, eif_vec))
result: Dict[int, Tuple[float, np.ndarray]] = {}
for e, items in by_e.items():
if e < 0:
continue
effs = np.array([x[1] for x in items])
pgs = np.array([x[2] for x in items])
eifs = [x[3] for x in items]
gt_pairs_e = [x[0] for x in items]
total_pg = pgs.sum()
w = pgs / total_pg if total_pg > 0 else np.ones(len(pgs)) / len(pgs)
es_eff = float(np.sum(w * effs))
es_eif = np.zeros(n_units)
for k_idx in range(len(eifs)):
es_eif += w[k_idx] * eifs[k_idx]
# WIF correction for estimated cohort-size weights
groups_e = np.array([g for (g, t) in gt_pairs_e])
pg_e = np.array([cohort_fractions.get(g, 0.0) for g, t in gt_pairs_e])
sum_pg = pg_e.sum()
if sum_pg > 0:
indicator = (unit_cohorts[:, None] == groups_e[None, :]).astype(float)
indicator_sum = np.sum(indicator - pg_e, axis=1)
with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
if1 = (indicator - pg_e) / sum_pg
if2 = np.outer(indicator_sum, pg_e) / sum_pg**2
wif = (if1 - if2) @ effs
es_eif = es_eif + wif
result[e] = (es_eff, es_eif)
return result
es_all = _aggregate_es(
result_all.group_time_effects, eif_all, result_all.groups, anticipation
)
es_post = _aggregate_es(
result_post.group_time_effects, eif_post, result_post.groups, anticipation
)
# Find common post-treatment horizons
common_e = sorted(set(es_all.keys()) & set(es_post.keys()))
if not common_e:
return _nan_result()
delta = np.array([es_post[e][0] - es_all[e][0] for e in common_e])
# Build ES(e)-level EIF matrices
eif_all_mat = np.column_stack([es_all[e][1] for e in common_e])
eif_post_mat = np.column_stack([es_post[e][1] for e in common_e])
# Filter units with non-finite EIF values
row_finite = np.all(np.isfinite(eif_all_mat), axis=1) & np.all(
np.isfinite(eif_post_mat), axis=1
)
cl_idx: Optional[np.ndarray] = None
n_cl: Optional[int] = None
if cluster is not None:
cl_idx, n_cl = _validate_and_build_cluster_mapping(data, unit, cluster, all_units_list)
if not np.all(row_finite):
eif_all_mat = eif_all_mat[row_finite]
eif_post_mat = eif_post_mat[row_finite]
n_units = int(np.sum(row_finite))
if cl_idx is not None:
cl_idx = cl_idx[row_finite]
# Recompute effective cluster count and remap to contiguous
# indices — entire clusters may have been dropped by filtering
unique_cl, cl_idx = np.unique(cl_idx, return_inverse=True)
n_cl = len(unique_cl)
# Compute full covariance matrices
if cl_idx is not None and n_cl is not None:
def _eif_cov(eif_mat: np.ndarray) -> np.ndarray:
centered = _cluster_aggregate(eif_mat, cl_idx, n_cl)
correction = n_cl / (n_cl - 1) if n_cl > 1 else 1.0
return correction * (centered.T @ centered) / (n_units**2)
cov_all = _eif_cov(eif_all_mat)
cov_post = _eif_cov(eif_post_mat)
else:
with np.errstate(over="ignore", invalid="ignore"):
cov_all = (eif_all_mat.T @ eif_all_mat) / (n_units**2)
cov_post = (eif_post_mat.T @ eif_post_mat) / (n_units**2)
H, effective_rank, p_value, n_negative, finite_ok = _hausman_quadratic_form(
delta, cov_post, cov_all
)
if not finite_ok:
warnings.warn(
"Hausman covariance matrix contains non-finite values. " "The test is unreliable.",
UserWarning,
stacklevel=2,
)
return _nan_result()
if n_negative > 0:
warnings.warn(
f"Hausman variance-difference matrix V has {n_negative} "
"substantially negative eigenvalue(s). The test may be "
"unreliable (finite-sample efficiency reversal).",
UserWarning,
stacklevel=2,
)
if effective_rank == 0:
return _nan_result()
reject = p_value < alpha
es_details = pd.DataFrame(
{
"relative_period": common_e,
"es_all": [es_all[e][0] for e in common_e],
"es_post": [es_post[e][0] for e in common_e],
"delta": delta,
}
)
return HausmanPretestResult(
statistic=H,
p_value=p_value,
df=effective_rank,
reject=reject,
alpha=alpha,
att_all=result_all.overall_att,
att_post=result_post.overall_att,
recommendation="pt_post" if reject else "pt_all",
gt_details=es_details,
)