"""
Continuous Difference-in-Differences estimator.
Implements Callaway, Goodman-Bacon & Sant'Anna (2024),
"Difference-in-Differences with a Continuous Treatment" (NBER WP 32117).
Estimates dose-response curves ATT(d) and ACRT(d), as well as summary
parameters ATT^{glob} and ACRT^{glob}, with optional multiplier bootstrap
inference.
"""
import warnings
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
from diff_diff.bootstrap_utils import (
compute_effect_bootstrap_stats,
generate_bootstrap_weights_batch,
)
from diff_diff.continuous_did_bspline import (
bspline_derivative_design_matrix,
bspline_design_matrix,
build_bspline_basis,
default_dose_grid,
)
from diff_diff.continuous_did_results import (
ContinuousDiDResults,
DoseResponseCurve,
)
from diff_diff.linalg import solve_ols
from diff_diff.survey import (
ResolvedSurveyDesign,
_resolve_survey_for_fit,
_validate_unit_constant_survey,
compute_survey_vcov,
)
from diff_diff.utils import safe_inference
__all__ = ["ContinuousDiD", "ContinuousDiDResults", "DoseResponseCurve"]
[docs]
class ContinuousDiD:
"""
Continuous Difference-in-Differences estimator.
Implements the methodology from Callaway, Goodman-Bacon & Sant'Anna (2024)
for estimating dose-response curves when treatment has a continuous intensity.
Parameters
----------
degree : int, default=3
B-spline degree (3 = cubic).
num_knots : int, default=0
Number of interior knots for the B-spline basis.
dvals : array-like, optional
Custom dose evaluation grid. If None, uses quantile-based default.
control_group : str, default="never_treated"
``"never_treated"`` or ``"not_yet_treated"``.
anticipation : int, default=0
Number of periods of treatment anticipation.
base_period : str, default="varying"
``"varying"`` or ``"universal"``.
alpha : float, default=0.05
Significance level for confidence intervals.
n_bootstrap : int, default=0
Number of multiplier bootstrap iterations. 0 for analytical SEs only.
bootstrap_weights : str, default="rademacher"
Bootstrap weight type: ``"rademacher"``, ``"mammen"``, or ``"webb"``.
seed : int, optional
Random seed for reproducibility.
rank_deficient_action : str, default="warn"
Action for rank-deficient B-spline OLS: ``"warn"``, ``"error"``, or ``"silent"``.
Examples
--------
>>> from diff_diff import ContinuousDiD, generate_continuous_did_data
>>> data = generate_continuous_did_data(n_units=200, seed=42)
>>> est = ContinuousDiD(n_bootstrap=199, seed=42)
>>> results = est.fit(data, outcome="outcome", unit="unit",
... time="period", first_treat="first_treat",
... dose="dose", aggregate="dose")
>>> results.overall_att # doctest: +SKIP
"""
_VALID_CONTROL_GROUPS = {"never_treated", "not_yet_treated"}
_VALID_BASE_PERIODS = {"varying", "universal"}
[docs]
def __init__(
self,
degree: int = 3,
num_knots: int = 0,
dvals: Optional[np.ndarray] = None,
control_group: str = "never_treated",
anticipation: int = 0,
base_period: str = "varying",
alpha: float = 0.05,
n_bootstrap: int = 0,
bootstrap_weights: str = "rademacher",
seed: Optional[int] = None,
rank_deficient_action: str = "warn",
):
self.degree = degree
self.num_knots = num_knots
self.dvals = np.asarray(dvals, dtype=float) if dvals is not None else None
self.control_group = control_group
self.anticipation = anticipation
self.base_period = base_period
self.alpha = alpha
self.n_bootstrap = n_bootstrap
self.bootstrap_weights = bootstrap_weights
self.seed = seed
self.rank_deficient_action = rank_deficient_action
self._validate_constrained_params()
def _validate_constrained_params(self) -> None:
"""Validate control_group and base_period values."""
if self.control_group not in self._VALID_CONTROL_GROUPS:
raise ValueError(
f"Invalid control_group: '{self.control_group}'. "
f"Must be one of {self._VALID_CONTROL_GROUPS}."
)
if self.base_period not in self._VALID_BASE_PERIODS:
raise ValueError(
f"Invalid base_period: '{self.base_period}'. "
f"Must be one of {self._VALID_BASE_PERIODS}."
)
[docs]
def get_params(self) -> Dict[str, Any]:
"""Return estimator parameters as a dictionary."""
return {
"degree": self.degree,
"num_knots": self.num_knots,
"dvals": self.dvals,
"control_group": self.control_group,
"anticipation": self.anticipation,
"base_period": self.base_period,
"alpha": self.alpha,
"n_bootstrap": self.n_bootstrap,
"bootstrap_weights": self.bootstrap_weights,
"seed": self.seed,
"rank_deficient_action": self.rank_deficient_action,
}
[docs]
def set_params(self, **params) -> "ContinuousDiD":
"""Set estimator parameters and return self."""
for key, value in params.items():
if not hasattr(self, key):
raise ValueError(f"Invalid parameter: {key}")
setattr(self, key, value)
self._validate_constrained_params()
return self
# ------------------------------------------------------------------
# Main fit
# ------------------------------------------------------------------
[docs]
def fit(
self,
data: pd.DataFrame,
outcome: str,
unit: str,
time: str,
first_treat: str,
dose: str,
aggregate: Optional[str] = None,
survey_design: object = None,
) -> ContinuousDiDResults:
"""
Fit the continuous DiD estimator.
Parameters
----------
data : pd.DataFrame
Panel data.
outcome : str
Outcome column name.
unit : str
Unit identifier column.
time : str
Time period column.
first_treat : str
First treatment period column (0 or inf for never-treated).
dose : str
Continuous dose column.
aggregate : str, optional
``"dose"`` for dose-response aggregation, ``"eventstudy"`` for
binarized event study.
survey_design : SurveyDesign, optional
Survey design specification for design-based inference.
Supports weighted estimation and Taylor series linearization
variance with strata, PSU, and FPC.
Returns
-------
ContinuousDiDResults
"""
# 1. Validate & prepare
_VALID_AGGREGATES = (None, "dose", "eventstudy")
if aggregate not in _VALID_AGGREGATES:
raise ValueError(
f"Invalid aggregate: '{aggregate}'. " f"Must be one of {_VALID_AGGREGATES}."
)
# Resolve survey design if provided
resolved_survey, survey_weights, survey_weight_type, survey_metadata = (
_resolve_survey_for_fit(survey_design, data, "analytical")
)
# Validate within-unit constancy for panel survey designs
if resolved_survey is not None:
_validate_unit_constant_survey(data, unit, survey_design)
# Bootstrap + survey supported via PSU-level multiplier bootstrap.
df = data.copy()
for col in [outcome, unit, time, first_treat, dose]:
if col not in df.columns:
raise ValueError(f"Column '{col}' not found in data.")
# Verify dose is time-invariant
dose_nunique = df.groupby(unit)[dose].nunique()
if dose_nunique.max() > 1:
bad_units = dose_nunique[dose_nunique > 1].index.tolist()
raise ValueError(
f"Dose must be time-invariant. Units with varying dose: {bad_units[:5]}"
)
# Normalize first_treat: +inf → 0 (R-style never-treated encoding).
# Count rows recategorized so users can see how many units just
# crossed from "treated at some point" to "never treated" — silent
# recategorization here would shift the control composition (axis-E
# silent coercion). Only positive infinity is recoded (to match the
# existing `.replace([np.inf, float("inf")], 0)` semantics on the
# next line).
first_treat_vals = df[first_treat].values
# Reject NaN first_treat explicitly. NaN survives preprocessing but
# satisfies neither the treated (g > 0) nor never-treated (g == 0)
# mask, so affected units would be silently excluded from the
# estimator (same silent-failure shape as `first_treat < 0`).
nan_mask = pd.isna(df[first_treat])
n_nan_first_treat = int(nan_mask.sum())
if n_nan_first_treat > 0:
raise ValueError(
f"{n_nan_first_treat} row(s) have NaN '{first_treat}' "
f"values. Valid values are 0 (never-treated) or a positive "
f"treatment period; such units would otherwise be silently "
f"excluded from both treated and control pools."
)
inf_mask = np.isposinf(first_treat_vals)
n_inf_first_treat = int(inf_mask.sum())
if n_inf_first_treat > 0:
warnings.warn(
f"{n_inf_first_treat} row(s) have inf in '{first_treat}'; "
f"treating the corresponding units as never-treated. Pass an "
f"explicit never-treated marker (0) if this is not intended.",
UserWarning,
stacklevel=2,
)
# Reject negative first_treat values (including -inf) explicitly.
# Without this guard they would survive preprocessing but fall out of
# both the treated (g > 0) and never-treated (g == 0) masks, silently
# excluding the affected units.
negative_mask = first_treat_vals < 0
n_negative_first_treat = int(negative_mask.sum())
if n_negative_first_treat > 0:
raise ValueError(
f"{n_negative_first_treat} row(s) have negative '{first_treat}' "
f"values (including -inf). Valid values are 0 (never-treated) "
f"or a positive treatment period; such units would otherwise "
f"be silently excluded from both treated and control pools."
)
df[first_treat] = df[first_treat].replace([np.inf, float("inf")], 0)
# Drop units with positive first_treat but zero dose (R convention)
unit_info = df.groupby(unit).first()[[first_treat, dose]]
drop_units = unit_info[(unit_info[first_treat] > 0) & (unit_info[dose] == 0)].index
if len(drop_units) > 0:
warnings.warn(
f"Dropping {len(drop_units)} units with positive first_treat but zero dose.",
UserWarning,
stacklevel=2,
)
df = df[~df[unit].isin(drop_units)]
# Validate no negative doses among treated units
treated_doses = df.loc[df[first_treat] > 0, dose]
if (treated_doses < 0).any():
n_neg = int((treated_doses < 0).sum())
raise ValueError(
f"Found {n_neg} treated unit(s) with negative dose. "
f"Dose must be strictly positive for treated units (D > 0)."
)
# Detect discrete (integer-valued) dose among treated units
unit_doses = df.loc[df[first_treat] > 0].groupby(unit)[dose].first()
unique_pos_doses = unit_doses[unit_doses > 0].unique()
is_integer = len(unique_pos_doses) > 0 and np.allclose(
unique_pos_doses, np.round(unique_pos_doses)
)
if is_integer:
warnings.warn(
f"Dose appears discrete ({len(unique_pos_doses)} unique integer values). "
"B-spline smoothing may be inappropriate for discrete treatments. "
"Consider a saturated regression approach (not yet implemented).",
UserWarning,
stacklevel=2,
)
# Force dose=0 for never-treated units with nonzero dose. Report the
# affected row count via UserWarning so users can see whether their
# never-treated rows had unintended nonzero doses — silent zeroing
# here would quietly shift part of the control trajectory (axis-E
# silent coercion, paired with the `first_treat=inf -> 0` fix above).
never_treated_mask = df[first_treat] == 0
nonzero_dose_rows = never_treated_mask & (df[dose] != 0)
n_nonzero_dose_never_treated = int(nonzero_dose_rows.sum())
if n_nonzero_dose_never_treated > 0:
warnings.warn(
f"{n_nonzero_dose_never_treated} row(s) have '{first_treat}'=0 "
f"(never-treated) but nonzero '{dose}'; zeroing the dose. Pass "
f"dose=0 for never-treated rows to avoid this coercion.",
UserWarning,
stacklevel=2,
)
df.loc[never_treated_mask, dose] = 0.0
# Verify balanced panel
all_periods = set(df[time].unique())
unit_periods = df.groupby(unit)[time].apply(set)
is_unbalanced = unit_periods.apply(lambda s: s != all_periods)
if is_unbalanced.any():
n_bad = int(is_unbalanced.sum())
raise ValueError(
"Unbalanced panel detected. ContinuousDiD requires a balanced panel. "
f"{n_bad} unit(s) have missing periods."
)
# Identify groups and time periods
unit_cohort = df.groupby(unit)[first_treat].first()
treatment_groups = sorted([g for g in unit_cohort.unique() if g > 0])
time_periods = sorted(df[time].unique())
if len(treatment_groups) == 0:
raise ValueError("No treated units found (all first_treat == 0).")
n_control = int((unit_cohort == 0).sum())
if self.control_group == "never_treated" and n_control == 0:
raise ValueError(
"No never-treated units found. Use control_group='not_yet_treated' "
"or add never-treated units."
)
if self.control_group == "not_yet_treated" and n_control == 0:
raise ValueError(
"No never-treated (D=0) units found. With control_group='not_yet_treated', "
"dose-response curve identification requires P(D=0) > 0 "
"(Remark 3.1 in Callaway et al. is not yet implemented). "
"Add never-treated units or use a dataset with D=0 observations."
)
# Re-resolve survey design on filtered df if rows were dropped
# (survey arrays must align with df, not the original data)
if resolved_survey is not None and len(df) < len(data):
resolved_survey, survey_weights, survey_weight_type, survey_metadata = (
_resolve_survey_for_fit(survey_design, df, "analytical")
)
# 2. Precompute structures
precomp = self._precompute_structures(
df,
outcome,
unit,
time,
first_treat,
dose,
time_periods,
survey_weights=survey_weights,
)
# Compute dvals (evaluation grid)
all_treated_doses = precomp["dose_vector"][precomp["dose_vector"] > 0]
if self.dvals is not None:
dvals = self.dvals
else:
dvals = default_dose_grid(all_treated_doses)
# Build B-spline knots from all treated doses
knots, degree = build_bspline_basis(
all_treated_doses, degree=self.degree, num_knots=self.num_knots
)
# 3. Iterate over (g,t) cells
gt_results = {}
gt_bootstrap_info = {}
for g in treatment_groups:
for t in time_periods:
result = self._compute_dose_response_gt(
precomp,
g,
t,
knots,
degree,
dvals,
survey_weights=precomp.get("unit_survey_weights"),
resolved_survey=resolved_survey,
)
if result is not None:
gt_results[(g, t)] = result
gt_bootstrap_info[(g, t)] = result.get("_bootstrap_info", {})
# Filter out NaN cells (e.g., from zero effective survey mass)
gt_results = {
gt: r for gt, r in gt_results.items()
if np.isfinite(r.get("att_glob", np.nan))
}
if len(gt_results) == 0:
raise ValueError("No valid (g,t) cells computed.")
# 4. Aggregate
post_gt = {(g, t): r for (g, t), r in gt_results.items() if t >= g - self.anticipation}
# Dose-response aggregation
n_grid = len(dvals)
# NaN-initialized SE/CI fields (used when post_gt is empty or as defaults)
att_d_se = np.full(n_grid, np.nan)
att_d_ci_lower = np.full(n_grid, np.nan)
att_d_ci_upper = np.full(n_grid, np.nan)
acrt_d_se = np.full(n_grid, np.nan)
acrt_d_ci_lower = np.full(n_grid, np.nan)
acrt_d_ci_upper = np.full(n_grid, np.nan)
overall_att_se = np.nan
overall_att_t = np.nan
overall_att_p = np.nan
overall_att_ci = (np.nan, np.nan)
overall_acrt_se = np.nan
overall_acrt_t = np.nan
overall_acrt_p = np.nan
overall_acrt_ci = (np.nan, np.nan)
att_d_p = None
acrt_d_p = None
# Event study aggregation (binarized) — runs on ALL (g,t) cells
event_study_effects = None
if aggregate == "eventstudy":
event_study_effects = self._aggregate_event_study(
gt_results,
gt_bootstrap_info=gt_bootstrap_info,
unit_survey_weights=precomp.get("unit_survey_weights"),
unit_cohorts=precomp["unit_cohorts"],
anticipation=self.anticipation,
)
_survey_df = None # Set by analytical branch when survey is active
if len(post_gt) == 0:
warnings.warn(
"No post-treatment (g,t) cells available for aggregation. "
"This can occur when all treatments start after the last observed "
"period or all cells were skipped due to insufficient data.",
UserWarning,
stacklevel=2,
)
overall_att = np.nan
overall_acrt = np.nan
agg_att_d = np.full(n_grid, np.nan)
agg_acrt_d = np.full(n_grid, np.nan)
else:
# Compute cell weights: group-proportional (matching R's contdid convention).
# Each group g gets weight proportional to its number of treated units.
# When survey weights present, use sum(w_g) / sum(w) instead of n_g / N.
# Within each group, weight is divided equally among post-treatment cells.
group_n_treated = {}
group_n_post_cells = {}
unit_sw = precomp.get("unit_survey_weights")
for (g, t), r in post_gt.items():
if g not in group_n_treated:
if unit_sw is not None:
# Survey-weighted group size: sum of weights for treated units in g
g_mask = precomp["unit_cohorts"] == g
group_n_treated[g] = float(np.sum(unit_sw[g_mask]))
else:
group_n_treated[g] = float(r["n_treated"])
group_n_post_cells[g] = 0
group_n_post_cells[g] += 1
total_treated = sum(group_n_treated.values())
cell_weights = {}
if total_treated > 0:
for (g, t), r in post_gt.items():
pg = group_n_treated[g] / total_treated
cell_weights[(g, t)] = pg / group_n_post_cells[g]
agg_att_d = np.zeros(n_grid)
agg_acrt_d = np.zeros(n_grid)
overall_att = 0.0
overall_acrt = 0.0
for gt, w in cell_weights.items():
r = post_gt[gt]
agg_att_d += w * r["att_d"]
agg_acrt_d += w * r["acrt_d"]
overall_att += w * r["att_glob"]
overall_acrt += w * r["acrt_glob"]
# 5. Bootstrap / Analytical SE
if self.n_bootstrap > 0:
boot_result = self._run_bootstrap(
precomp,
gt_results,
gt_bootstrap_info,
post_gt,
cell_weights,
knots,
degree,
dvals,
overall_att,
overall_acrt,
agg_att_d,
agg_acrt_d,
event_study_effects,
resolved_survey=resolved_survey,
)
att_d_se = boot_result["att_d_se"]
att_d_ci_lower = boot_result["att_d_ci_lower"]
att_d_ci_upper = boot_result["att_d_ci_upper"]
acrt_d_se = boot_result["acrt_d_se"]
acrt_d_ci_lower = boot_result["acrt_d_ci_lower"]
acrt_d_ci_upper = boot_result["acrt_d_ci_upper"]
att_d_p = boot_result["att_d_p"]
acrt_d_p = boot_result["acrt_d_p"]
overall_att_se = boot_result["overall_att_se"]
overall_att_t = safe_inference(overall_att, overall_att_se, self.alpha)[0]
overall_att_p = boot_result["overall_att_p"]
overall_att_ci = boot_result["overall_att_ci"]
overall_acrt_se = boot_result["overall_acrt_se"]
overall_acrt_t = safe_inference(overall_acrt, overall_acrt_se, self.alpha)[0]
overall_acrt_p = boot_result["overall_acrt_p"]
overall_acrt_ci = boot_result["overall_acrt_ci"]
if event_study_effects is not None:
for e, info in event_study_effects.items():
if e in boot_result.get("es_se", {}):
info["se"] = boot_result["es_se"][e]
info["t_stat"] = safe_inference(info["effect"], info["se"], self.alpha)[
0
]
info["p_value"] = boot_result["es_p"][e]
info["conf_int"] = boot_result["es_ci"][e]
else:
# Analytical SEs via influence functions
analytic = self._compute_analytical_se(
precomp,
gt_results,
gt_bootstrap_info,
post_gt,
cell_weights,
knots,
degree,
dvals,
agg_att_d,
agg_acrt_d,
resolved_survey=resolved_survey,
)
att_d_se = analytic["att_d_se"]
acrt_d_se = analytic["acrt_d_se"]
overall_att_se = analytic["overall_att_se"]
overall_acrt_se = analytic["overall_acrt_se"]
# Survey df for t-distribution inference (unit-level, not panel-level)
_survey_df = analytic.get("df_survey")
# Guard: replicate design with undefined df → NaN inference
if (_survey_df is None and resolved_survey is not None
and hasattr(resolved_survey, 'uses_replicate_variance')
and resolved_survey.uses_replicate_variance):
_survey_df = 0
# Recompute survey_metadata from unit-level design so reported
# effective_n/n_psu/df_survey match the inference actually run
_unit_resolved = analytic.get("unit_resolved")
if _unit_resolved is not None:
from diff_diff.survey import compute_survey_metadata
raw_w_unit = _unit_resolved.weights
survey_metadata = compute_survey_metadata(_unit_resolved, raw_w_unit)
# Propagate replicate df override to survey_metadata for display
# (but not the df=0 sentinel — keep metadata as None for undefined df)
if (_survey_df is not None and _survey_df != 0
and survey_metadata is not None):
if survey_metadata.df_survey != _survey_df:
survey_metadata.df_survey = _survey_df
overall_att_t, overall_att_p, overall_att_ci = safe_inference(
overall_att, overall_att_se, self.alpha, df=_survey_df
)
overall_acrt_t, overall_acrt_p, overall_acrt_ci = safe_inference(
overall_acrt, overall_acrt_se, self.alpha, df=_survey_df
)
# Per-grid-point inference for dose-response
for idx in range(n_grid):
_, _, ci = safe_inference(
agg_att_d[idx], att_d_se[idx], self.alpha, df=_survey_df
)
att_d_ci_lower[idx] = ci[0]
att_d_ci_upper[idx] = ci[1]
_, _, ci = safe_inference(
agg_acrt_d[idx], acrt_d_se[idx], self.alpha, df=_survey_df
)
acrt_d_ci_lower[idx] = ci[0]
acrt_d_ci_upper[idx] = ci[1]
# Event study analytical SEs
if event_study_effects is not None:
n_units = precomp["n_units"]
unit_sw = precomp.get("unit_survey_weights")
# Build unit-level ResolvedSurveyDesign once (reused per bin)
unit_resolved_es = None
if resolved_survey is not None:
row_idx = precomp["unit_first_panel_row"]
uw = (
precomp.get("unit_survey_weights")
if precomp.get("unit_survey_weights") is not None
else np.ones(n_units)
)
us = (
resolved_survey.strata[row_idx]
if resolved_survey.strata is not None
else None
)
up = (
resolved_survey.psu[row_idx]
if resolved_survey.psu is not None
else None
)
uf = (
resolved_survey.fpc[row_idx]
if resolved_survey.fpc is not None
else None
)
n_strata_u = len(np.unique(us)) if us is not None else 0
n_psu_u = len(np.unique(up)) if up is not None else 0
unit_resolved_es = resolved_survey.subset_to_units(
row_idx, uw, us, up, uf, n_strata_u, n_psu_u,
)
for e_val, info_e in event_study_effects.items():
# Collect (g,t) cells for this event-time bin
e_gts = [gt for gt in gt_results if gt[1] - gt[0] == e_val]
if not e_gts:
continue
# Weights within this bin: survey-weighted mass or n_treated
if unit_sw is not None:
unit_cohorts = precomp["unit_cohorts"]
ns = np.array(
[float(np.sum(unit_sw[unit_cohorts == gt[0]])) for gt in e_gts],
dtype=float,
)
else:
ns = np.array(
[gt_results[gt]["n_treated"] for gt in e_gts],
dtype=float,
)
total_n = ns.sum()
if total_n == 0:
continue
ws = ns / total_n
# Build per-unit IF for this event-time bin
if_es = np.zeros(n_units)
for idx_cell, gt in enumerate(e_gts):
b_info = gt_bootstrap_info.get(gt, {})
if not b_info:
continue
w = ws[idx_cell]
treated_idx = b_info["treated_indices"]
control_idx = b_info["control_indices"]
n_t = b_info["n_treated"]
n_c = b_info["n_control"]
# Use survey-weighted masses when available
if "w_treated" in b_info:
n_t = b_info["w_treated"]
n_c = b_info["w_control"]
n_total_gt = n_t + n_c
p_1 = n_t / n_total_gt
p_0 = n_c / n_total_gt
att_glob_gt = b_info["att_glob"]
mu_0 = b_info["mu_0"]
delta_y_treated = b_info["delta_y_treated"]
ee_control = b_info["ee_control"]
sw_treated = b_info.get("w_treated_arr")
for k, uid in enumerate(treated_idx):
score_k = delta_y_treated[k] - att_glob_gt - mu_0
if sw_treated is not None:
score_k = sw_treated[k] * score_k
if_es[uid] += w * score_k / p_1 / n_total_gt
for k, uid in enumerate(control_idx):
if_es[uid] -= w * ee_control[k] / p_0 / n_total_gt
# Compute SE: survey-aware TSL or standard sqrt(sum(IF^2))
if unit_resolved_es is not None:
if unit_resolved_es.uses_replicate_variance:
from diff_diff.survey import compute_replicate_if_variance
# Score-scale: psi = w * if_es (matches TSL bread)
psi_es = unit_resolved_es.weights * if_es
variance, _nv = compute_replicate_if_variance(psi_es, unit_resolved_es)
es_se = float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
else:
X_ones_es = np.ones((n_units, 1))
tsl_scale_es = float(unit_resolved_es.weights.sum())
if_es_tsl = if_es * tsl_scale_es
vcov_es = compute_survey_vcov(X_ones_es, if_es_tsl, unit_resolved_es)
es_se = float(np.sqrt(np.abs(vcov_es[0, 0])))
else:
es_se = float(np.sqrt(np.sum(if_es**2)))
t_stat, p_val, ci_es = safe_inference(
info_e["effect"], es_se, self.alpha, df=_survey_df
)
info_e["se"] = es_se
info_e["t_stat"] = t_stat
info_e["p_value"] = p_val
info_e["conf_int"] = ci_es
# 6. Assemble results
dose_response_att = DoseResponseCurve(
dose_grid=dvals,
effects=agg_att_d,
se=att_d_se,
conf_int_lower=att_d_ci_lower,
conf_int_upper=att_d_ci_upper,
target="att",
p_value=att_d_p,
n_bootstrap=self.n_bootstrap,
df_survey=_survey_df,
)
dose_response_acrt = DoseResponseCurve(
dose_grid=dvals,
effects=agg_acrt_d,
se=acrt_d_se,
conf_int_lower=acrt_d_ci_lower,
conf_int_upper=acrt_d_ci_upper,
target="acrt",
p_value=acrt_d_p,
n_bootstrap=self.n_bootstrap,
df_survey=_survey_df,
)
# Strip bootstrap internals from gt_results
clean_gt = {}
for gt, r in gt_results.items():
clean_gt[gt] = {k: v for k, v in r.items() if not k.startswith("_")}
return ContinuousDiDResults(
dose_response_att=dose_response_att,
dose_response_acrt=dose_response_acrt,
overall_att=overall_att,
overall_att_se=overall_att_se,
overall_att_t_stat=overall_att_t,
overall_att_p_value=overall_att_p,
overall_att_conf_int=overall_att_ci,
overall_acrt=overall_acrt,
overall_acrt_se=overall_acrt_se,
overall_acrt_t_stat=overall_acrt_t,
overall_acrt_p_value=overall_acrt_p,
overall_acrt_conf_int=overall_acrt_ci,
group_time_effects=clean_gt,
dose_grid=dvals,
groups=treatment_groups,
time_periods=time_periods,
n_obs=len(df),
n_treated_units=int((unit_cohort > 0).sum()),
n_control_units=n_control,
alpha=self.alpha,
control_group=self.control_group,
degree=self.degree,
num_knots=self.num_knots,
base_period=self.base_period,
anticipation=self.anticipation,
n_bootstrap=self.n_bootstrap,
bootstrap_weights=self.bootstrap_weights,
seed=self.seed,
rank_deficient_action=self.rank_deficient_action,
event_study_effects=event_study_effects,
survey_metadata=survey_metadata,
)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _precompute_structures(
self,
df: pd.DataFrame,
outcome: str,
unit: str,
time: str,
first_treat: str,
dose: str,
time_periods: List[Any],
survey_weights: Optional[np.ndarray] = None,
) -> Dict[str, Any]:
"""Pivot to wide format and build lookup structures."""
all_units = sorted(df[unit].unique())
unit_to_idx = {u: i for i, u in enumerate(all_units)}
n_units = len(all_units)
n_periods = len(time_periods)
period_to_col = {t: j for j, t in enumerate(time_periods)}
# Outcome matrix: (n_units, n_periods)
outcome_matrix = np.full((n_units, n_periods), np.nan)
for _, row in df.iterrows():
i = unit_to_idx[row[unit]]
j = period_to_col[row[time]]
outcome_matrix[i, j] = row[outcome]
# Per-unit cohort and dose
unit_cohorts = np.zeros(n_units, dtype=float)
dose_vector = np.zeros(n_units, dtype=float)
unit_first = df.groupby(unit).first()
for u in all_units:
i = unit_to_idx[u]
unit_cohorts[i] = unit_first.loc[u, first_treat]
dose_vector[i] = unit_first.loc[u, dose]
# Build unit-to-first-panel-row mapping (for subsetting panel-level arrays)
# This maps each unit index to the positional index of its first row in df.
unit_first_panel_row = np.zeros(n_units, dtype=int)
seen_units: set = set()
for pos_idx, (_, row) in enumerate(df.iterrows()):
u = row[unit]
if u not in seen_units:
seen_units.add(u)
unit_first_panel_row[unit_to_idx[u]] = pos_idx
# Per-unit survey weights (take first obs per unit from panel data)
unit_survey_weights = None
if survey_weights is not None:
unit_survey_weights = survey_weights[unit_first_panel_row]
# Cohort masks
cohort_masks = {}
unique_cohorts = np.unique(unit_cohorts)
for c in unique_cohorts:
cohort_masks[c] = unit_cohorts == c
never_treated_mask = unit_cohorts == 0
return {
"all_units": all_units,
"unit_to_idx": unit_to_idx,
"outcome_matrix": outcome_matrix,
"period_to_col": period_to_col,
"unit_cohorts": unit_cohorts,
"dose_vector": dose_vector,
"cohort_masks": cohort_masks,
"never_treated_mask": never_treated_mask,
"time_periods": time_periods,
"n_units": n_units,
"unit_survey_weights": unit_survey_weights,
"unit_first_panel_row": unit_first_panel_row,
}
def _compute_dose_response_gt(
self,
precomp: Dict[str, Any],
g: Any,
t: Any,
knots: np.ndarray,
degree: int,
dvals: np.ndarray,
survey_weights: Optional[np.ndarray] = None,
resolved_survey: object = None,
) -> Optional[Dict[str, Any]]:
"""Compute dose-response for a single (g,t) cell."""
period_to_col = precomp["period_to_col"]
outcome_matrix = precomp["outcome_matrix"]
unit_cohorts = precomp["unit_cohorts"]
dose_vector = precomp["dose_vector"]
never_treated_mask = precomp["never_treated_mask"]
time_periods = precomp["time_periods"]
# Base period selection
is_post = t >= g - self.anticipation
if self.base_period == "varying":
if is_post:
base_t = g - 1 - self.anticipation
else:
# Pre-treatment: use t-1
t_idx = time_periods.index(t)
if t_idx == 0:
return None # No prior period
base_t = time_periods[t_idx - 1]
else:
# Universal base period
base_t = g - 1 - self.anticipation
if base_t not in period_to_col or t not in period_to_col:
return None
col_t = period_to_col[t]
col_base = period_to_col[base_t]
# Treated units: first_treat == g and dose > 0
treated_mask = (unit_cohorts == g) & (dose_vector > 0)
n_treated = int(np.sum(treated_mask))
if n_treated == 0:
return None
# Control units
if self.control_group == "never_treated":
control_mask = never_treated_mask
else:
# Not-yet-treated: never-treated + first_treat > t
control_mask = never_treated_mask | (
(unit_cohorts > t + self.anticipation) & (unit_cohorts != g)
)
n_control = int(np.sum(control_mask))
if n_control == 0:
warnings.warn(
f"No control units for (g={g}, t={t}). Skipping.",
UserWarning,
stacklevel=3,
)
return None
# Outcome changes
delta_y_treated = (
outcome_matrix[treated_mask, col_t] - outcome_matrix[treated_mask, col_base]
)
delta_y_control = (
outcome_matrix[control_mask, col_t] - outcome_matrix[control_mask, col_base]
)
# Subset survey weights to the cell
w_treated = None
w_control = None
if survey_weights is not None:
w_treated = survey_weights[treated_mask]
w_control = survey_weights[control_mask]
# Guard against zero effective mass (e.g., after subpopulation)
if np.sum(w_treated) <= 0 or np.sum(w_control) <= 0:
return {
"att_glob": np.nan, "acrt_glob": np.nan,
"n_treated": 0, "n_control": 0,
"att_d": np.full(len(dvals), np.nan),
"acrt_d": np.full(len(dvals), np.nan),
}
# Control counterfactual (weighted mean when survey weights present)
if w_control is not None:
mu_0 = float(np.average(delta_y_control, weights=w_control))
else:
mu_0 = float(np.mean(delta_y_control))
# Demean
delta_tilde_y = delta_y_treated - mu_0
# Treated doses
treated_doses = dose_vector[treated_mask]
# B-spline OLS
Psi = bspline_design_matrix(treated_doses, knots, degree, include_intercept=True)
n_basis = Psi.shape[1]
# Check for all-same dose
if np.all(treated_doses == treated_doses[0]):
warnings.warn(
f"All treated doses identical in (g={g}, t={t}). " "ACRT(d) will be 0 everywhere.",
UserWarning,
stacklevel=3,
)
# Skip if not enough treated units for OLS (need n > K for residual df)
# When survey weights are present, use positive-weight count as
# the effective sample size — subpopulation() can zero weights
# leaving rows present but the weighted regression underidentified.
n_eff = int(np.count_nonzero(w_treated > 0)) if w_treated is not None else n_treated
if n_eff <= n_basis:
label = "positive-weight treated units" if w_treated is not None else "treated units"
warnings.warn(
f"Not enough {label} ({n_eff}) for {n_basis} basis functions "
f"in (g={g}, t={t}). Skipping cell.",
UserWarning,
stacklevel=3,
)
return None
# OLS or WLS regression
if w_treated is not None:
# WLS: apply sqrt(w) transformation
sqrt_w = np.sqrt(w_treated)
Psi_w = Psi * sqrt_w[:, np.newaxis]
delta_tilde_y_w = delta_tilde_y * sqrt_w
beta_hat, _, _ = solve_ols(
Psi_w,
delta_tilde_y_w,
return_vcov=False,
rank_deficient_action=self.rank_deficient_action,
)
# Residuals on original scale (for influence functions)
beta_pred_tmp = np.where(np.isnan(beta_hat), 0.0, beta_hat)
residuals = delta_tilde_y - Psi @ beta_pred_tmp
else:
beta_hat, residuals, _ = solve_ols(
Psi,
delta_tilde_y,
return_vcov=False,
rank_deficient_action=self.rank_deficient_action,
)
# For prediction: zero out NaN (dropped rank-deficient columns).
# solve_ols sets dropped-column coefficients to NaN (R convention);
# zeroing them produces correct predictions: ATT(d) = intercept
# (constant), ACRT(d) = 0 (derivative of intercept is 0).
beta_pred = np.where(np.isnan(beta_hat), 0.0, beta_hat)
# Evaluate ATT(d) and ACRT(d) at dvals
Psi_eval = bspline_design_matrix(dvals, knots, degree, include_intercept=True)
dPsi_eval = bspline_derivative_design_matrix(dvals, knots, degree, include_intercept=True)
att_d = Psi_eval @ beta_pred
acrt_d = dPsi_eval @ beta_pred
# Summary parameters
if w_treated is not None:
att_glob = float(np.average(delta_y_treated, weights=w_treated) - mu_0)
else:
att_glob = float(np.mean(delta_y_treated) - mu_0)
# ACRT^{glob}: plug-in average of ACRT(D_i) for treated
dPsi_treated = bspline_derivative_design_matrix(
treated_doses, knots, degree, include_intercept=True
)
if w_treated is not None:
acrt_glob = float(np.average(dPsi_treated @ beta_pred, weights=w_treated))
else:
acrt_glob = float(np.mean(dPsi_treated @ beta_pred))
# Store bootstrap info for influence function computation
# bread = (Psi'WPsi / n_treated)^{-1} when survey, (Psi'Psi / n_treated)^{-1} otherwise
if w_treated is not None:
w_treated_sum = float(np.sum(w_treated))
PtWP = Psi.T @ (Psi * w_treated[:, np.newaxis])
# Normalize bread by weighted mass (not raw count) for consistency
# with downstream IF score denominators that also use weighted mass
try:
bread = np.linalg.inv(PtWP / w_treated_sum)
except np.linalg.LinAlgError:
bread = np.linalg.pinv(PtWP / w_treated_sum)
else:
PtP = Psi.T @ Psi
try:
bread = np.linalg.inv(PtP / n_treated)
except np.linalg.LinAlgError:
bread = np.linalg.pinv(PtP / n_treated)
# ee_treated: per-unit estimating equation vectors (K-vector per unit)
# For WLS (survey weights), the score is w_i * X_i * u_i to match the
# weighted bread inv(X'WX / sum(w)). Without this factor the sandwich
# is inconsistent. For OLS (no survey weights), the score is X_i * u_i.
if w_treated is not None:
ee_treated = Psi * (w_treated * residuals)[:, np.newaxis] # (n_treated, K)
else:
ee_treated = Psi * residuals[:, np.newaxis] # (n_treated, K)
# ee_control: per-unit deviation from control mean (weighted for WLS)
if w_control is not None:
ee_control = w_control * (delta_y_control - mu_0) # (n_control,)
else:
ee_control = delta_y_control - mu_0 # (n_control,)
# psi_bar: mean basis vector for treated (weighted when survey)
if w_treated is not None:
psi_bar = np.average(Psi, axis=0, weights=w_treated)
else:
psi_bar = np.mean(Psi, axis=0) # (K,)
# Unit indices for bootstrap
treated_indices = np.where(treated_mask)[0]
control_indices = np.where(control_mask)[0]
# dpsi_bar: mean derivative basis vector (weighted when survey)
if w_treated is not None:
dpsi_bar = np.average(dPsi_treated, axis=0, weights=w_treated)
else:
dpsi_bar = np.mean(dPsi_treated, axis=0)
bootstrap_info = {
"bread": bread,
"ee_treated": ee_treated,
"ee_control": ee_control,
"psi_bar": psi_bar,
"dpsi_bar": dpsi_bar,
"beta_hat": beta_hat,
"beta_pred": beta_pred,
"treated_indices": treated_indices,
"control_indices": control_indices,
"n_treated": n_treated,
"n_control": n_control,
"Psi_eval": Psi_eval,
"dPsi_eval": dPsi_eval,
"dPsi_treated": dPsi_treated,
"delta_y_treated": delta_y_treated,
"delta_y_control": delta_y_control,
"mu_0": mu_0,
"att_glob": att_glob,
"acrt_glob": acrt_glob,
}
# Store survey-weighted masses and per-unit arrays for IF linearization
if w_treated is not None:
bootstrap_info["w_treated"] = float(np.sum(w_treated))
bootstrap_info["w_control"] = float(np.sum(w_control))
bootstrap_info["w_treated_arr"] = w_treated
bootstrap_info["w_control_arr"] = w_control
return {
"att_d": att_d,
"acrt_d": acrt_d,
"att_glob": att_glob,
"acrt_glob": acrt_glob,
"beta_hat": beta_hat,
"n_treated": n_treated,
"n_control": n_control,
"_bootstrap_info": bootstrap_info,
}
def _aggregate_event_study(
self,
gt_results: Dict[Tuple, Dict],
gt_bootstrap_info: Dict[Tuple, Dict] = None,
unit_survey_weights: Optional[np.ndarray] = None,
unit_cohorts: Optional[np.ndarray] = None,
anticipation: int = 0,
) -> Dict[int, Dict[str, Any]]:
"""Aggregate binarized ATT_glob by relative period."""
effects_by_e: Dict[int, List[Tuple[float, float, Tuple]]] = {}
for (g, t), r in gt_results.items():
e = t - g
if anticipation > 0 and e < -anticipation:
continue
if e not in effects_by_e:
effects_by_e[e] = []
# Compute weight for this (g,t) cell
if unit_survey_weights is not None and unit_cohorts is not None:
# Survey-weighted: sum of survey weights for treated units in group g
g_mask = unit_cohorts == g
cell_weight = float(np.sum(unit_survey_weights[g_mask]))
else:
cell_weight = float(r["n_treated"])
effects_by_e[e].append((r["att_glob"], cell_weight, (g, t)))
result = {}
for e, entries in sorted(effects_by_e.items()):
effects = np.array([x[0] for x in entries])
weights = np.array([x[1] for x in entries])
if np.sum(weights) > 0:
w = weights / np.sum(weights)
agg = float(np.sum(w * effects))
else:
agg = np.nan
result[e] = {
"effect": agg,
"se": np.nan,
"t_stat": np.nan,
"p_value": np.nan,
"conf_int": (np.nan, np.nan),
}
return result
def _compute_analytical_se(
self,
precomp: Dict[str, Any],
gt_results: Dict[Tuple, Dict],
gt_bootstrap_info: Dict[Tuple, Dict],
post_gt: Dict[Tuple, Dict],
cell_weights: Dict[Tuple, float],
knots: np.ndarray,
degree: int,
dvals: np.ndarray,
agg_att_d: np.ndarray,
agg_acrt_d: np.ndarray,
resolved_survey: object = None,
) -> Dict[str, Any]:
"""Compute analytical SEs using influence functions."""
n_units = precomp["n_units"]
n_grid = len(dvals)
# Build per-unit influence functions for aggregated parameters
# IF_i for overall ATT_glob (binarized)
if_att_glob = np.zeros(n_units)
if_acrt_glob = np.zeros(n_units)
if_att_d = np.zeros((n_units, n_grid))
if_acrt_d = np.zeros((n_units, n_grid))
for gt, w in cell_weights.items():
if w == 0:
continue
info = gt_bootstrap_info[gt]
if not info:
continue
treated_idx = info["treated_indices"]
control_idx = info["control_indices"]
n_t = info["n_treated"]
n_c = info["n_control"]
# Use survey-weighted masses when available
if "w_treated" in info:
n_t = info["w_treated"]
n_c = info["w_control"]
bread = info["bread"]
ee_treated = info["ee_treated"]
ee_control = info["ee_control"]
psi_bar = info["psi_bar"]
dpsi_bar = info["dpsi_bar"]
Psi_eval = info["Psi_eval"]
dPsi_eval = info["dPsi_eval"]
att_glob_gt = info["att_glob"]
mu_0 = info["mu_0"]
delta_y_treated = info["delta_y_treated"]
# Per-unit survey weight array (None when no survey)
sw_treated = info.get("w_treated_arr")
n_total = n_t + n_c
p_1 = n_t / n_total
p_0 = n_c / n_total
# IF for ATT_glob (binarized DiD)
# When survey weights are present, each unit's score includes its
# survey weight w_k so the sandwich is consistent with the weighted
# estimand. ee_control already contains the w_k factor (set in
# _compute_dose_response_gt); delta_y_treated needs it here.
for k, idx in enumerate(treated_idx):
score_k = delta_y_treated[k] - att_glob_gt - mu_0
if sw_treated is not None:
score_k = sw_treated[k] * score_k
if_att_glob[idx] += w * score_k / p_1 / n_total
for k, idx in enumerate(control_idx):
if_att_glob[idx] -= w * ee_control[k] / p_0 / n_total
# IF for beta perturbation → ATT(d) and ACRT(d)
# beta perturbation from treated: bread @ (1/n_t) * sum w_i * ee_treated_i
# beta perturbation from control: -bread @ psi_bar * (1/n_c) * sum w_i * ee_control_i
# ATT_b(d) = Psi_eval @ beta_b => IF_i(d) contribution
# Treated unit contributions to beta
for k, idx in enumerate(treated_idx):
beta_pert = bread @ ee_treated[k] / n_t
if_att_d[idx] += w * (Psi_eval @ beta_pert)
if_acrt_d[idx] += w * (dPsi_eval @ beta_pert)
# Control unit contributions to beta (through mu_0)
for k, idx in enumerate(control_idx):
beta_pert = -bread @ psi_bar * ee_control[k] / n_c
if_att_d[idx] += w * (Psi_eval @ beta_pert)
if_acrt_d[idx] += w * (dPsi_eval @ beta_pert)
# ACRT_glob IF: (1/n_t) sum_j dpsi(D_j)' @ beta_pert
for k, idx in enumerate(treated_idx):
beta_pert = bread @ ee_treated[k] / n_t
if_acrt_glob[idx] += w * float(dpsi_bar @ beta_pert)
for k, idx in enumerate(control_idx):
beta_pert = -bread @ psi_bar * ee_control[k] / n_c
if_acrt_glob[idx] += w * float(dpsi_bar @ beta_pert)
# Compute SEs from influence functions
if resolved_survey is not None:
# Survey design: use TSL variance on the aggregate influence functions.
# The IFs serve as "residuals" in the TSL sandwich; X is a column of ones
# (the estimand is a scalar/vector mean of the IFs).
#
# The resolved_survey has panel-level arrays (n_obs = n_units * n_periods),
# but influence functions are unit-level (n_units). Build a unit-level
# ResolvedSurveyDesign by subsetting to one obs per unit.
row_idx = precomp["unit_first_panel_row"]
unit_weights = precomp.get("unit_survey_weights")
if unit_weights is None:
unit_weights = np.ones(n_units)
unit_strata = (
resolved_survey.strata[row_idx] if resolved_survey.strata is not None else None
)
unit_psu = resolved_survey.psu[row_idx] if resolved_survey.psu is not None else None
unit_fpc = resolved_survey.fpc[row_idx] if resolved_survey.fpc is not None else None
# Count unique strata/PSU in the unit-level subset
n_strata_unit = len(np.unique(unit_strata)) if unit_strata is not None else 0
n_psu_unit = len(np.unique(unit_psu)) if unit_psu is not None else 0
unit_resolved = resolved_survey.subset_to_units(
row_idx, unit_weights, unit_strata, unit_psu, unit_fpc,
n_strata_unit, n_psu_unit,
)
X_ones = np.ones((n_units, 1))
if unit_resolved.uses_replicate_variance:
# Replicate-weight variance: score-scale IFs to match TSL bread.
# TSL path does: scores = w * (if * tsl_scale), bread = 1/sum(w)^2
# Equivalent psi for replicate: w * if_vals * tsl_scale / sum(w) = w * if_vals
from diff_diff.survey import compute_replicate_if_variance
_w_rep = unit_resolved.weights
_rep_n_valid = unit_resolved.n_replicates # track effective count
def _rep_se(if_vals):
nonlocal _rep_n_valid
psi_scaled = _w_rep * if_vals
v, nv = compute_replicate_if_variance(psi_scaled, unit_resolved)
_rep_n_valid = min(_rep_n_valid, nv) # worst-case valid count
return float(np.sqrt(max(v, 0.0))) if np.isfinite(v) else np.nan
overall_att_se = _rep_se(if_att_glob)
overall_acrt_se = _rep_se(if_acrt_glob)
att_d_se = np.zeros(n_grid)
acrt_d_se = np.zeros(n_grid)
for d_idx in range(n_grid):
att_d_se[d_idx] = _rep_se(if_att_d[:, d_idx])
acrt_d_se[d_idx] = _rep_se(if_acrt_d[:, d_idx])
else:
# TSL: rescale IFs from 1/n convention to score scale for sandwich.
tsl_scale = float(unit_resolved.weights.sum())
if_att_glob_tsl = if_att_glob * tsl_scale
if_acrt_glob_tsl = if_acrt_glob * tsl_scale
if_att_d_tsl = if_att_d * tsl_scale
if_acrt_d_tsl = if_acrt_d * tsl_scale
vcov_att = compute_survey_vcov(X_ones, if_att_glob_tsl, unit_resolved)
overall_att_se = float(np.sqrt(np.abs(vcov_att[0, 0])))
vcov_acrt = compute_survey_vcov(X_ones, if_acrt_glob_tsl, unit_resolved)
overall_acrt_se = float(np.sqrt(np.abs(vcov_acrt[0, 0])))
att_d_se = np.zeros(n_grid)
acrt_d_se = np.zeros(n_grid)
for d_idx in range(n_grid):
vcov_d = compute_survey_vcov(X_ones, if_att_d_tsl[:, d_idx], unit_resolved)
att_d_se[d_idx] = float(np.sqrt(np.abs(vcov_d[0, 0])))
vcov_d = compute_survey_vcov(X_ones, if_acrt_d_tsl[:, d_idx], unit_resolved)
acrt_d_se[d_idx] = float(np.sqrt(np.abs(vcov_d[0, 0])))
else:
# SE = sqrt(sum(IF_i^2)), matching CallawaySantAnna's convention
# (per-unit IFs already contain 1/n_t, 1/n_c scaling)
overall_att_se = float(np.sqrt(np.sum(if_att_glob**2)))
overall_acrt_se = float(np.sqrt(np.sum(if_acrt_glob**2)))
att_d_se = np.sqrt(np.sum(if_att_d**2, axis=0))
acrt_d_se = np.sqrt(np.sum(if_acrt_d**2, axis=0))
# Return unit-level survey df and resolved design for metadata recomputation
# Only override with n_valid-based df when replicates were actually dropped
if resolved_survey is not None and hasattr(resolved_survey, 'uses_replicate_variance') and resolved_survey.uses_replicate_variance:
if _rep_n_valid < unit_resolved.n_replicates:
unit_df_survey = _rep_n_valid - 1 if _rep_n_valid > 1 else None
else:
unit_df_survey = unit_resolved.df_survey
else:
unit_df_survey = unit_resolved.df_survey if resolved_survey is not None else None
return {
"overall_att_se": overall_att_se,
"overall_acrt_se": overall_acrt_se,
"att_d_se": att_d_se,
"acrt_d_se": acrt_d_se,
"df_survey": unit_df_survey,
"unit_resolved": unit_resolved if resolved_survey is not None else None,
}
def _run_bootstrap(
self,
precomp: Dict[str, Any],
gt_results: Dict[Tuple, Dict],
gt_bootstrap_info: Dict[Tuple, Dict],
post_gt: Dict[Tuple, Dict],
cell_weights: Dict[Tuple, float],
knots: np.ndarray,
degree: int,
dvals: np.ndarray,
original_att: float,
original_acrt: float,
original_att_d: np.ndarray,
original_acrt_d: np.ndarray,
event_study_effects: Optional[Dict[int, Dict]],
resolved_survey: object = None,
) -> Dict[str, Any]:
"""Run multiplier bootstrap inference."""
if self.n_bootstrap < 50:
warnings.warn(
f"n_bootstrap={self.n_bootstrap} is low. Consider n_bootstrap >= 199 "
"for reliable inference.",
UserWarning,
stacklevel=3,
)
# Reject replicate-weight designs for bootstrap — replicate variance
# is an analytical alternative to bootstrap, not compatible with it
if resolved_survey is not None and hasattr(resolved_survey, "uses_replicate_variance") and resolved_survey.uses_replicate_variance:
raise NotImplementedError(
"ContinuousDiD bootstrap (n_bootstrap > 0) is not supported "
"with replicate-weight survey designs. Replicate weights provide "
"analytical variance; use n_bootstrap=0 instead."
)
rng = np.random.default_rng(self.seed)
n_units = precomp["n_units"]
n_grid = len(dvals)
# Build unit-level ResolvedSurveyDesign for survey-aware bootstrap
unit_resolved = None
if resolved_survey is not None:
from diff_diff.survey import ResolvedSurveyDesign
row_idx = precomp["unit_first_panel_row"]
unit_weights = precomp.get("unit_survey_weights")
if unit_weights is None:
unit_weights = np.ones(n_units)
unit_strata = (
resolved_survey.strata[row_idx] if resolved_survey.strata is not None else None
)
unit_psu = resolved_survey.psu[row_idx] if resolved_survey.psu is not None else None
unit_fpc = resolved_survey.fpc[row_idx] if resolved_survey.fpc is not None else None
n_strata_u = len(np.unique(unit_strata)) if unit_strata is not None else 0
n_psu_u = len(np.unique(unit_psu)) if unit_psu is not None else 0
unit_resolved = resolved_survey.subset_to_units(
row_idx, unit_weights, unit_strata, unit_psu, unit_fpc,
n_strata_u, n_psu_u,
)
# Generate bootstrap weights — PSU-level when survey design is present
_use_survey_bootstrap = unit_resolved is not None and (
unit_resolved.strata is not None
or unit_resolved.psu is not None
or unit_resolved.fpc is not None
)
if _use_survey_bootstrap:
from diff_diff.bootstrap_utils import (
generate_survey_multiplier_weights_batch,
)
psu_weights, psu_ids = generate_survey_multiplier_weights_batch(
self.n_bootstrap, unit_resolved, self.bootstrap_weights, rng
)
# Build unit -> PSU column map
if unit_resolved.psu is not None:
psu_id_to_col = {int(p): c for c, p in enumerate(psu_ids)}
unit_to_psu_col = np.array(
[psu_id_to_col[int(unit_resolved.psu[i])] for i in range(n_units)]
)
else:
unit_to_psu_col = np.arange(n_units)
all_weights = psu_weights[:, unit_to_psu_col]
else:
all_weights = generate_bootstrap_weights_batch(
self.n_bootstrap, n_units, self.bootstrap_weights, rng
)
boot_att_glob = np.zeros(self.n_bootstrap)
boot_acrt_glob = np.zeros(self.n_bootstrap)
boot_att_d = np.zeros((self.n_bootstrap, n_grid))
boot_acrt_d = np.zeros((self.n_bootstrap, n_grid))
# Event study bootstrap — compute weights per event-time bin
es_keys = sorted(event_study_effects.keys()) if event_study_effects else []
boot_es = {e: np.zeros(self.n_bootstrap) for e in es_keys}
# Per-(g,t) weight within event-time bin — use survey-weighted cohort
# masses when available, matching _aggregate_event_study.
unit_sw = precomp.get("unit_survey_weights")
unit_cohorts = precomp["unit_cohorts"]
es_cell_weights: Dict[Tuple, float] = {}
if event_study_effects is not None:
from collections import defaultdict
es_bin_total: Dict[int, float] = defaultdict(float)
for gt, r in gt_results.items():
g_val, t_val = gt
e = t_val - g_val
if self.anticipation > 0 and e < -self.anticipation:
continue
if unit_sw is not None:
g_mask = unit_cohorts == g_val
cell_mass = float(np.sum(unit_sw[g_mask]))
else:
cell_mass = float(r["n_treated"])
es_bin_total[e] += cell_mass
for gt, r in gt_results.items():
g_val, t_val = gt
e = t_val - g_val
if self.anticipation > 0 and e < -self.anticipation:
continue
if unit_sw is not None:
g_mask = unit_cohorts == g_val
cell_mass = float(np.sum(unit_sw[g_mask]))
else:
cell_mass = float(r["n_treated"])
if es_bin_total[e] > 0:
es_cell_weights[gt] = cell_mass / es_bin_total[e]
# Helper to bootstrap a single (g,t) cell
def _bootstrap_gt_cell(gt, info):
"""Returns att_glob_b array (B,) for this cell."""
treated_idx = info["treated_indices"]
control_idx = info["control_indices"]
n_t = info["n_treated"]
n_c = info["n_control"]
# Use survey-weighted masses when available (matching analytical SE)
if "w_treated" in info:
n_t = info["w_treated"]
n_c = info["w_control"]
bread = info["bread"]
ee_treated = info["ee_treated"]
ee_control = info["ee_control"]
psi_bar = info["psi_bar"]
beta_pred = info["beta_pred"]
Psi_eval = info["Psi_eval"]
dPsi_eval = info["dPsi_eval"]
dPsi_treated = info["dPsi_treated"]
delta_y_treated = info["delta_y_treated"]
mu_0 = info["mu_0"]
att_glob_gt = info["att_glob"]
sw_treated = info.get("w_treated_arr")
w_treated = all_weights[:, treated_idx]
w_control = all_weights[:, control_idx]
with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
treated_sum = w_treated @ ee_treated / n_t
control_sum = (w_control @ ee_control) / n_c
psi_bar_outer = psi_bar[np.newaxis, :]
delta_beta = (treated_sum - control_sum[:, np.newaxis] * psi_bar_outer) @ bread.T
beta_b = beta_pred[np.newaxis, :] + delta_beta
att_d_b = beta_b @ Psi_eval.T
acrt_d_b = beta_b @ dPsi_eval.T
mu_0_pert = (w_control @ ee_control) / n_c
# ATT_glob perturbation: weight scores by survey weight w_k
# when present, matching the analytical IF path.
att_glob_score = delta_y_treated - att_glob_gt - mu_0
if sw_treated is not None:
att_glob_score = sw_treated * att_glob_score
mean_dy_treated_pert = (w_treated @ att_glob_score) / n_t
att_glob_b = att_glob_gt + mean_dy_treated_pert - mu_0_pert
if sw_treated is not None:
sw_norm = sw_treated / sw_treated.sum()
dpsi_mean = sw_norm @ dPsi_treated
else:
dpsi_mean = np.mean(dPsi_treated, axis=0)
acrt_glob_b = delta_beta @ dpsi_mean
return att_d_b, acrt_d_b, att_glob_b, acrt_glob_b, info.get("acrt_glob", 0.0)
# Iterate over post-treatment cells for dose-response/overall aggregation
for gt, w in cell_weights.items():
if w == 0:
continue
info = gt_bootstrap_info[gt]
if not info:
continue
att_d_b, acrt_d_b, att_glob_b, acrt_glob_b, acrt_glob_pt = _bootstrap_gt_cell(gt, info)
boot_att_d += w * att_d_b
boot_acrt_d += w * acrt_d_b
boot_att_glob += w * att_glob_b
boot_acrt_glob += w * (acrt_glob_pt + acrt_glob_b)
# Event study bootstrap — iterate over ALL (g,t) cells
if event_study_effects is not None:
for gt, r in gt_results.items():
info = gt_bootstrap_info[gt]
if not info:
continue
g_val, t_val = gt
e = t_val - g_val
if e not in boot_es:
continue
es_w = es_cell_weights.get(gt, 0.0)
if es_w == 0:
continue
_, _, att_glob_b, _, _ = _bootstrap_gt_cell(gt, info)
boot_es[e] += es_w * att_glob_b
# Compute statistics
result: Dict[str, Any] = {}
# Per-grid-point
att_d_se = np.full(n_grid, np.nan)
att_d_ci_lower = np.full(n_grid, np.nan)
att_d_ci_upper = np.full(n_grid, np.nan)
acrt_d_se = np.full(n_grid, np.nan)
acrt_d_ci_lower = np.full(n_grid, np.nan)
acrt_d_ci_upper = np.full(n_grid, np.nan)
att_d_p = np.full(n_grid, np.nan)
acrt_d_p = np.full(n_grid, np.nan)
for idx in range(n_grid):
se, ci, p = compute_effect_bootstrap_stats(
original_att_d[idx],
boot_att_d[:, idx],
alpha=self.alpha,
context=f"ATT(d) at grid point {idx}",
)
att_d_se[idx] = se
att_d_ci_lower[idx] = ci[0]
att_d_ci_upper[idx] = ci[1]
att_d_p[idx] = p
se, ci, p = compute_effect_bootstrap_stats(
original_acrt_d[idx],
boot_acrt_d[:, idx],
alpha=self.alpha,
context=f"ACRT(d) at grid point {idx}",
)
acrt_d_se[idx] = se
acrt_d_ci_lower[idx] = ci[0]
acrt_d_ci_upper[idx] = ci[1]
acrt_d_p[idx] = p
result["att_d_se"] = att_d_se
result["att_d_ci_lower"] = att_d_ci_lower
result["att_d_ci_upper"] = att_d_ci_upper
result["acrt_d_se"] = acrt_d_se
result["acrt_d_ci_lower"] = acrt_d_ci_lower
result["acrt_d_ci_upper"] = acrt_d_ci_upper
result["att_d_p"] = att_d_p
result["acrt_d_p"] = acrt_d_p
# Overall
se, ci, p = compute_effect_bootstrap_stats(
original_att,
boot_att_glob,
alpha=self.alpha,
context="overall ATT_glob",
)
result["overall_att_se"] = se
result["overall_att_ci"] = ci
result["overall_att_p"] = p
se, ci, p = compute_effect_bootstrap_stats(
original_acrt,
boot_acrt_glob,
alpha=self.alpha,
context="overall ACRT_glob",
)
result["overall_acrt_se"] = se
result["overall_acrt_ci"] = ci
result["overall_acrt_p"] = p
# Event study SEs
if event_study_effects is not None:
es_se = {}
es_ci = {}
es_p = {}
for e in es_keys:
se_e, ci_e, p_e = compute_effect_bootstrap_stats(
event_study_effects[e]["effect"],
boot_es[e],
alpha=self.alpha,
context=f"event study e={e}",
)
es_se[e] = se_e
es_ci[e] = ci_e
es_p[e] = p_e
result["es_se"] = es_se
result["es_ci"] = es_ci
result["es_p"] = es_p
return result