Source code for diff_diff.staggered

"""
Staggered Difference-in-Differences estimators.

Implements modern methods for DiD with variation in treatment timing,
including the Callaway-Sant'Anna (2021) estimator.
"""

import warnings
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
from scipy import linalg as scipy_linalg

from diff_diff.linalg import (
    _check_propensity_diagnostics,
    _detect_rank_deficiency,
    _format_dropped_columns,
    solve_logit,
    solve_ols,
)
from diff_diff.staggered_aggregation import (
    CallawaySantAnnaAggregationMixin,
)
from diff_diff.staggered_bootstrap import (
    CallawaySantAnnaBootstrapMixin,
    CSBootstrapResults,
)

# Import from split modules
from diff_diff.staggered_results import (
    CallawaySantAnnaResults,
    GroupTimeEffect,
)
from diff_diff.utils import safe_inference, safe_inference_batch

# Re-export for backward compatibility
__all__ = [
    "CallawaySantAnna",
    "CallawaySantAnnaResults",
    "CSBootstrapResults",
    "GroupTimeEffect",
]

# Type alias for pre-computed structures
PrecomputedData = Dict[str, Any]


def _linear_regression(
    X: np.ndarray,
    y: np.ndarray,
    rank_deficient_action: str = "warn",
    weights: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Fit OLS regression.

    Parameters
    ----------
    X : np.ndarray
        Feature matrix (n_samples, n_features). Intercept added automatically.
    y : np.ndarray
        Outcome variable.
    rank_deficient_action : str, default "warn"
        Action when design matrix is rank-deficient:
        - "warn": Issue warning and drop linearly dependent columns (default)
        - "error": Raise ValueError
        - "silent": Drop columns silently without warning
    weights : np.ndarray, optional
        Observation weights for WLS. When None, OLS is used.

    Returns
    -------
    beta : np.ndarray
        Fitted coefficients (including intercept).
    residuals : np.ndarray
        Residuals from the fit.
    """
    n = X.shape[0]
    # Add intercept
    X_with_intercept = np.column_stack([np.ones(n), X])

    # Use unified OLS backend (no vcov needed)
    beta, residuals, _ = solve_ols(
        X_with_intercept,
        y,
        return_vcov=False,
        rank_deficient_action=rank_deficient_action,
        weights=weights,
    )

    return beta, residuals


def _safe_inv(
    A: np.ndarray,
    tracker: Optional[list] = None,
) -> np.ndarray:
    """Invert a square matrix with lstsq fallback for near-singular cases.

    Parameters
    ----------
    A : np.ndarray
        Square matrix to invert.
    tracker : list, optional
        When provided, one condition-number sample of ``A`` is appended on
        every LinAlgError fallback. ``CallawaySantAnna.fit()`` initializes
        a list and emits a single aggregate `UserWarning` after the fit
        finishes, rather than surfacing a separate warning per fallback.
        Sibling of finding #17 in the Phase 2 silent-failures audit.
    """
    try:
        return np.linalg.solve(A, np.eye(A.shape[0]))
    except np.linalg.LinAlgError:
        if tracker is not None:
            with np.errstate(invalid="ignore", over="ignore"):
                tracker.append(float(np.linalg.cond(A)))
        return np.linalg.lstsq(A, np.eye(A.shape[0]), rcond=None)[0]


[docs] class CallawaySantAnna( CallawaySantAnnaBootstrapMixin, CallawaySantAnnaAggregationMixin, ): """ Callaway-Sant'Anna (2021) estimator for staggered Difference-in-Differences. This estimator handles DiD designs with variation in treatment timing (staggered adoption) and heterogeneous treatment effects. It avoids the bias of traditional two-way fixed effects (TWFE) estimators by: 1. Computing group-time average treatment effects ATT(g,t) for each cohort g (units first treated in period g) and time t. 2. Aggregating these to summary measures (overall ATT, event study, etc.) using appropriate weights. Parameters ---------- control_group : str, default="never_treated" Which units to use as controls: - "never_treated": Use only never-treated units (recommended) - "not_yet_treated": Use never-treated and not-yet-treated units anticipation : int, default=0 Number of periods before treatment where effects may occur. Set to > 0 if treatment effects can begin before the official treatment date. estimation_method : str, default="dr" Estimation method: - "dr": Doubly robust (recommended) - "ipw": Inverse probability weighting - "reg": Outcome regression alpha : float, default=0.05 Significance level for confidence intervals. cluster : str, optional Column name for cluster-robust standard errors. Defaults to unit-level clustering. n_bootstrap : int, default=0 Number of bootstrap iterations for inference. If 0, uses analytical standard errors. Recommended: 999 or more for reliable inference. .. note:: Memory Usage The bootstrap stores all weights in memory as a (n_bootstrap, n_units) float64 array. For large datasets, this can be significant: - 1K bootstrap × 10K units = ~80 MB - 10K bootstrap × 100K units = ~8 GB Consider reducing n_bootstrap if memory is constrained. bootstrap_weights : str, default="rademacher" Type of weights for multiplier bootstrap: - "rademacher": +1/-1 with equal probability (standard choice) - "mammen": Two-point distribution (asymptotically valid, matches skewness) - "webb": Six-point distribution (recommended when n_clusters < 20) seed : int, optional Random seed for reproducibility. rank_deficient_action : str, default="warn" Action when design matrix is rank-deficient (linearly dependent columns): - "warn": Issue warning and drop linearly dependent columns (default) - "error": Raise ValueError - "silent": Drop columns silently without warning base_period : str, default="varying" Method for selecting the base (reference) period for computing ATT(g,t). Options: - "varying": For pre-treatment periods (t < g - anticipation), use t-1 as base (consecutive comparisons). For post-treatment, use g-1-anticipation. Requires t-1 to exist in data. - "universal": Always use g-1-anticipation as base period. Both produce identical post-treatment effects. Matches R's did::att_gt() base_period parameter. cband : bool, default=True Whether to compute simultaneous confidence bands (sup-t) for event study aggregation. Requires ``n_bootstrap > 0``. When True, results include ``cband_crit_value`` and per-event-time ``cband_conf_int`` entries controlling family-wise error rate. pscore_trim : float, default=0.01 Trimming bound for propensity scores. Scores are clipped to ``[pscore_trim, 1 - pscore_trim]`` before weight computation in IPW and DR estimation. Must be in ``(0, 0.5)``. panel : bool, default=True Whether the data is a balanced/unbalanced panel (units observed across multiple time periods). Set to ``False`` for stationary repeated cross-sections where each observation has a unique unit ID and units do not repeat across periods. Requires that the cross-sectional samples are drawn from the same population in each period (stationarity). Uses cross-sectional DRDID (Sant'Anna & Zhao 2020, Section 4) with per-observation influence functions. epv_threshold : float, default=10 Events Per Variable threshold for propensity score logit. When the ratio of minority-class observations to predictor variables (excluding intercept) falls below this value, a warning is emitted (or ``ValueError`` raised if ``rank_deficient_action="error"``). Based on Peduzzi et al. (1996). Only applies to IPW and DR estimation methods. Use ``diagnose_propensity()`` for a pre-estimation check across all cohorts. pscore_fallback : str, default="error" Action when propensity score estimation fails entirely (``LinAlgError`` or ``ValueError`` from IRLS): - "error": Raise the exception (default). Ensures the user is aware of estimation failures. - "unconditional": Fall back to unconditional propensity with a warning. For IPW, this drops all covariates. For DR, the propensity model becomes unconditional but outcome regression still uses covariates. When ``rank_deficient_action="error"``, errors are always re-raised regardless of this setting. Attributes ---------- results_ : CallawaySantAnnaResults Estimation results after calling fit(). is_fitted_ : bool Whether the model has been fitted. Examples -------- Basic usage: >>> import pandas as pd >>> from diff_diff import CallawaySantAnna >>> >>> # Panel data with staggered treatment >>> # 'first_treat' = period when unit was first treated (0 if never treated) >>> data = pd.DataFrame({ ... 'unit': [...], ... 'time': [...], ... 'outcome': [...], ... 'first_treat': [...] # 0 for never-treated, else first treatment period ... }) >>> >>> cs = CallawaySantAnna() >>> results = cs.fit(data, outcome='outcome', unit='unit', ... time='time', first_treat='first_treat') >>> >>> results.print_summary() With event study aggregation: >>> cs = CallawaySantAnna() >>> results = cs.fit(data, outcome='outcome', unit='unit', ... time='time', first_treat='first_treat', ... aggregate='event_study') >>> >>> # Plot event study >>> from diff_diff import plot_event_study >>> plot_event_study(results) With covariate adjustment (conditional parallel trends): >>> # When parallel trends only holds conditional on covariates >>> cs = CallawaySantAnna(estimation_method='dr') # doubly robust >>> results = cs.fit(data, outcome='outcome', unit='unit', ... time='time', first_treat='first_treat', ... covariates=['age', 'income']) >>> >>> # DR is recommended: consistent if either outcome model >>> # or propensity model is correctly specified Notes ----- The key innovation of Callaway & Sant'Anna (2021) is the disaggregated approach: instead of estimating a single treatment effect, they estimate ATT(g,t) for each cohort-time pair. This avoids the "forbidden comparison" problem where already-treated units act as controls. The ATT(g,t) is identified under parallel trends conditional on covariates: E[Y(0)_t - Y(0)_g-1 | G=g] = E[Y(0)_t - Y(0)_g-1 | C=1] where G=g indicates treatment cohort g and C=1 indicates control units. This uses g-1 as the base period, which applies to post-treatment (t >= g). With base_period="varying" (default), pre-treatment uses t-1 as base for consecutive comparisons useful in parallel trends diagnostics. References ---------- Callaway, B., & Sant'Anna, P. H. (2021). Difference-in-Differences with multiple time periods. Journal of Econometrics, 225(2), 200-230. """
[docs] def __init__( self, control_group: str = "never_treated", anticipation: int = 0, estimation_method: str = "dr", alpha: float = 0.05, cluster: Optional[str] = None, n_bootstrap: int = 0, bootstrap_weights: Optional[str] = None, seed: Optional[int] = None, rank_deficient_action: str = "warn", base_period: str = "varying", cband: bool = True, pscore_trim: float = 0.01, panel: bool = True, epv_threshold: float = 10, pscore_fallback: str = "error", ): import warnings if control_group not in ["never_treated", "not_yet_treated"]: raise ValueError( f"control_group must be 'never_treated' or 'not_yet_treated', " f"got '{control_group}'" ) if estimation_method not in ["dr", "ipw", "reg"]: raise ValueError( f"estimation_method must be 'dr', 'ipw', or 'reg', " f"got '{estimation_method}'" ) if not (0 < pscore_trim < 0.5): raise ValueError(f"pscore_trim must be in (0, 0.5), got {pscore_trim}") if epv_threshold <= 0: raise ValueError(f"epv_threshold must be > 0, got {epv_threshold}") if pscore_fallback not in ["error", "unconditional"]: raise ValueError( f"pscore_fallback must be 'error' or 'unconditional', " f"got '{pscore_fallback}'" ) # Default to rademacher if not specified if bootstrap_weights is None: bootstrap_weights = "rademacher" if bootstrap_weights not in ["rademacher", "mammen", "webb"]: raise ValueError( f"bootstrap_weights must be 'rademacher', 'mammen', or 'webb', " f"got '{bootstrap_weights}'" ) if rank_deficient_action not in ["warn", "error", "silent"]: raise ValueError( f"rank_deficient_action must be 'warn', 'error', or 'silent', " f"got '{rank_deficient_action}'" ) if base_period not in ["varying", "universal"]: raise ValueError( f"base_period must be 'varying' or 'universal', " f"got '{base_period}'" ) self.control_group = control_group self.anticipation = anticipation self.estimation_method = estimation_method self.alpha = alpha self.cluster = cluster self.n_bootstrap = n_bootstrap self.bootstrap_weights = bootstrap_weights self.seed = seed self.rank_deficient_action = rank_deficient_action self.base_period = base_period self.cband = cband self.pscore_trim = pscore_trim self.panel = panel self.epv_threshold = epv_threshold self.pscore_fallback = pscore_fallback self.is_fitted_ = False self.results_: Optional[CallawaySantAnnaResults] = None
[docs] def diagnose_propensity( self, df: pd.DataFrame, outcome: str, unit: str, time: str, first_treat: str, covariates: Optional[List[str]] = None, ) -> pd.DataFrame: """ Check Events Per Variable (EPV) across all cohorts without estimation. Examines the data to identify cohorts where propensity score logit may be unreliable due to too few events per covariate. Based on Peduzzi et al. (1996). This is a raw-count heuristic: it uses total cohort/control unit counts without filtering for missing outcomes, zero survey weights, or period-specific validity. The actual fit-time EPV (stored in ``results.epv_diagnostics``) may be lower because ``fit()`` operates on the valid base/post outcome pair and the positive-weight effective sample. Use this method as a quick pre-check; rely on ``results.epv_diagnostics`` for authoritative per-cell EPV. Parameters ---------- df, outcome, unit, time, first_treat, covariates Same arguments as ``fit()``. Returns ------- pd.DataFrame Per-cohort EPV diagnostics with columns: group, n_treated, n_control, n_covariates, n_params, epv, status. """ if not self.panel: raise NotImplementedError( "diagnose_propensity() is not yet supported for repeated " "cross-section data (panel=False). Use fit() with covariates " "and check results.epv_diagnostics instead." ) if self.control_group == "not_yet_treated": raise NotImplementedError( "diagnose_propensity() is not yet supported for " "control_group='not_yet_treated' because the control set " "varies per (g, t) cell. Use fit() with covariates and " "check results.epv_diagnostics instead." ) if self.estimation_method == "reg": return pd.DataFrame( columns=[ "group", "n_treated", "n_control", "n_covariates", "n_params", "epv", "status", ] ) if not covariates: return pd.DataFrame( columns=[ "group", "n_treated", "n_control", "n_covariates", "n_params", "epv", "status", ] ) # Normalize np.inf → 0 for never-treated encoding (same as fit()) df = df.copy() _inf_mask_diag = df[first_treat].isin([np.inf, float("inf")]) if _inf_mask_diag.any(): n_inf_units = df.loc[_inf_mask_diag, unit].nunique() warnings.warn( f"{n_inf_units} unit(s) have first_treat=inf; recoding to 0 " f"(never-treated). Use first_treat=0 to suppress this warning.", UserWarning, stacklevel=2, ) df[first_treat] = df[first_treat].replace([np.inf, float("inf")], 0) # Compute time_periods and treatment_groups (same logic as fit()) time_periods = sorted(df[time].unique()) treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0]) precomputed = self._precompute_structures( df, outcome, unit, time, first_treat, covariates, time_periods=time_periods, treatment_groups=treatment_groups, ) cohort_masks = precomputed["cohort_masks"] never_treated_mask = precomputed["never_treated_mask"] unit_cohorts = precomputed["unit_cohorts"] n_covariates = len(covariates) n_params = n_covariates # predictor count, excluding intercept (Peduzzi convention) rows = [] for g in sorted(cohort_masks.keys()): treated_mask = cohort_masks[g] if self.control_group == "never_treated": control_mask = never_treated_mask else: base_period_val = g - 1 - self.anticipation nyt_threshold = base_period_val + self.anticipation control_mask = never_treated_mask | ( (unit_cohorts > nyt_threshold) & (unit_cohorts != g) ) n_treated = int(np.sum(treated_mask)) n_control = int(np.sum(control_mask)) n_events = min(n_treated, n_control) epv = n_events / n_params if n_params > 0 else float("inf") if epv >= self.epv_threshold: status = "ok" elif epv >= 2: status = "low" else: status = "critical" rows.append( { "group": g, "n_treated": n_treated, "n_control": n_control, "n_covariates": n_covariates, "n_params": n_params, "epv": round(epv, 1), "status": status, } ) return pd.DataFrame(rows)
@staticmethod def _collapse_survey_to_unit_level(resolved_survey, df, unit_col, all_units): """Create unit-level ResolvedSurveyDesign for panel IF-based variance. Survey design columns are constant within units (validated upstream). This extracts one row per unit, aligned to ``all_units`` ordering. """ from diff_diff.survey import collapse_survey_to_unit_level return collapse_survey_to_unit_level(resolved_survey, df, unit_col, all_units) def _precompute_structures( self, df: pd.DataFrame, outcome: str, unit: str, time: str, first_treat: str, covariates: Optional[List[str]], time_periods: List[Any], treatment_groups: List[Any], resolved_survey=None, ) -> PrecomputedData: """ Pre-compute data structures for efficient ATT(g,t) computation. This pivots data to wide format and pre-computes: - Outcome matrix (units x time periods) - Covariate matrix (units x covariates) from base period - Unit cohort membership masks - Control unit masks Returns ------- PrecomputedData Dictionary with pre-computed structures. """ # Get unique units and their cohort assignments unit_info = df.groupby(unit)[first_treat].first() all_units = unit_info.index.values unit_cohorts = unit_info.values # Create unit index mapping for fast lookups unit_to_idx = {u: i for i, u in enumerate(all_units)} # Pivot outcome to wide format: rows = units, columns = time periods outcome_wide = df.pivot(index=unit, columns=time, values=outcome) # Reindex to ensure all units are present (handles unbalanced panels) outcome_wide = outcome_wide.reindex(all_units) outcome_matrix = outcome_wide.values # Shape: (n_units, n_periods) period_to_col = {t: i for i, t in enumerate(outcome_wide.columns)} # Pre-compute cohort masks (boolean arrays) cohort_masks = {} for g in treatment_groups: cohort_masks[g] = unit_cohorts == g # Never-treated mask # np.inf was normalized to 0 in fit(), so the np.inf check is defensive only never_treated_mask = (unit_cohorts == 0) | (unit_cohorts == np.inf) # Pre-compute covariate matrices by time period if needed # (covariates are retrieved from the base period of each comparison) covariate_by_period = None if covariates: covariate_by_period = {} for t in time_periods: period_data = df[df[time] == t].set_index(unit) period_cov = period_data.reindex(all_units)[covariates] covariate_by_period[t] = period_cov.values # Shape: (n_units, n_covariates) is_balanced = not np.any(np.isnan(outcome_matrix)) # Extract per-unit survey weights (one weight per unit) if resolved_survey is not None: sw_by_unit = ( pd.Series(resolved_survey.weights, index=df.index).groupby(df[unit]).first() ) survey_weights_arr = sw_by_unit.reindex(all_units).values else: survey_weights_arr = None resolved_survey_unit = ( self._collapse_survey_to_unit_level(resolved_survey, df, unit, all_units) if resolved_survey is not None else None ) return { "all_units": all_units, "unit_to_idx": unit_to_idx, "unit_cohorts": unit_cohorts, "outcome_matrix": outcome_matrix, "period_to_col": period_to_col, "cohort_masks": cohort_masks, "never_treated_mask": never_treated_mask, "covariate_by_period": covariate_by_period, "time_periods": time_periods, "is_balanced": is_balanced, "is_panel": True, "canonical_size": len(all_units), "survey_weights": survey_weights_arr, "resolved_survey": resolved_survey, "resolved_survey_unit": resolved_survey_unit, "df_survey": ( resolved_survey_unit.df_survey if resolved_survey_unit is not None else None ), } def _compute_att_gt_fast( self, precomputed: PrecomputedData, g: Any, t: Any, covariates: Optional[List[str]], pscore_cache: Optional[Dict] = None, cho_cache: Optional[Dict] = None, epv_diagnostics: Optional[Dict] = None, ) -> Tuple[Optional[float], float, int, int, Optional[Dict[str, Any]], Optional[float]]: """ Compute ATT(g,t) using pre-computed data structures (fast version). Uses vectorized numpy operations on pre-pivoted outcome matrix instead of repeated pandas filtering. Returns ------- att_gt : float or None se_gt : float n_treated : int n_control : int inf_func_info : dict or None survey_weight_sum : float or None Sum of survey weights for treated units (for aggregation weighting). """ period_to_col = precomputed["period_to_col"] outcome_matrix = precomputed["outcome_matrix"] cohort_masks = precomputed["cohort_masks"] never_treated_mask = precomputed["never_treated_mask"] unit_cohorts = precomputed["unit_cohorts"] covariate_by_period = precomputed["covariate_by_period"] # Base period selection based on mode if self.base_period == "universal": # Universal: always use g - 1 - anticipation base_period_val = g - 1 - self.anticipation else: # varying if t < g - self.anticipation: # Pre-treatment: use t - 1 (consecutive comparison) base_period_val = t - 1 else: # Post-treatment: use g - 1 - anticipation base_period_val = g - 1 - self.anticipation if base_period_val not in period_to_col: # Base period must exist; no fallback to maintain methodological consistency return None, 0.0, 0, 0, None, None # Check if periods exist in the data if base_period_val not in period_to_col or t not in period_to_col: return None, 0.0, 0, 0, None, None base_col = period_to_col[base_period_val] post_col = period_to_col[t] # Get treated units mask (cohort g) treated_mask = cohort_masks[g] # Get control units mask if self.control_group == "never_treated": control_mask = never_treated_mask else: # not_yet_treated # Not yet treated at BOTH time t and the base period: # Controls must be untreated at whichever is later, otherwise # their outcome at the base period is contaminated by treatment. nyt_threshold = max(t, base_period_val) + self.anticipation control_mask = never_treated_mask | ( (unit_cohorts > nyt_threshold) & (unit_cohorts != g) ) # Extract outcomes for base and post periods y_base = outcome_matrix[:, base_col] y_post = outcome_matrix[:, post_col] # Compute outcome changes (vectorized) outcome_change = y_post - y_base # Filter to units with valid data (no NaN in either period) valid_mask = ~(np.isnan(y_base) | np.isnan(y_post)) # Get treated and control with valid data treated_valid = treated_mask & valid_mask control_valid = control_mask & valid_mask n_treated = np.sum(treated_valid) n_control = np.sum(control_valid) if n_treated == 0 or n_control == 0: return None, 0.0, 0, 0, None, None # Extract outcome changes for treated and control treated_change = outcome_change[treated_valid] control_change = outcome_change[control_valid] # Extract survey weights for treated and control survey_w = precomputed.get("survey_weights") sw_treated = survey_w[treated_valid] if survey_w is not None else None sw_control = survey_w[control_valid] if survey_w is not None else None # Guard against zero effective mass after subpopulation filtering if sw_treated is not None and np.sum(sw_treated) <= 0: return None, 0.0, 0, 0, None, None if sw_control is not None and np.sum(sw_control) <= 0: return None, 0.0, 0, 0, None, None # Get covariates if specified (from the base period) X_treated = None X_control = None if covariates and covariate_by_period is not None: cov_matrix = covariate_by_period[base_period_val] X_treated = cov_matrix[treated_valid] X_control = cov_matrix[control_valid] # Check for missing values if np.any(np.isnan(X_treated)) or np.any(np.isnan(X_control)): warnings.warn( f"Missing values in covariates for group {g}, time {t}. " "Falling back to unconditional estimation.", UserWarning, stacklevel=3, ) X_treated = None X_control = None # Compute cache key for propensity score reuse pscore_key = None if pscore_cache is not None and X_treated is not None: is_balanced = precomputed.get("is_balanced", False) if is_balanced and self.control_group == "never_treated": pscore_key = (g, base_period_val) else: pscore_key = (g, base_period_val, t) # Compute cache key for Cholesky reuse (DR outcome regression) cho_key = None if cho_cache is not None and X_control is not None: is_balanced = precomputed.get("is_balanced", False) if is_balanced and self.control_group == "never_treated": cho_key = base_period_val else: cho_key = (g, base_period_val, t) # Estimation method if self.estimation_method == "reg": att_gt, se_gt, inf_func = self._outcome_regression( treated_change, control_change, X_treated, X_control, sw_treated=sw_treated, sw_control=sw_control, ) elif self.estimation_method == "ipw": sw_all = np.concatenate([sw_treated, sw_control]) if sw_treated is not None else None epv_diag: dict = {} att_gt, se_gt, inf_func = self._ipw_estimation( treated_change, control_change, int(n_treated), int(n_control), X_treated, X_control, pscore_cache=pscore_cache, pscore_key=pscore_key, sw_treated=sw_treated, sw_control=sw_control, sw_all=sw_all, context_label=f"cohort g={g}", epv_diagnostics_out=epv_diag, ) if epv_diagnostics is not None and epv_diag: epv_diagnostics[(g, t)] = epv_diag else: # doubly robust sw_all = np.concatenate([sw_treated, sw_control]) if sw_treated is not None else None epv_diag = {} att_gt, se_gt, inf_func = self._doubly_robust( treated_change, control_change, X_treated, X_control, pscore_cache=pscore_cache, pscore_key=pscore_key, cho_cache=cho_cache, cho_key=cho_key, sw_treated=sw_treated, sw_control=sw_control, sw_all=sw_all, context_label=f"cohort g={g}", epv_diagnostics_out=epv_diag, ) if epv_diagnostics is not None and epv_diag: epv_diagnostics[(g, t)] = epv_diag # Package influence function info with index arrays (positions into # precomputed['all_units']) for O(1) downstream lookups instead of # O(n) Python dict lookups. n_t = int(n_treated) all_units = precomputed["all_units"] treated_positions = np.where(treated_valid)[0] control_positions = np.where(control_valid)[0] inf_func_info = { "treated_idx": treated_positions, "control_idx": control_positions, "treated_units": all_units[treated_positions], "control_units": all_units[control_positions], "treated_inf": inf_func[:n_t], "control_inf": inf_func[n_t:], } sw_sum = float(np.sum(sw_treated)) if sw_treated is not None else None return att_gt, se_gt, int(n_treated), int(n_control), inf_func_info, sw_sum def _compute_all_att_gt_vectorized( self, precomputed: PrecomputedData, treatment_groups: List[Any], time_periods: List[Any], min_period: Any, ) -> Tuple[Dict, Dict, Dict]: """ Vectorized computation of all ATT(g,t) for the no-covariates regression case. This inlines the simple difference-in-means path from _outcome_regression() and eliminates per-(g,t) Python function call overhead. Returns ------- group_time_effects : dict Mapping (g, t) -> effect dict. influence_func_info : dict Mapping (g, t) -> influence function info dict. """ period_to_col = precomputed["period_to_col"] outcome_matrix = precomputed["outcome_matrix"] cohort_masks = precomputed["cohort_masks"] never_treated_mask = precomputed["never_treated_mask"] unit_cohorts = precomputed["unit_cohorts"] survey_w = precomputed.get("survey_weights") group_time_effects = {} influence_func_info = {} skipped_missing_period: List[Tuple] = [] skipped_empty_cell: List[Tuple] = [] # Collect all valid (g, t, base_col, post_col) tuples tasks = [] for g in treatment_groups: if self.base_period == "universal": universal_base = g - 1 - self.anticipation valid_periods = [t for t in time_periods if t != universal_base] else: valid_periods = [ t for t in time_periods if t >= g - self.anticipation or t > min_period ] for t in valid_periods: # Base period selection if self.base_period == "universal": base_period_val = g - 1 - self.anticipation else: if t < g - self.anticipation: base_period_val = t - 1 else: base_period_val = g - 1 - self.anticipation if base_period_val not in period_to_col or t not in period_to_col: skipped_missing_period.append((g, t)) continue tasks.append( (g, t, period_to_col[base_period_val], period_to_col[t], base_period_val) ) # Process all tasks atts = [] ses = [] task_keys = [] for g, t, base_col, post_col, base_period_val in tasks: treated_mask = cohort_masks[g] if self.control_group == "never_treated": control_mask = never_treated_mask else: # Controls must be untreated at both t and base_period_val nyt_threshold = max(t, base_period_val) + self.anticipation control_mask = never_treated_mask | ( (unit_cohorts > nyt_threshold) & (unit_cohorts != g) ) y_base = outcome_matrix[:, base_col] y_post = outcome_matrix[:, post_col] outcome_change = y_post - y_base valid_mask = ~(np.isnan(y_base) | np.isnan(y_post)) treated_valid = treated_mask & valid_mask control_valid = control_mask & valid_mask n_treated = np.sum(treated_valid) n_control = np.sum(control_valid) if n_treated == 0 or n_control == 0: skipped_empty_cell.append((g, t)) continue treated_change = outcome_change[treated_valid] control_change = outcome_change[control_valid] n_t = int(n_treated) n_c = int(n_control) # Inline no-covariates regression (difference in means) if survey_w is not None: sw_t = survey_w[treated_valid] sw_c = survey_w[control_valid] # Guard against zero effective mass if np.sum(sw_t) <= 0 or np.sum(sw_c) <= 0: skipped_empty_cell.append((g, t)) continue sw_t_norm = sw_t / np.sum(sw_t) sw_c_norm = sw_c / np.sum(sw_c) mu_t = float(np.sum(sw_t_norm * treated_change)) mu_c = float(np.sum(sw_c_norm * control_change)) att = mu_t - mu_c # Influence function (survey-weighted) inf_treated = sw_t_norm * (treated_change - mu_t) inf_control = -sw_c_norm * (control_change - mu_c) # SE derived from IF: sum(IF_i^2) se = ( float(np.sqrt(np.sum(inf_treated**2) + np.sum(inf_control**2))) if (n_t > 0 and n_c > 0) else 0.0 ) sw_sum = float(np.sum(sw_t)) else: att = float(np.mean(treated_change) - np.mean(control_change)) var_t = float(np.var(treated_change, ddof=1)) if n_t > 1 else 0.0 var_c = float(np.var(control_change, ddof=1)) if n_c > 1 else 0.0 se = float(np.sqrt(var_t / n_t + var_c / n_c)) if (n_t > 0 and n_c > 0) else 0.0 # Influence function inf_treated = (treated_change - np.mean(treated_change)) / n_t inf_control = -(control_change - np.mean(control_change)) / n_c sw_sum = None gte_entry = { "effect": att, "se": se, # t_stat, p_value, conf_int filled by batch inference below "t_stat": np.nan, "p_value": np.nan, "conf_int": (np.nan, np.nan), "n_treated": n_t, "n_control": n_c, } if sw_sum is not None: gte_entry["survey_weight_sum"] = sw_sum group_time_effects[(g, t)] = gte_entry all_units = precomputed["all_units"] treated_positions = np.where(treated_valid)[0] control_positions = np.where(control_valid)[0] influence_func_info[(g, t)] = { "treated_idx": treated_positions, "control_idx": control_positions, "treated_units": all_units[treated_positions], "control_units": all_units[control_positions], "treated_inf": inf_treated, "control_inf": inf_control, } atts.append(att) ses.append(se) task_keys.append((g, t)) # Batch inference for all (g,t) pairs at once if task_keys: df_survey_val = precomputed.get("df_survey") # Guard: replicate design with undefined df → NaN inference if ( df_survey_val is None and precomputed.get("resolved_survey_unit") is not None and hasattr(precomputed["resolved_survey_unit"], "uses_replicate_variance") and precomputed["resolved_survey_unit"].uses_replicate_variance ): df_survey_val = 0 t_stats, p_values, ci_lowers, ci_uppers = safe_inference_batch( np.array(atts), np.array(ses), alpha=self.alpha, df=df_survey_val, ) for idx, key in enumerate(task_keys): group_time_effects[key]["t_stat"] = float(t_stats[idx]) group_time_effects[key]["p_value"] = float(p_values[idx]) group_time_effects[key]["conf_int"] = (float(ci_lowers[idx]), float(ci_uppers[idx])) skip_info = { "missing_period": skipped_missing_period, "empty_cell": skipped_empty_cell, } return group_time_effects, influence_func_info, skip_info def _compute_all_att_gt_covariate_reg( self, precomputed: PrecomputedData, treatment_groups: List[Any], time_periods: List[Any], min_period: Any, ) -> Tuple[Dict, Dict, Dict]: """ Optimized computation of all ATT(g,t) for the covariate regression case. Groups (g,t) pairs by their control regression key to reuse Cholesky factorizations of X^T X across pairs that share the same control design matrix. Returns ------- group_time_effects : dict Mapping (g, t) -> effect dict. influence_func_info : dict Mapping (g, t) -> influence function info dict. """ period_to_col = precomputed["period_to_col"] outcome_matrix = precomputed["outcome_matrix"] cohort_masks = precomputed["cohort_masks"] never_treated_mask = precomputed["never_treated_mask"] unit_cohorts = precomputed["unit_cohorts"] covariate_by_period = precomputed["covariate_by_period"] is_balanced = precomputed["is_balanced"] group_time_effects = {} influence_func_info = {} atts = [] ses = [] task_keys = [] n_nan_cells = 0 skipped_missing_period: List[Tuple] = [] skipped_empty_cell: List[Tuple] = [] # Collect all valid (g, t) tasks with their base periods tasks_by_group = {} # control_key -> list of (g, t, base_period_val, base_col, post_col) for g in treatment_groups: if self.base_period == "universal": universal_base = g - 1 - self.anticipation valid_periods = [t for t in time_periods if t != universal_base] else: valid_periods = [ t for t in time_periods if t >= g - self.anticipation or t > min_period ] for t in valid_periods: if self.base_period == "universal": base_period_val = g - 1 - self.anticipation else: if t < g - self.anticipation: base_period_val = t - 1 else: base_period_val = g - 1 - self.anticipation if base_period_val not in period_to_col or t not in period_to_col: skipped_missing_period.append((g, t)) continue # Determine control regression grouping key. # For balanced panels with never_treated control, X_control depends # only on base_period_val (control mask is time-invariant). # For not_yet_treated, the control mask excludes cohort g, so include g. if is_balanced and self.control_group == "never_treated": control_key = base_period_val else: control_key = (g, base_period_val, t) tasks_by_group.setdefault(control_key, []).append( (g, t, base_period_val, period_to_col[base_period_val], period_to_col[t]) ) # Process each group of tasks sharing the same control regression for control_key, tasks in tasks_by_group.items(): # Use the first task to build X_control (same for all in the group) first_g, first_t, base_period_val, first_base_col, first_post_col = tasks[0] cov_matrix = covariate_by_period[base_period_val] # Build control mask (same for all tasks in this group) if self.control_group == "never_treated": control_mask = never_treated_mask else: # Controls must be untreated at both t and base_period_val nyt_threshold = max(first_t, base_period_val) + self.anticipation control_mask = never_treated_mask | ( (unit_cohorts > nyt_threshold) & (unit_cohorts != first_g) ) # For balanced panels, valid_mask is all True so control_valid = control_mask if is_balanced: control_valid_base = control_mask else: y_base_first = outcome_matrix[:, first_base_col] y_post_first = outcome_matrix[:, first_post_col] valid_first = ~(np.isnan(y_base_first) | np.isnan(y_post_first)) control_valid_base = control_mask & valid_first X_ctrl_raw = cov_matrix[control_valid_base] # Check for NaN in control covariates ctrl_has_nan = bool(np.any(np.isnan(X_ctrl_raw))) # Build X_ctrl with intercept n_c_base = int(np.sum(control_valid_base)) if n_c_base == 0: skipped_empty_cell.extend((g, t) for g, t, *_ in tasks) continue X_ctrl = None cho = None kept_cols = None if not ctrl_has_nan: X_ctrl = np.column_stack([np.ones(n_c_base), X_ctrl_raw]) # One-time rank check for this control group rank, dropped_cols, _ = _detect_rank_deficiency(X_ctrl) if len(dropped_cols) > 0: # Rank-deficient: force lstsq for both "warn" and "silent". # Cholesky on near-singular XtX could yield unstable coefficients. if self.rank_deficient_action == "warn": col_info = _format_dropped_columns(dropped_cols) warnings.warn( f"Rank-deficient covariate design (control_key={control_key}): " f"dropped columns {col_info}. Rank {rank} < {X_ctrl.shape[1]}. " "Using minimum-norm least-squares solution.", UserWarning, stacklevel=2, ) cho = None # Force lstsq path for ALL rank-deficient cases kept_cols = np.array( [i for i in range(X_ctrl.shape[1]) if i not in dropped_cols] ) else: kept_cols = None # Full rank — use all columns with np.errstate(all="ignore"): XtX = X_ctrl.T @ X_ctrl try: cho = scipy_linalg.cho_factor(XtX) except np.linalg.LinAlgError: cho = None # Process each (g, t) pair in this group for g, t, bp_val, base_col, post_col in tasks: treated_mask = cohort_masks[g] # Recompute control mask for not_yet_treated (varies by g, t) if self.control_group == "not_yet_treated": # Controls must be untreated at both t and base period nyt_threshold = max(t, bp_val) + self.anticipation control_mask = never_treated_mask | ( (unit_cohorts > nyt_threshold) & (unit_cohorts != g) ) y_base = outcome_matrix[:, base_col] y_post = outcome_matrix[:, post_col] outcome_change = y_post - y_base if is_balanced: valid_mask_pair = np.ones(len(y_base), dtype=bool) else: valid_mask_pair = ~(np.isnan(y_base) | np.isnan(y_post)) treated_valid = treated_mask & valid_mask_pair # For balanced + never_treated, control_valid is same as control_valid_base if is_balanced and self.control_group == "never_treated": control_valid = control_valid_base else: control_valid = control_mask & valid_mask_pair n_t = int(np.sum(treated_valid)) n_c = int(np.sum(control_valid)) if n_t == 0 or n_c == 0: skipped_empty_cell.append((g, t)) continue treated_change = outcome_change[treated_valid] control_change = outcome_change[control_valid] X_treated_pair = cov_matrix[treated_valid] X_control_pair = cov_matrix[control_valid] # Check for NaN in this pair's covariates if np.any(np.isnan(X_treated_pair)) or np.any(np.isnan(X_control_pair)): # Fall back to unconditional (difference in means) warnings.warn( f"Missing values in covariates for group {g}, time {t}. " "Falling back to unconditional estimation.", UserWarning, stacklevel=3, ) att = float(np.mean(treated_change) - np.mean(control_change)) var_t = float(np.var(treated_change, ddof=1)) if n_t > 1 else 0.0 var_c = float(np.var(control_change, ddof=1)) if n_c > 1 else 0.0 se = float(np.sqrt(var_t / n_t + var_c / n_c)) inf_treated = (treated_change - np.mean(treated_change)) / n_t inf_control = -(control_change - np.mean(control_change)) / n_c else: # Build per-pair X_ctrl if control_valid differs from base if is_balanced and self.control_group == "never_treated" and X_ctrl is not None: pair_X_ctrl = X_ctrl pair_n_c = n_c_base else: pair_X_ctrl = np.column_stack([np.ones(n_c), X_control_pair]) pair_n_c = n_c # Solve for beta beta = None with np.errstate(all="ignore"): if ( cho is not None and is_balanced and self.control_group == "never_treated" ): # Use cached Cholesky Xty = pair_X_ctrl.T @ control_change beta = scipy_linalg.cho_solve(cho, Xty) else: # Compute per-pair Cholesky or lstsq fallback if kept_cols is not None: # Rank-deficient: skip Cholesky, use reduced lstsq pass else: pair_XtX = pair_X_ctrl.T @ pair_X_ctrl try: pair_cho = scipy_linalg.cho_factor(pair_XtX) Xty = pair_X_ctrl.T @ control_change beta = scipy_linalg.cho_solve(pair_cho, Xty) except np.linalg.LinAlgError: pass if beta is None or np.any(~np.isfinite(beta)): if kept_cols is not None: # Reduced solve for rank-deficient design result = scipy_linalg.lstsq( pair_X_ctrl[:, kept_cols], control_change, cond=1e-07, ) beta = np.zeros(pair_X_ctrl.shape[1]) beta[kept_cols] = result[0] else: # Full-rank lstsq fallback (Cholesky numerical failure) result = scipy_linalg.lstsq( pair_X_ctrl, control_change, cond=1e-07, ) beta = result[0] nan_cell = False if beta is None or np.any(~np.isfinite(beta)): nan_cell = True n_nan_cells += 1 if not nan_cell: X_treated_w_intercept = np.column_stack([np.ones(n_t), X_treated_pair]) with np.errstate(all="ignore"): predicted_control = X_treated_w_intercept @ beta treated_residuals = treated_change - predicted_control if np.any(~np.isfinite(predicted_control)): nan_cell = True n_nan_cells += 1 if not nan_cell: att = float(np.mean(treated_residuals)) with np.errstate(all="ignore"): residuals = control_change - pair_X_ctrl @ beta if np.any(~np.isfinite(residuals)): nan_cell = True n_nan_cells += 1 if nan_cell: att = np.nan se = np.nan inf_treated = np.zeros(n_t) inf_control = np.zeros(n_c) else: var_t = float(np.var(treated_residuals, ddof=1)) if n_t > 1 else 0.0 var_c = float(np.var(residuals, ddof=1)) if pair_n_c > 1 else 0.0 se = float(np.sqrt(var_t / n_t + var_c / pair_n_c)) inf_treated = (treated_residuals - np.mean(treated_residuals)) / n_t inf_control = -residuals / pair_n_c group_time_effects[(g, t)] = { "effect": att, "se": se, "t_stat": np.nan, "p_value": np.nan, "conf_int": (np.nan, np.nan), "n_treated": n_t, "n_control": n_c, } all_units = precomputed["all_units"] treated_positions = np.where(treated_valid)[0] control_positions = np.where(control_valid)[0] influence_func_info[(g, t)] = { "treated_idx": treated_positions, "control_idx": control_positions, "treated_units": all_units[treated_positions], "control_units": all_units[control_positions], "treated_inf": inf_treated, "control_inf": inf_control, } atts.append(att) ses.append(se) task_keys.append((g, t)) if n_nan_cells > 0: warnings.warn( f"{n_nan_cells} group-time cell(s) have non-finite regression results " "(near-singular covariates). These cells are preserved with NaN inference.", UserWarning, stacklevel=2, ) # Batch inference if task_keys: # Use survey df for replicate designs (propagated from precomputed) _ipw_dr_df = precomputed.get("df_survey") if precomputed is not None else None # Guard: replicate design with undefined df → NaN inference if ( _ipw_dr_df is None and precomputed is not None and precomputed.get("resolved_survey_unit") is not None and hasattr(precomputed["resolved_survey_unit"], "uses_replicate_variance") and precomputed["resolved_survey_unit"].uses_replicate_variance ): _ipw_dr_df = 0 t_stats, p_values, ci_lowers, ci_uppers = safe_inference_batch( np.array(atts), np.array(ses), alpha=self.alpha, df=_ipw_dr_df ) for idx, key in enumerate(task_keys): group_time_effects[key]["t_stat"] = float(t_stats[idx]) group_time_effects[key]["p_value"] = float(p_values[idx]) group_time_effects[key]["conf_int"] = (float(ci_lowers[idx]), float(ci_uppers[idx])) skip_info = { "missing_period": skipped_missing_period, "empty_cell": skipped_empty_cell, } return group_time_effects, influence_func_info, skip_info
[docs] def fit( self, data: pd.DataFrame, outcome: str, unit: str, time: str, first_treat: str, covariates: Optional[List[str]] = None, aggregate: Optional[str] = None, balance_e: Optional[int] = None, survey_design: object = None, ) -> CallawaySantAnnaResults: """ Fit the Callaway-Sant'Anna estimator. Parameters ---------- data : pd.DataFrame Panel data with unit and time identifiers. For repeated cross-sections (``panel=False``), each observation should have a unique unit ID — units do not repeat across periods. outcome : str Name of outcome variable column. unit : str Name of unit identifier column. time : str Name of time period column. first_treat : str Name of column indicating when unit was first treated. Use 0 (or np.inf) for never-treated units. covariates : list, optional List of covariate column names for conditional parallel trends. aggregate : str, optional How to aggregate group-time effects: - None: Only compute ATT(g,t) (default) - "simple": Simple weighted average (overall ATT) - "event_study": Aggregate by relative time (event study) - "group": Aggregate by treatment cohort - "all": Compute all aggregations balance_e : int, optional For event study, balance the panel at relative time e. Ensures all groups contribute to each relative period. survey_design : SurveyDesign, optional Survey design specification. Supports pweight with strata/PSU/FPC. Aggregated SEs (overall, event study, group) use design-based variance via compute_survey_if_variance(). All estimation methods (reg, ipw, dr) support covariates + survey. For repeated cross-sections (``panel=False``), survey weights are per-observation (no unit-level collapse). Returns ------- CallawaySantAnnaResults Object containing all estimation results. Raises ------ ValueError If required columns are missing or data validation fails. """ # Validate pscore_trim (may have been changed via set_params) if not (0 < self.pscore_trim < 0.5): raise ValueError(f"pscore_trim must be in (0, 0.5), got {self.pscore_trim}") # Reset stale state from prior fit (prevents leaking event-study VCV) self._event_study_vcov = None # Tracker for _safe_inv lstsq fallbacks across all analytical SE # paths (PS Hessian, OR bread, event-study bread, etc.). Emit ONE # aggregate warning at the end of fit rather than fanning out per # cell. Sibling of PR #9 finding #17. self._safe_inv_tracker: List[float] = [] if not self.panel: warnings.warn( "panel=False uses repeated cross-section DRDID estimators " "(Sant'Anna & Zhao 2020, Section 4) which assume stationary " "cross-sectional sampling: the population distribution of " "(Y, X, G) must be stable across periods. This assumption " "is not data-checkable.", UserWarning, stacklevel=2, ) # Validate unique unit IDs for panel=False if not self.panel: if data[unit].duplicated().any(): raise ValueError( "panel=False requires unique unit IDs (one observation per unit). " "Found duplicate unit IDs. If your data is a panel, use panel=True." ) # Normalize empty covariates list to None if covariates is not None and len(covariates) == 0: covariates = None # Resolve survey design if provided from diff_diff.survey import ( _resolve_survey_for_fit, _validate_unit_constant_survey, ) resolved_survey, survey_weights, survey_weight_type, survey_metadata = ( _resolve_survey_for_fit(survey_design, data, "analytical") ) # Validate within-unit constancy for panel survey designs if resolved_survey is not None: if self.panel: _validate_unit_constant_survey(data, unit, survey_design) if resolved_survey.weight_type != "pweight": raise ValueError( f"CallawaySantAnna survey support requires weight_type='pweight', " f"got '{resolved_survey.weight_type}'. The survey variance math " f"assumes probability weights (pweight)." ) # Note: strata/PSU/FPC are now supported — aggregated SEs use # compute_survey_if_variance() for design-based inference. # Bootstrap + survey is now supported via PSU-level multiplier bootstrap. # Validate inputs required_cols = [outcome, unit, time, first_treat] if covariates: required_cols.extend(covariates) missing = [c for c in required_cols if c not in data.columns] if missing: raise ValueError(f"Missing columns: {missing}") # Create working copy df = data.copy() # Ensure numeric types df[time] = pd.to_numeric(df[time]) df[first_treat] = pd.to_numeric(df[first_treat]) # Standardize the first_treat column name for internal use # This avoids hardcoding column names in internal methods df["first_treat"] = df[first_treat] # Never-treated indicator (must precede treatment_groups to exclude np.inf) df["_never_treated"] = (df[first_treat] == 0) | (df[first_treat] == np.inf) # Normalize np.inf → 0 so all downstream `> 0` checks exclude never-treated _inf_mask = df[first_treat] == np.inf if _inf_mask.any(): n_inf_units = df.loc[_inf_mask, unit].nunique() warnings.warn( f"{n_inf_units} unit(s) have first_treat=inf; recoding to 0 " f"(never-treated). Use first_treat=0 to suppress this warning.", UserWarning, stacklevel=2, ) df.loc[_inf_mask, first_treat] = 0 # Identify groups and time periods time_periods = sorted(df[time].unique()) treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0]) if self.panel: # Panel: count unique units unit_info = ( df.groupby(unit) .agg({first_treat: "first", "_never_treated": "first"}) .reset_index() ) n_treated_units = (unit_info[first_treat] > 0).sum() n_control_units = (unit_info["_never_treated"]).sum() else: # RCS: count observations per cohort (no unit tracking) n_treated_units = int((df[first_treat] > 0).sum()) n_control_units = int(df["_never_treated"].sum()) if n_control_units == 0 and self.control_group == "never_treated": raise ValueError( "No never-treated units found. Check 'first_treat' column. " "Use control_group='not_yet_treated' if all units are eventually treated." ) if n_control_units == 0 and self.control_group == "not_yet_treated": # With not_yet_treated, controls are units not yet treated at each # (g, t) pair — never-treated units are not required. if len(treatment_groups) < 2: raise ValueError( "not_yet_treated control group requires at least 2 treatment " "cohorts when there are no never-treated units." ) # Note: CallawaySantAnna supports survey weights, strata, PSU, and FPC. # Per-cell SEs use IF-based variance; aggregated SEs use design-based # variance via compute_survey_if_variance() or PSU-level bootstrap. # Pre-compute data structures for efficient ATT(g,t) computation if self.panel: precomputed = self._precompute_structures( df, outcome, unit, time, first_treat, covariates, time_periods, treatment_groups, resolved_survey=resolved_survey, ) else: precomputed = self._precompute_structures_rc( df, outcome, unit, time, first_treat, covariates, time_periods, treatment_groups, resolved_survey=resolved_survey, ) # Recompute survey metadata from the unit-level resolved survey so # that n_psu and df_survey reflect the actual survey design (explicit # PSU/strata) rather than hard-coding n_units. if resolved_survey is not None and survey_metadata is not None: resolved_survey_unit = precomputed.get("resolved_survey_unit") if resolved_survey_unit is not None: from diff_diff.survey import compute_survey_metadata unit_w = resolved_survey_unit.weights survey_metadata = compute_survey_metadata(resolved_survey_unit, unit_w) # Survey df for safe_inference calls — use the unit-level resolved # survey df computed in _precompute_structures for consistency. df_survey = precomputed.get("df_survey") # Guard: replicate design with undefined df (rank <= 1) → NaN inference if ( df_survey is None and resolved_survey is not None and hasattr(resolved_survey, "uses_replicate_variance") and resolved_survey.uses_replicate_variance ): df_survey = 0 # Compute ATT(g,t) for each group-time combination min_period = min(time_periods) has_survey = resolved_survey is not None _skip_info = {"missing_period": [], "empty_cell": []} _n_skipped_other = 0 if not self.panel: # --- Repeated cross-section path --- # No vectorized/Cholesky fast paths (panel-only optimizations). # Loop using _compute_att_gt_rc() for each (g,t). group_time_effects = {} influence_func_info = {} epv_diagnostics = ( {} if (covariates and self.estimation_method in ("ipw", "dr")) else None ) for g in treatment_groups: if self.base_period == "universal": universal_base = g - 1 - self.anticipation valid_periods = [t for t in time_periods if t != universal_base] else: valid_periods = [ t for t in time_periods if t >= g - self.anticipation or t > min_period ] for t in valid_periods: rc_result = self._compute_att_gt_rc( precomputed, g, t, covariates, epv_diagnostics=epv_diagnostics, ) att_gt, se_gt, n_treat, n_ctrl, inf_info, sw_sum = rc_result[:6] agg_w = rc_result[6] if len(rc_result) > 6 else n_treat if att_gt is not None: t_stat, p_val, ci = safe_inference( att_gt, se_gt, alpha=self.alpha, df=df_survey, ) gte_entry = { "effect": att_gt, "se": se_gt, "t_stat": t_stat, "p_value": p_val, "conf_int": ci, "n_treated": n_treat, "n_control": n_ctrl, "agg_weight": agg_w, } if sw_sum is not None: gte_entry["survey_weight_sum"] = sw_sum group_time_effects[(g, t)] = gte_entry if inf_info is not None: influence_func_info[(g, t)] = inf_info else: _n_skipped_other += 1 elif covariates is None and self.estimation_method == "reg": # Fast vectorized path for the common no-covariates regression case group_time_effects, influence_func_info, _skip_info = ( self._compute_all_att_gt_vectorized( precomputed, treatment_groups, time_periods, min_period ) ) epv_diagnostics = None # No logit in this path elif ( covariates is not None and self.estimation_method == "reg" and self.rank_deficient_action != "error" and not has_survey # Cholesky cache uses X'X; survey needs X'WX ): # Optimized covariate regression path with Cholesky caching group_time_effects, influence_func_info, _skip_info = ( self._compute_all_att_gt_covariate_reg( precomputed, treatment_groups, time_periods, min_period ) ) epv_diagnostics = None # No logit in this path else: # General path: IPW, DR, rank_deficient_action="error", or edge cases group_time_effects = {} influence_func_info = {} # Propensity score cache for IPW/DR with covariates pscore_cache = {} if (covariates and self.estimation_method in ("ipw", "dr")) else None # Cholesky cache for DR outcome regression component # Skip cache when survey weights present (X'WX differs from X'X) cho_cache = ( {} if ( covariates and self.estimation_method == "dr" and self.rank_deficient_action != "error" and not has_survey ) else None ) epv_diagnostics = ( {} if (covariates and self.estimation_method in ("ipw", "dr")) else None ) for g in treatment_groups: if self.base_period == "universal": universal_base = g - 1 - self.anticipation valid_periods = [t for t in time_periods if t != universal_base] else: valid_periods = [ t for t in time_periods if t >= g - self.anticipation or t > min_period ] for t in valid_periods: att_gt, se_gt, n_treat, n_ctrl, inf_info, sw_sum = self._compute_att_gt_fast( precomputed, g, t, covariates, pscore_cache=pscore_cache, cho_cache=cho_cache, epv_diagnostics=epv_diagnostics, ) if att_gt is not None: t_stat, p_val, ci = safe_inference( att_gt, se_gt, alpha=self.alpha, df=df_survey, ) gte_entry = { "effect": att_gt, "se": se_gt, "t_stat": t_stat, "p_value": p_val, "conf_int": ci, "n_treated": n_treat, "n_control": n_ctrl, } if sw_sum is not None: gte_entry["survey_weight_sum"] = sw_sum group_time_effects[(g, t)] = gte_entry if inf_info is not None: influence_func_info[(g, t)] = inf_info else: _n_skipped_other += 1 if not group_time_effects: raise ValueError( "Could not estimate any group-time effects. " "Check that data has sufficient observations." ) # Consolidated EPV summary warning if epv_diagnostics: low_epv = {k: v for k, v in epv_diagnostics.items() if v.get("is_low")} if low_epv: n_affected = len(low_epv) n_total = len(epv_diagnostics) min_entry = min(low_epv.values(), key=lambda v: v["epv"]) min_g = min(low_epv.keys(), key=lambda k: low_epv[k]["epv"]) warnings.warn( f"Low Events Per Variable (EPV) detected in propensity " f"score estimation for {n_affected} of {n_total} cell(s). " f"Minimum EPV = {min_entry['epv']:.1f} " f"(cohort g={min_g[0]}). " f"Consider estimation_method='reg' (avoids propensity " f"scores) or reducing the number of covariates. " f"See results.epv_summary() for details.", UserWarning, stacklevel=2, ) # Consolidated (g,t) cell skip warning (all paths) _n_missing = len(_skip_info.get("missing_period", [])) _n_empty = len(_skip_info.get("empty_cell", [])) _n_total_skipped = _n_missing + _n_empty + _n_skipped_other if _n_total_skipped > 0: _parts = [] if _n_missing: _parts.append( f"{_n_missing} due to missing base/post period " f"in panel structure" ) if _n_empty: _parts.append(f"{_n_empty} due to zero treated or control " f"observations") if _n_skipped_other: _parts.append( f"{_n_skipped_other} due to insufficient data or " f"non-estimable cells" ) warnings.warn( f"{_n_total_skipped} (group, time) cell(s) could not be " f"estimated: {'; '.join(_parts)}.", UserWarning, stacklevel=2, ) # Compute overall ATT (simple aggregation) overall_att, overall_se, overall_effective_df = self._aggregate_simple( group_time_effects, influence_func_info, df, unit, precomputed ) # Use per-statistic effective df from replicate aggregation if available; # otherwise fall back to the original df from the survey design. if overall_effective_df is not None: df_survey = overall_effective_df # Propagate to survey_metadata for display consistency if survey_metadata is not None: survey_metadata.df_survey = df_survey # Guard: replicate design with undefined df (rank <= 1) → NaN inference if ( df_survey is None and resolved_survey is not None and hasattr(resolved_survey, "uses_replicate_variance") and resolved_survey.uses_replicate_variance ): df_survey = 0 overall_t, overall_p, overall_ci = safe_inference( overall_att, overall_se, alpha=self.alpha, df=df_survey, ) # Compute additional aggregations if requested event_study_effects = None group_effects = None if aggregate in ["event_study", "all"]: event_study_effects = self._aggregate_event_study( group_time_effects, influence_func_info, treatment_groups, time_periods, balance_e, df, unit, precomputed, ) if aggregate in ["group", "all"]: group_effects = self._aggregate_by_group( group_time_effects, influence_func_info, treatment_groups, precomputed=precomputed, df=df, unit=unit, ) # Reject replicate-weight designs for bootstrap — replicate variance # is an analytical alternative, not compatible with bootstrap if ( self.n_bootstrap > 0 and resolved_survey is not None and hasattr(resolved_survey, "uses_replicate_variance") and resolved_survey.uses_replicate_variance ): raise NotImplementedError( "CallawaySantAnna bootstrap (n_bootstrap > 0) is not supported " "with replicate-weight survey designs. Replicate weights provide " "analytical variance; use n_bootstrap=0 instead." ) # Run bootstrap inference if requested bootstrap_results = None if self.n_bootstrap > 0 and influence_func_info: bootstrap_results = self._run_multiplier_bootstrap( group_time_effects=group_time_effects, influence_func_info=influence_func_info, aggregate=aggregate, balance_e=balance_e, treatment_groups=treatment_groups, time_periods=time_periods, df=df, unit=unit, precomputed=precomputed, cband=self.cband, ) # Update estimates with bootstrap inference overall_se = bootstrap_results.overall_att_se overall_t = safe_inference(overall_att, overall_se, alpha=self.alpha)[0] overall_p = bootstrap_results.overall_att_p_value overall_ci = bootstrap_results.overall_att_ci # Update group-time effects with bootstrap SEs (batched) gt_keys = [gt for gt in group_time_effects if gt in bootstrap_results.group_time_ses] if gt_keys: gt_effects_arr = np.array( [float(group_time_effects[gt]["effect"]) for gt in gt_keys] ) gt_ses_arr = np.array( [float(bootstrap_results.group_time_ses[gt]) for gt in gt_keys] ) gt_t_stats, _, _, _ = safe_inference_batch( gt_effects_arr, gt_ses_arr, alpha=self.alpha ) for idx, gt in enumerate(gt_keys): group_time_effects[gt]["se"] = bootstrap_results.group_time_ses[gt] group_time_effects[gt]["conf_int"] = bootstrap_results.group_time_cis[gt] group_time_effects[gt]["p_value"] = bootstrap_results.group_time_p_values[gt] group_time_effects[gt]["t_stat"] = float(gt_t_stats[idx]) # Update event study effects with bootstrap SEs (batched) if ( event_study_effects is not None and bootstrap_results.event_study_ses is not None and bootstrap_results.event_study_cis is not None and bootstrap_results.event_study_p_values is not None ): es_keys = [e for e in event_study_effects if e in bootstrap_results.event_study_ses] if es_keys: es_effects_arr = np.array( [float(event_study_effects[e]["effect"]) for e in es_keys] ) es_ses_arr = np.array( [float(bootstrap_results.event_study_ses[e]) for e in es_keys] ) es_t_stats, _, _, _ = safe_inference_batch( es_effects_arr, es_ses_arr, alpha=self.alpha ) for idx, e in enumerate(es_keys): event_study_effects[e]["se"] = bootstrap_results.event_study_ses[e] event_study_effects[e]["conf_int"] = bootstrap_results.event_study_cis[e] event_study_effects[e]["p_value"] = bootstrap_results.event_study_p_values[ e ] event_study_effects[e]["t_stat"] = float(es_t_stats[idx]) # Update group effects with bootstrap SEs (batched) if ( group_effects is not None and bootstrap_results.group_effect_ses is not None and bootstrap_results.group_effect_cis is not None and bootstrap_results.group_effect_p_values is not None ): grp_keys = [g for g in group_effects if g in bootstrap_results.group_effect_ses] if grp_keys: grp_effects_arr = np.array( [float(group_effects[g]["effect"]) for g in grp_keys] ) grp_ses_arr = np.array( [float(bootstrap_results.group_effect_ses[g]) for g in grp_keys] ) grp_t_stats, _, _, _ = safe_inference_batch( grp_effects_arr, grp_ses_arr, alpha=self.alpha ) for idx, g in enumerate(grp_keys): group_effects[g]["se"] = bootstrap_results.group_effect_ses[g] group_effects[g]["conf_int"] = bootstrap_results.group_effect_cis[g] group_effects[g]["p_value"] = bootstrap_results.group_effect_p_values[g] group_effects[g]["t_stat"] = float(grp_t_stats[idx]) # Compute simultaneous confidence band CIs if cband is available cband_crit_value = None if bootstrap_results is not None: cband_crit_value = bootstrap_results.cband_crit_value if cband_crit_value is not None and event_study_effects is not None: for e, eff_data in event_study_effects.items(): se_val = eff_data["se"] if np.isfinite(se_val) and se_val > 0: eff_data["cband_conf_int"] = ( eff_data["effect"] - cband_crit_value * se_val, eff_data["effect"] + cband_crit_value * se_val, ) # Consolidated _safe_inv lstsq-fallback warning (sibling of PR #9 # finding #17). Rank-deficient PS Hessian / OR bread matrices in the # analytical SE paths previously fell back to np.linalg.lstsq # silently per cell. Now aggregated here into ONE UserWarning so # a bad design surface doesn't quietly degrade analytical SEs. if self._safe_inv_tracker: n_fallbacks = len(self._safe_inv_tracker) finite_conds = [c for c in self._safe_inv_tracker if np.isfinite(c)] max_cond = max(finite_conds) if finite_conds else float("inf") warnings.warn( f"Rank-deficient matrix encountered {n_fallbacks} time(s) " f"in analytical SE paths (propensity-score Hessian or " f"outcome-regression bread); fell back to np.linalg.lstsq. " f"Max condition number of affected matrix: {max_cond:.2e}. " f"Analytical SEs may be numerically unstable; consider " f"dropping collinear covariates or using n_bootstrap > 0.", UserWarning, stacklevel=2, ) # Store results # Retrieve event-study VCV from aggregation mixin (Phase 7d). # Clear it when bootstrap overwrites event-study SEs to prevent # HonestDiD from mixing analytical VCV with bootstrap SEs. event_study_vcov = getattr(self, "_event_study_vcov", None) event_study_vcov_index = getattr(self, "_event_study_vcov_index", None) if bootstrap_results is not None and event_study_vcov is not None: event_study_vcov = None event_study_vcov_index = None self.results_ = CallawaySantAnnaResults( group_time_effects=group_time_effects, overall_att=overall_att, overall_se=overall_se, overall_t_stat=overall_t, overall_p_value=overall_p, overall_conf_int=overall_ci, groups=treatment_groups, time_periods=time_periods, n_obs=len(df), n_treated_units=n_treated_units, n_control_units=n_control_units, alpha=self.alpha, control_group=self.control_group, base_period=self.base_period, anticipation=self.anticipation, event_study_effects=event_study_effects, group_effects=group_effects, bootstrap_results=bootstrap_results, cband_crit_value=cband_crit_value, pscore_trim=self.pscore_trim, survey_metadata=survey_metadata, event_study_vcov=event_study_vcov, event_study_vcov_index=event_study_vcov_index, panel=self.panel, epv_diagnostics=epv_diagnostics if epv_diagnostics else None, epv_threshold=self.epv_threshold, pscore_fallback=self.pscore_fallback, ) self.is_fitted_ = True return self.results_
def _outcome_regression( self, treated_change: np.ndarray, control_change: np.ndarray, X_treated: Optional[np.ndarray] = None, X_control: Optional[np.ndarray] = None, sw_treated: Optional[np.ndarray] = None, sw_control: Optional[np.ndarray] = None, ) -> Tuple[float, float, np.ndarray]: """ Estimate ATT using outcome regression. With covariates: 1. Regress outcome changes on covariates for control group 2. Predict counterfactual for treated using their covariates 3. ATT = mean(treated_change) - mean(predicted_counterfactual) Without covariates: Simple difference in means. Parameters ---------- sw_treated, sw_control : np.ndarray, optional Survey weights for treated and control units. """ n_t = len(treated_change) n_c = len(control_change) if X_treated is not None and X_control is not None and X_treated.shape[1] > 0: # Covariate-adjusted outcome regression # Fit regression on control units: E[Delta Y | X, D=0] beta, residuals = _linear_regression( X_control, control_change, rank_deficient_action=self.rank_deficient_action, weights=sw_control, ) # Zero NaN coefficients for prediction (dropped rank-deficient columns # contribute 0 to the column space projection, matching DR path convention) beta = np.where(np.isfinite(beta), beta, 0.0) # Predict counterfactual for treated units X_treated_with_intercept = np.column_stack([np.ones(n_t), X_treated]) predicted_control = np.dot(X_treated_with_intercept, beta) # ATT: survey-weighted mean of treated residuals treated_residuals = treated_change - predicted_control if sw_treated is not None: sw_t_sum = float(np.sum(sw_treated)) sw_c_sum = float(np.sum(sw_control)) sw_t_norm = sw_treated / sw_t_sum sw_c_norm = sw_control / sw_c_sum att = float(np.sum(sw_t_norm * treated_residuals)) # Survey-weighted OR influence function. # Mirrors unweighted: inf_treated = (resid-ATT)/n_t, # inf_control = -resid/n_c. Survey: w_i/sum(w_group). # WLS residuals are orthogonal to W*X by construction. X_c_int = np.column_stack([np.ones(n_c), X_control]) resid_c = control_change - np.dot(X_c_int, beta) inf_treated = sw_t_norm * (treated_residuals - att) inf_control = -sw_c_norm * resid_c inf_func = np.concatenate([inf_treated, inf_control]) # SE: survey-weighted variance matching unweighted var_t/n_t + var_c/n_c var_t = float(np.sum(sw_t_norm * (treated_residuals - att) ** 2)) var_c = float(np.sum(sw_c_norm * resid_c**2)) se = float(np.sqrt(var_t + var_c)) if (n_t > 0 and n_c > 0) else 0.0 else: att = float(np.mean(treated_residuals)) # Standard error using sandwich estimator var_t = np.var(treated_residuals, ddof=1) if n_t > 1 else 0.0 var_c = np.var(residuals, ddof=1) if n_c > 1 else 0.0 se = float(np.sqrt(var_t / n_t + var_c / n_c)) if (n_t > 0 and n_c > 0) else 0.0 # Influence function inf_treated = (treated_residuals - np.mean(treated_residuals)) / n_t inf_control = -residuals / n_c inf_func = np.concatenate([inf_treated, inf_control]) else: # Simple difference in means (no covariates) if sw_treated is not None: sw_t_norm = sw_treated / np.sum(sw_treated) sw_c_norm = sw_control / np.sum(sw_control) mu_t = float(np.sum(sw_t_norm * treated_change)) mu_c = float(np.sum(sw_c_norm * control_change)) att = mu_t - mu_c # Influence function (survey-weighted) inf_treated = sw_t_norm * (treated_change - mu_t) inf_control = -sw_c_norm * (control_change - mu_c) inf_func = np.concatenate([inf_treated, inf_control]) # SE from influence function variance se = ( float(np.sqrt(np.sum(inf_treated**2) + np.sum(inf_control**2))) if (n_t > 0 and n_c > 0) else 0.0 ) else: att = float(np.mean(treated_change) - np.mean(control_change)) var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0 var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0 se = float(np.sqrt(var_t / n_t + var_c / n_c)) if (n_t > 0 and n_c > 0) else 0.0 # Influence function (for aggregation) inf_treated = treated_change - np.mean(treated_change) inf_control = control_change - np.mean(control_change) inf_func = np.concatenate([inf_treated / n_t, -inf_control / n_c]) return att, se, inf_func def _ipw_estimation( self, treated_change: np.ndarray, control_change: np.ndarray, n_treated: int, n_control: int, X_treated: Optional[np.ndarray] = None, X_control: Optional[np.ndarray] = None, pscore_cache: Optional[Dict] = None, pscore_key: Optional[Any] = None, sw_treated: Optional[np.ndarray] = None, sw_control: Optional[np.ndarray] = None, sw_all: Optional[np.ndarray] = None, context_label: str = "", epv_diagnostics_out: Optional[dict] = None, ) -> Tuple[float, float, np.ndarray]: """ Estimate ATT using inverse probability weighting. With covariates: 1. Estimate propensity score P(D=1|X) using logistic regression 2. Reweight control units to match treated covariate distribution 3. ATT = mean(treated) - weighted_mean(control) Without covariates: Simple difference in means with unconditional propensity weighting. Parameters ---------- sw_treated, sw_control, sw_all : np.ndarray, optional Survey weights for treated, control, and all units. """ n_t = len(treated_change) n_c = len(control_change) n_total = n_treated + n_control if X_treated is not None and X_control is not None and X_treated.shape[1] > 0: # Covariate-adjusted IPW estimation ps_fallback_used = False # Check propensity score cache cached_pscore = None if pscore_cache is not None and pscore_key is not None: cached_pscore = pscore_cache.get(pscore_key) if cached_pscore is not None: # Use cached propensity scores (beta coefficients + EPV diag) beta_logistic, cached_diag = cached_pscore X_all = np.vstack([X_treated, X_control]) X_all_with_intercept = np.column_stack([np.ones(n_t + n_c), X_all]) z = np.dot(X_all_with_intercept, beta_logistic) z = np.clip(z, -500, 500) pscore = 1 / (1 + np.exp(-z)) if epv_diagnostics_out is not None and cached_diag: epv_diagnostics_out.update(cached_diag) else: # Stack covariates and create treatment indicator X_all = np.vstack([X_treated, X_control]) D = np.concatenate([np.ones(n_t), np.zeros(n_c)]) # Estimate propensity scores using IRLS logistic regression diag = {} try: beta_logistic, pscore = solve_logit( X_all, D, rank_deficient_action=self.rank_deficient_action, weights=sw_all, epv_threshold=self.epv_threshold, context_label=context_label, diagnostics_out=diag, ) _check_propensity_diagnostics(pscore, self.pscore_trim) # Cache the fitted coefficients (zero-fill NaN from # dropped rank-deficient columns to prevent NaN # propagation on cache reuse) alongside EPV diagnostics if pscore_cache is not None and pscore_key is not None: beta_clean = np.where(np.isfinite(beta_logistic), beta_logistic, 0.0) pscore_cache[pscore_key] = (beta_clean, diag) except (np.linalg.LinAlgError, ValueError): if self.pscore_fallback == "error" or self.rank_deficient_action == "error": raise # Fallback to unconditional if logistic regression fails ctx = f" for {context_label}" if context_label else "" warnings.warn( f"Propensity score estimation failed{ctx}. " f"Falling back to unconditional propensity " f"(all covariates dropped for this cell). " f"Consider estimation_method='reg' to avoid " f"propensity scores entirely.", UserWarning, stacklevel=4, ) if sw_all is not None: pos = sw_all > 0 p_uc = float(np.average(D[pos], weights=sw_all[pos])) else: p_uc = n_t / (n_t + n_c) pscore = np.full(len(D), p_uc) ps_fallback_used = True if epv_diagnostics_out is not None and diag: epv_diagnostics_out.update(diag) # Propensity scores for treated and control pscore_treated = pscore[:n_t] pscore_control = pscore[n_t:] # Clip propensity scores to avoid extreme weights pscore_control = np.clip(pscore_control, self.pscore_trim, 1 - self.pscore_trim) pscore_treated = np.clip(pscore_treated, self.pscore_trim, 1 - self.pscore_trim) if sw_treated is not None: # IPW weights compose with survey weights: # w_i = sw_i * p(X_i) / (1 - p(X_i)) weights_control = sw_control * pscore_control / (1 - pscore_control) weights_control_norm = weights_control / np.sum(weights_control) # ATT: survey-weighted treated mean minus composite-weighted control mean sw_t_norm = sw_treated / np.sum(sw_treated) mu_t = float(np.sum(sw_t_norm * treated_change)) att = mu_t - float(np.sum(weights_control_norm * control_change)) # Influence function (survey-weighted) inf_treated = sw_t_norm * (treated_change - mu_t) inf_control = -weights_control_norm * ( control_change - np.sum(weights_control_norm * control_change) ) inf_func = np.concatenate([inf_treated, inf_control]) if not ps_fallback_used: # Propensity score IF correction # Accounts for estimation uncertainty in logistic regression coefficients X_all_int = np.column_stack([np.ones(n_t + n_c), X_all]) pscore_all = np.concatenate([pscore_treated, pscore_control]) # PS IF correction — compute in R's psi convention, convert to phi n_all_panel = n_t + n_c W_ps = pscore_all * (1 - pscore_all) if sw_all is not None: W_ps = W_ps * sw_all # R: Hessian.ps = crossprod(X * sqrt(W)) / n H_psi = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all_panel H_psi_inv = _safe_inv(H_psi, tracker=self._safe_inv_tracker) D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)]) score_ps = (D_all - pscore_all)[:, None] * X_all_int if sw_all is not None: score_ps = score_ps * sw_all[:, None] # R: asy.lin.rep.ps = score.ps %*% Hessian.ps (psi scale, O(1) per obs) asy_lin_rep_psi = score_ps @ H_psi_inv att_control_weighted = np.sum(weights_control_norm * control_change) # R: M2 = colMeans(w.cont * (y - att) * X) / mean(w.cont) # np.sum (not mean): subset sum with normalized weights matches # R's full-sample colMeans/mean(w) after cancellation M2 = np.sum( (weights_control_norm * (control_change - att_control_weighted))[:, None] * X_all_int[n_t:], axis=0, ) # psi-scale correction, convert to phi for storage # Subtract: R adds PS correction to inf.control, then att = treat - control inf_func = inf_func - (asy_lin_rep_psi @ M2) / n_all_panel # SE from influence function variance var_psi = np.sum(inf_func**2) se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0 else: # IPW weights for control units: p(X) / (1 - p(X)) # This reweights controls to have same covariate distribution as treated weights_control = pscore_control / (1 - pscore_control) weights_control = weights_control / np.sum(weights_control) # normalize # ATT = mean(treated) - weighted_mean(control) att = float(np.mean(treated_change) - np.sum(weights_control * control_change)) # Compute standard error var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0 weighted_var_c = np.sum( weights_control * (control_change - np.sum(weights_control * control_change)) ** 2 ) se = float(np.sqrt(var_t / n_t + weighted_var_c)) if (n_t > 0 and n_c > 0) else 0.0 # Influence function inf_treated = (treated_change - np.mean(treated_change)) / n_t inf_control = -weights_control * ( control_change - np.sum(weights_control * control_change) ) inf_func = np.concatenate([inf_treated, inf_control]) else: # Unconditional IPW (reduces to difference in means) if sw_treated is not None: # Survey-weighted difference in means sw_t_norm = sw_treated / np.sum(sw_treated) sw_c_norm = sw_control / np.sum(sw_control) mu_t = float(np.sum(sw_t_norm * treated_change)) mu_c = float(np.sum(sw_c_norm * control_change)) att = mu_t - mu_c inf_treated = sw_t_norm * (treated_change - mu_t) inf_control = -sw_c_norm * (control_change - mu_c) inf_func = np.concatenate([inf_treated, inf_control]) se = ( float(np.sqrt(np.sum(inf_treated**2) + np.sum(inf_control**2))) if (n_t > 0 and n_c > 0) else 0.0 ) else: p_treat = n_treated / n_total # unconditional propensity score att = float(np.mean(treated_change) - np.mean(control_change)) var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0 var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0 # Adjusted variance for IPW se = float( np.sqrt(var_t / n_t + var_c * (1 - p_treat) / (n_c * p_treat)) if (n_t > 0 and n_c > 0 and p_treat > 0) else 0.0 ) # Influence function (for aggregation) inf_treated = (treated_change - np.mean(treated_change)) / n_t inf_control = (control_change - np.mean(control_change)) / n_c inf_func = np.concatenate([inf_treated, -inf_control]) return att, se, inf_func def _doubly_robust( self, treated_change: np.ndarray, control_change: np.ndarray, X_treated: Optional[np.ndarray] = None, X_control: Optional[np.ndarray] = None, pscore_cache: Optional[Dict] = None, pscore_key: Optional[Any] = None, cho_cache: Optional[Dict] = None, cho_key: Optional[Any] = None, sw_treated: Optional[np.ndarray] = None, sw_control: Optional[np.ndarray] = None, sw_all: Optional[np.ndarray] = None, context_label: str = "", epv_diagnostics_out: Optional[dict] = None, ) -> Tuple[float, float, np.ndarray]: """ Estimate ATT using doubly robust estimation. With covariates: Combines outcome regression and IPW for double robustness. The estimator is consistent if either the outcome model OR the propensity model is correctly specified. ATT_DR = (1/n_t) * sum_i[D_i * (Y_i - m(X_i))] + (1/n_t) * sum_i[(1-D_i) * w_i * (m(X_i) - Y_i)] where m(X) is the outcome model and w_i are IPW weights. Without covariates: Reduces to simple difference in means. Parameters ---------- sw_treated, sw_control, sw_all : np.ndarray, optional Survey weights for treated, control, and all units. """ n_t = len(treated_change) n_c = len(control_change) if X_treated is not None and X_control is not None and X_treated.shape[1] > 0: # Doubly robust estimation with covariates ps_fallback_used = False # Step 1: Outcome regression - fit E[Delta Y | X] on control # Try Cholesky cache for outcome regression (disabled when survey weights present) beta = None X_control_with_intercept = np.column_stack([np.ones(n_c), X_control]) if cho_cache is not None and cho_key is not None: cached_cho = cho_cache.get(cho_key) if cached_cho is False: # Rank-deficient sentinel: skip Cholesky, fall through pass elif cached_cho is not None: Xty = X_control_with_intercept.T @ control_change beta = scipy_linalg.cho_solve(cached_cho, Xty) if np.any(~np.isfinite(beta)): beta = None else: # First time for this cho_key: check rank before Cholesky rank_info = _detect_rank_deficiency(X_control_with_intercept) if len(rank_info[1]) > 0: cho_cache[cho_key] = False # Sentinel else: XtX = X_control_with_intercept.T @ X_control_with_intercept try: cho_factor = scipy_linalg.cho_factor(XtX) cho_cache[cho_key] = cho_factor Xty = X_control_with_intercept.T @ control_change beta = scipy_linalg.cho_solve(cho_factor, Xty) if np.any(~np.isfinite(beta)): beta = None except np.linalg.LinAlgError: pass if beta is None: beta, _ = _linear_regression( X_control, control_change, rank_deficient_action=self.rank_deficient_action, weights=sw_control, ) # Zero NaN coefficients for prediction only — dropped columns # contribute 0 to the column space projection. Note: solve_ols # deliberately uses NaN (R's lm() convention) for inference, but # here we only need beta for prediction (m_treated, m_control). beta = np.where(np.isfinite(beta), beta, 0.0) # Predict counterfactual for both treated and control X_treated_with_intercept = np.column_stack([np.ones(n_t), X_treated]) m_treated = np.dot(X_treated_with_intercept, beta) m_control = np.dot(X_control_with_intercept, beta) # Step 2: Propensity score estimation # Check propensity score cache cached_pscore = None if pscore_cache is not None and pscore_key is not None: cached_pscore = pscore_cache.get(pscore_key) if cached_pscore is not None: beta_logistic, cached_diag = cached_pscore X_all = np.vstack([X_treated, X_control]) X_all_with_intercept = np.column_stack([np.ones(n_t + n_c), X_all]) z = np.dot(X_all_with_intercept, beta_logistic) z = np.clip(z, -500, 500) pscore = 1 / (1 + np.exp(-z)) if epv_diagnostics_out is not None and cached_diag: epv_diagnostics_out.update(cached_diag) else: X_all = np.vstack([X_treated, X_control]) D = np.concatenate([np.ones(n_t), np.zeros(n_c)]) diag = {} try: beta_logistic, pscore = solve_logit( X_all, D, rank_deficient_action=self.rank_deficient_action, weights=sw_all, epv_threshold=self.epv_threshold, context_label=context_label, diagnostics_out=diag, ) _check_propensity_diagnostics(pscore, self.pscore_trim) if pscore_cache is not None and pscore_key is not None: beta_clean = np.where(np.isfinite(beta_logistic), beta_logistic, 0.0) pscore_cache[pscore_key] = (beta_clean, diag) except (np.linalg.LinAlgError, ValueError): if self.pscore_fallback == "error" or self.rank_deficient_action == "error": raise # Fallback to unconditional if logistic regression fails ctx = f" for {context_label}" if context_label else "" warnings.warn( f"Propensity score estimation failed{ctx}. " f"Falling back to unconditional propensity " f"(propensity model ignores covariates; outcome " f"regression still uses them). " f"Consider estimation_method='reg' to avoid " f"propensity scores entirely.", UserWarning, stacklevel=4, ) if sw_all is not None: pos = sw_all > 0 p_uc = float(np.average(D[pos], weights=sw_all[pos])) else: p_uc = n_t / (n_t + n_c) pscore = np.full(len(D), p_uc) ps_fallback_used = True if epv_diagnostics_out is not None and diag: epv_diagnostics_out.update(diag) pscore_control = pscore[n_t:] # Clip propensity scores pscore_control = np.clip(pscore_control, self.pscore_trim, 1 - self.pscore_trim) if sw_treated is not None: # IPW weights compose with survey weights weights_control = sw_control * pscore_control / (1 - pscore_control) # Step 3: DR ATT (survey-weighted) sw_t_sum = np.sum(sw_treated) att_treated_part = float( np.sum(sw_treated * (treated_change - m_treated)) / sw_t_sum ) augmentation = float( np.sum(weights_control * (m_control - control_change)) / sw_t_sum ) att = att_treated_part + augmentation # Step 4: Influence function (survey-weighted DR) # Start with plug-in IF, then add nuisance parameter corrections # (Sant'Anna & Zhao 2020, Theorem 3.1) psi_treated = (sw_treated / sw_t_sum) * (treated_change - m_treated - att) psi_control = (weights_control / sw_t_sum) * (m_control - control_change) inf_func = np.concatenate([psi_treated, psi_control]) if X_treated is not None and X_control is not None and X_treated.shape[1] > 0: if not ps_fallback_used: # --- PS IF correction (mirrors IPW L1929-1961) --- # Accounts for propensity score estimation uncertainty X_all_int = np.column_stack([np.ones(n_t + n_c), X_all]) pscore_treated_clipped = np.clip( pscore[:n_t], self.pscore_trim, 1 - self.pscore_trim ) pscore_all = np.concatenate([pscore_treated_clipped, pscore_control]) # PS IF correction — psi convention, convert to phi n_all_panel = n_t + n_c W_ps = pscore_all * (1 - pscore_all) if sw_all is not None: W_ps = W_ps * sw_all H_psi = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all_panel H_psi_inv = _safe_inv(H_psi, tracker=self._safe_inv_tracker) D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)]) score_ps = (D_all - pscore_all)[:, None] * X_all_int if sw_all is not None: score_ps = score_ps * sw_all[:, None] # R: asy.lin.rep.ps = score.ps %*% Hessian.ps (psi scale) asy_lin_rep_psi = score_ps @ H_psi_inv dr_resid_control = m_control - control_change M2_dr = np.sum( ((weights_control / sw_t_sum) * dr_resid_control)[:, None] * X_all_int[n_t:], axis=0, ) inf_func = inf_func + (asy_lin_rep_psi @ M2_dr) / n_all_panel # --- OR IF correction --- # Accounts for outcome regression estimation uncertainty X_c_int = X_control_with_intercept W_diag = sw_control if sw_control is not None else np.ones(n_c) XtWX = X_c_int.T @ (W_diag[:, None] * X_c_int) bread = _safe_inv(XtWX, tracker=self._safe_inv_tracker) # M1: dATT/dbeta — gradient of DR ATT w.r.t. OR parameters X_t_int = X_treated_with_intercept M1 = ( -np.sum(sw_treated[:, None] * X_t_int, axis=0) + np.sum(weights_control[:, None] * X_c_int, axis=0) ) / sw_t_sum # OR asymptotic linear representation (control-only) resid_c = control_change - m_control asy_lin_rep_or = (W_diag * resid_c)[:, None] * X_c_int @ bread # Apply to control portion only (treated contribute zero) inf_func[n_t:] += asy_lin_rep_or @ M1 # Recompute SE from corrected IF var_psi = np.sum(inf_func**2) se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0 else: # IPW weights for control: p(X) / (1 - p(X)) weights_control = pscore_control / (1 - pscore_control) # Step 3: Doubly robust ATT att_treated_part = float(np.mean(treated_change - m_treated)) augmentation = float(np.sum(weights_control * (m_control - control_change)) / n_t) att = att_treated_part + augmentation # Step 4: Influence function with nuisance IF corrections psi_treated = (treated_change - m_treated - att) / n_t psi_control = (weights_control * (m_control - control_change)) / n_t inf_func = np.concatenate([psi_treated, psi_control]) if X_treated is not None and X_control is not None and X_treated.shape[1] > 0: if not ps_fallback_used: # --- PS IF correction — psi convention, convert to phi --- n_all_panel = n_t + n_c X_all_int = np.column_stack([np.ones(n_all_panel), X_all]) pscore_treated_clipped = np.clip( pscore[:n_t], self.pscore_trim, 1 - self.pscore_trim ) pscore_all = np.concatenate([pscore_treated_clipped, pscore_control]) W_ps = pscore_all * (1 - pscore_all) H_psi = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all_panel H_psi_inv = _safe_inv(H_psi, tracker=self._safe_inv_tracker) D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)]) score_ps = (D_all - pscore_all)[:, None] * X_all_int # R: asy.lin.rep.ps = score.ps %*% Hessian.ps (psi scale) asy_lin_rep_psi = score_ps @ H_psi_inv dr_resid_control = m_control - control_change M2_dr = np.sum( ((weights_control / n_t) * dr_resid_control)[:, None] * X_all_int[n_t:], axis=0, ) inf_func = inf_func + (asy_lin_rep_psi @ M2_dr) / n_all_panel # --- OR IF correction --- X_c_int = X_control_with_intercept XtX = X_c_int.T @ X_c_int bread = _safe_inv(XtX, tracker=self._safe_inv_tracker) X_t_int = X_treated_with_intercept M1 = ( -np.sum(X_t_int, axis=0) + np.sum(weights_control[:, None] * X_c_int, axis=0) ) / n_t resid_c = control_change - m_control asy_lin_rep_or = resid_c[:, None] * X_c_int @ bread inf_func[n_t:] += asy_lin_rep_or @ M1 # Recompute SE from corrected IF var_psi = np.sum(inf_func**2) se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0 else: # Without covariates, DR simplifies to difference in means if sw_treated is not None: sw_t_norm = sw_treated / np.sum(sw_treated) sw_c_norm = sw_control / np.sum(sw_control) mu_t = float(np.sum(sw_t_norm * treated_change)) mu_c = float(np.sum(sw_c_norm * control_change)) att = mu_t - mu_c inf_treated = sw_t_norm * (treated_change - mu_t) inf_control = -sw_c_norm * (control_change - mu_c) inf_func = np.concatenate([inf_treated, inf_control]) se = ( float(np.sqrt(np.sum(inf_treated**2) + np.sum(inf_control**2))) if (n_t > 0 and n_c > 0) else 0.0 ) else: att = float(np.mean(treated_change) - np.mean(control_change)) var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0 var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0 se = float(np.sqrt(var_t / n_t + var_c / n_c)) if (n_t > 0 and n_c > 0) else 0.0 # Influence function for DR estimator inf_treated = (treated_change - np.mean(treated_change)) / n_t inf_control = (control_change - np.mean(control_change)) / n_c inf_func = np.concatenate([inf_treated, -inf_control]) return att, se, inf_func # ========================================================================= # Repeated Cross-Section (RCS) methods # ========================================================================= def _precompute_structures_rc( self, df: pd.DataFrame, outcome: str, unit: str, time: str, first_treat: str, covariates: Optional[List[str]], time_periods: List[Any], treatment_groups: List[Any], resolved_survey=None, ) -> PrecomputedData: """ Pre-compute observation-level structures for repeated cross-section. Unlike the panel path, RCS does not pivot to wide format. Each observation is treated independently (no within-unit differencing). Returns ------- PrecomputedData Dictionary with pre-computed structures (observation-level). """ n_obs = len(df) # Observation-level arrays (no pivot) obs_time = df[time].values obs_outcome = df[outcome].values unit_cohorts = df[first_treat].values # "all_units" key holds integer observation indices for backward # compatibility with aggregation code all_units = np.arange(n_obs) # Pre-compute cohort masks (boolean arrays, observation-level) cohort_masks = {} for g in treatment_groups: cohort_masks[g] = unit_cohorts == g # Never-treated mask never_treated_mask = (unit_cohorts == 0) | (unit_cohorts == np.inf) # Period-to-column mapping (identity for RCS — used for base period checks) period_to_col = {t: i for i, t in enumerate(sorted(time_periods))} # Covariates (observation-level, not per-period) obs_covariates = None if covariates: obs_covariates = df[covariates].values # Survey weights (already per-observation for RCS) if resolved_survey is not None: survey_weights_arr = resolved_survey.weights.copy() else: survey_weights_arr = None # For RCS, the resolved survey is already per-observation resolved_survey_rc = resolved_survey # Fixed cohort masses: total observations per cohort across all periods. # Used as aggregation weights so that n_treated is consistent with WIF. rcs_cohort_masses = {} for g in treatment_groups: rcs_cohort_masses[g] = int(np.sum(unit_cohorts == g)) return { "all_units": all_units, "unit_to_idx": None, # RCS: obs indices are positions "unit_cohorts": unit_cohorts, "canonical_size": n_obs, "is_panel": False, "obs_time": obs_time, "obs_outcome": obs_outcome, "obs_covariates": obs_covariates, "cohort_masks": cohort_masks, "never_treated_mask": never_treated_mask, "time_periods": time_periods, "period_to_col": period_to_col, "is_balanced": False, "survey_weights": survey_weights_arr, "resolved_survey": resolved_survey, "resolved_survey_unit": resolved_survey_rc, "df_survey": ( resolved_survey_rc.df_survey if resolved_survey_rc is not None and hasattr(resolved_survey_rc, "df_survey") else None ), "rcs_cohort_masses": rcs_cohort_masses, } def _compute_att_gt_rc( self, precomputed: PrecomputedData, g: Any, t: Any, covariates: Optional[List[str]], epv_diagnostics: Optional[Dict] = None, ) -> Tuple[Optional[float], float, int, int, Optional[Dict[str, Any]], Optional[float]]: """ Compute ATT(g,t) for repeated cross-section data. For RCS, the 2x2 DiD compares outcomes across two independent cross-sections (periods t and base period s) rather than within-unit changes. Returns ------- att_gt : float or None se_gt : float n_treated : int (treated obs at period t) n_control : int (control obs at period t) inf_func_info : dict or None survey_weight_sum : float or None """ cohort_masks = precomputed["cohort_masks"] never_treated_mask = precomputed["never_treated_mask"] unit_cohorts = precomputed["unit_cohorts"] obs_time = precomputed["obs_time"] obs_outcome = precomputed["obs_outcome"] period_to_col = precomputed["period_to_col"] # Base period selection (same logic as panel) if self.base_period == "universal": base_period_val = g - 1 - self.anticipation else: # varying if t < g - self.anticipation: base_period_val = t - 1 else: base_period_val = g - 1 - self.anticipation if base_period_val not in period_to_col or t not in period_to_col: return None, 0.0, 0, 0, None, None # Treated mask = cohort g treated_mask = cohort_masks[g] # Control mask (same logic as panel) if self.control_group == "never_treated": control_mask = never_treated_mask else: # not_yet_treated nyt_threshold = max(t, base_period_val) + self.anticipation control_mask = never_treated_mask | ( (unit_cohorts > nyt_threshold) & (unit_cohorts != g) ) # Period masks at_t = obs_time == t at_s = obs_time == base_period_val # 4 groups of observations treated_t = treated_mask & at_t treated_s = treated_mask & at_s control_t = control_mask & at_t control_s = control_mask & at_s n_gt = int(np.sum(treated_t)) n_gs = int(np.sum(treated_s)) n_ct = int(np.sum(control_t)) n_cs = int(np.sum(control_s)) if n_gt == 0 or n_ct == 0 or n_gs == 0 or n_cs == 0: return None, 0.0, 0, 0, None, None # Extract outcomes for each group y_gt = obs_outcome[treated_t] y_gs = obs_outcome[treated_s] y_ct = obs_outcome[control_t] y_cs = obs_outcome[control_s] # Survey weights survey_w = precomputed.get("survey_weights") sw_gt = survey_w[treated_t] if survey_w is not None else None sw_gs = survey_w[treated_s] if survey_w is not None else None sw_ct = survey_w[control_t] if survey_w is not None else None sw_cs = survey_w[control_s] if survey_w is not None else None # Guard against zero effective mass if sw_gt is not None: if np.sum(sw_gt) <= 0 or np.sum(sw_gs) <= 0: return None, 0.0, 0, 0, None, None if np.sum(sw_ct) <= 0 or np.sum(sw_cs) <= 0: return None, 0.0, 0, 0, None, None # Get covariates if specified obs_covariates = precomputed.get("obs_covariates") has_covariates = covariates is not None and obs_covariates is not None if has_covariates: X_gt = obs_covariates[treated_t] X_gs = obs_covariates[treated_s] X_ct = obs_covariates[control_t] X_cs = obs_covariates[control_s] # Check for NaN in covariates if ( np.any(np.isnan(X_gt)) or np.any(np.isnan(X_gs)) or np.any(np.isnan(X_ct)) or np.any(np.isnan(X_cs)) ): warnings.warn( f"Missing values in covariates for group {g}, time {t} (RCS). " "Falling back to unconditional estimation.", UserWarning, stacklevel=3, ) has_covariates = False if has_covariates and self.estimation_method == "reg": att, se, inf_func_all, idx_all = self._outcome_regression_rc( y_gt, y_gs, y_ct, y_cs, X_gt, X_gs, X_ct, X_cs, sw_gt=sw_gt, sw_gs=sw_gs, sw_ct=sw_ct, sw_cs=sw_cs, ) elif has_covariates and self.estimation_method == "ipw": epv_diag: dict = {} att, se, inf_func_all, idx_all = self._ipw_estimation_rc( y_gt, y_gs, y_ct, y_cs, X_gt, X_gs, X_ct, X_cs, sw_gt=sw_gt, sw_gs=sw_gs, sw_ct=sw_ct, sw_cs=sw_cs, context_label=f"cohort g={g}", epv_diagnostics_out=epv_diag, ) if epv_diagnostics is not None and epv_diag: epv_diagnostics[(g, t)] = epv_diag elif has_covariates and self.estimation_method == "dr": epv_diag = {} att, se, inf_func_all, idx_all = self._doubly_robust_rc( y_gt, y_gs, y_ct, y_cs, X_gt, X_gs, X_ct, X_cs, sw_gt=sw_gt, sw_gs=sw_gs, sw_ct=sw_ct, sw_cs=sw_cs, context_label=f"cohort g={g}", epv_diagnostics_out=epv_diag, ) if epv_diagnostics is not None and epv_diag: epv_diagnostics[(g, t)] = epv_diag else: # No-covariates 2x2 DiD (all methods reduce to same) att, se, inf_func_all, idx_all = self._rc_2x2_did( y_gt, y_gs, y_ct, y_cs, treated_t, treated_s, control_t, control_s, sw_gt=sw_gt, sw_gs=sw_gs, sw_ct=sw_ct, sw_cs=sw_cs, ) # Build influence function info # For RCS, treated_idx/control_idx combine obs from BOTH periods treated_idx = np.concatenate([np.where(treated_t)[0], np.where(treated_s)[0]]) control_idx = np.concatenate([np.where(control_t)[0], np.where(control_s)[0]]) n_treated_combined = len(treated_idx) inf_func_info = { "treated_idx": treated_idx, "control_idx": control_idx, "treated_units": treated_idx, # For RCS, obs indices = "units" "control_units": control_idx, "treated_inf": inf_func_all[:n_treated_combined], "control_inf": inf_func_all[n_treated_combined:], } sw_sum = float(np.sum(sw_gt)) if sw_gt is not None else None # n_treated = per-cell treated count at period t (for display). # cohort_mass = total treated across all periods (for aggregation weights). cohort_mass = precomputed.get("rcs_cohort_masses", {}).get(g, n_gt) return att, se, n_gt, n_ct, inf_func_info, sw_sum, cohort_mass def _rc_2x2_did( self, y_gt, y_gs, y_ct, y_cs, mask_gt, mask_gs, mask_ct, mask_cs, sw_gt=None, sw_gs=None, sw_ct=None, sw_cs=None, ): """ Compute the basic 2x2 DiD for RCS (no covariates). ATT = (mean(Y_treated_t) - mean(Y_control_t)) - (mean(Y_treated_s) - mean(Y_control_s)) Returns (att, se, inf_func_concat, idx_concat) where inf_func_concat has treated obs (both periods) first, then control obs (both periods). """ n_gt = len(y_gt) n_gs = len(y_gs) n_ct = len(y_ct) n_cs = len(y_cs) if sw_gt is not None: sw_gt_norm = sw_gt / np.sum(sw_gt) sw_gs_norm = sw_gs / np.sum(sw_gs) sw_ct_norm = sw_ct / np.sum(sw_ct) sw_cs_norm = sw_cs / np.sum(sw_cs) mu_gt = float(np.sum(sw_gt_norm * y_gt)) mu_gs = float(np.sum(sw_gs_norm * y_gs)) mu_ct = float(np.sum(sw_ct_norm * y_ct)) mu_cs = float(np.sum(sw_cs_norm * y_cs)) att = (mu_gt - mu_ct) - (mu_gs - mu_cs) # Influence function for 4 groups (survey-weighted) inf_gt = sw_gt_norm * (y_gt - mu_gt) inf_ct = -sw_ct_norm * (y_ct - mu_ct) inf_gs = -sw_gs_norm * (y_gs - mu_gs) inf_cs = sw_cs_norm * (y_cs - mu_cs) else: mu_gt = float(np.mean(y_gt)) mu_gs = float(np.mean(y_gs)) mu_ct = float(np.mean(y_ct)) mu_cs = float(np.mean(y_cs)) att = (mu_gt - mu_ct) - (mu_gs - mu_cs) # Influence function for 4 groups inf_gt = (y_gt - mu_gt) / n_gt inf_ct = -(y_ct - mu_ct) / n_ct inf_gs = -(y_gs - mu_gs) / n_gs inf_cs = (y_cs - mu_cs) / n_cs # Concatenate: treated (t then s), control (t then s) inf_treated = np.concatenate([inf_gt, inf_gs]) inf_control = np.concatenate([inf_ct, inf_cs]) inf_all = np.concatenate([inf_treated, inf_control]) # SE from influence function se = float(np.sqrt(np.sum(inf_all**2))) idx_all = np.concatenate( [ np.where(mask_gt)[0], np.where(mask_gs)[0], np.where(mask_ct)[0], np.where(mask_cs)[0], ] ) return att, se, inf_all, idx_all def _outcome_regression_rc( self, y_gt, y_gs, y_ct, y_cs, X_gt, X_gs, X_ct, X_cs, sw_gt=None, sw_gs=None, sw_ct=None, sw_cs=None, ): """ Cross-sectional outcome regression for ATT(g,t). Matches R DRDID::reg_did_rc (Sant'Anna & Zhao 2020, Eq 2.2). Two OLS models fit on controls (period t and base period s). Predictions made for ALL treated (both periods). OR correction pools ALL treated observations across both periods. IF convention ------------- Intermediate terms use R's unnormalized psi_i convention throughout. R computes SE as ``sd(psi) / sqrt(n)``; with mean(psi) approx 0 this equals ``sqrt(sum(psi^2)) / n``. At the end we convert to the library's pre-scaled phi_i = psi_i / n convention where ``se = sqrt(sum(phi^2))``, used by the aggregation/bootstrap layer. Returns (att, se, inf_func_concat, idx_concat). """ n_gt = len(y_gt) n_gs = len(y_gs) n_ct = len(y_ct) n_cs = len(y_cs) n_all = n_gt + n_gs + n_ct + n_cs # --- Fit 2 OLS on control groups (period t and s separately) --- beta_t, resid_ct = _linear_regression( X_ct, y_ct, rank_deficient_action=self.rank_deficient_action, weights=sw_ct, ) beta_t = np.where(np.isfinite(beta_t), beta_t, 0.0) beta_s, resid_cs = _linear_regression( X_cs, y_cs, rank_deficient_action=self.rank_deficient_action, weights=sw_cs, ) beta_s = np.where(np.isfinite(beta_s), beta_s, 0.0) # --- Predict counterfactual for ALL treated (both periods) --- X_gt_int = np.column_stack([np.ones(n_gt), X_gt]) X_gs_int = np.column_stack([np.ones(n_gs), X_gs]) X_ct_int = np.column_stack([np.ones(n_ct), X_ct]) X_cs_int = np.column_stack([np.ones(n_cs), X_cs]) # mu_hat_{0,t}(X) and mu_hat_{0,s}(X) for each treated obs mu_post_gt = X_gt_int @ beta_t # treated-post predicted at post model mu_pre_gt = X_gt_int @ beta_s # treated-post predicted at pre model mu_post_gs = X_gs_int @ beta_t # treated-pre predicted at post model mu_pre_gs = X_gs_int @ beta_s # treated-pre predicted at pre model # --- Group weights (R: w.treat.pre, w.treat.post, w.cont = w.D) --- if sw_gt is not None: w_treat_post = sw_gt # treated at t w_treat_pre = sw_gs # treated at s w_D_gt = sw_gt # ALL treated: t portion w_D_gs = sw_gs # ALL treated: s portion else: w_treat_post = np.ones(n_gt) w_treat_pre = np.ones(n_gs) w_D_gt = np.ones(n_gt) w_D_gs = np.ones(n_gs) sum_w_treat_post = np.sum(w_treat_post) sum_w_treat_pre = np.sum(w_treat_pre) sum_w_D = np.sum(w_D_gt) + np.sum(w_D_gs) # pool ALL treated # R: mean(w.treat.post), mean(w.treat.pre), mean(w.cont) mean_w_treat_post = sum_w_treat_post / n_all mean_w_treat_pre = sum_w_treat_pre / n_all mean_w_D = sum_w_D / n_all # --- Treated means (period-specific Hajek means) --- eta_treat_post = np.sum(w_treat_post * y_gt) / sum_w_treat_post eta_treat_pre = np.sum(w_treat_pre * y_gs) / sum_w_treat_pre # --- OR correction: pools ALL treated --- # R: out.y.post - out.y.pre for each treated obs or_diff_gt = mu_post_gt - mu_pre_gt # treated at t or_diff_gs = mu_post_gs - mu_pre_gs # treated at s eta_cont = (np.sum(w_D_gt * or_diff_gt) + np.sum(w_D_gs * or_diff_gs)) / sum_w_D # --- Point estimate --- att = float(eta_treat_post - eta_treat_pre - eta_cont) # ================================================================= # Influence function in R's unnormalized psi convention # (R: reg_did_rc.R, psi = n * phi) # ================================================================= # --- Treated psi (R: eta.treat.post, eta.treat.pre) --- # R: w.treat.post * (y - eta.treat.post) / mean(w.treat.post) psi_treat_post = w_treat_post * (y_gt - eta_treat_post) / mean_w_treat_post # R: w.treat.pre * (y - eta.treat.pre) / mean(w.treat.pre) psi_treat_pre = w_treat_pre * (y_gs - eta_treat_pre) / mean_w_treat_pre # --- Control psi: leading term (R: inf.cont.1) --- # R: w.cont * (or_diff - eta.cont) [before /mean(w.cont)] psi_cont_1_gt = w_D_gt * (or_diff_gt - eta_cont) psi_cont_1_gs = w_D_gs * (or_diff_gs - eta_cont) # --- Control psi: estimation effect (R: inf.cont.2) --- # R: bread = solve(crossprod(X_ctrl, W * X_ctrl) / n) # Here bread is (X'WX)^{-1} (without /n), so asy_lin_rep already # absorbs the 1/n that R puts in its bread. We compensate by using # R's colMeans (= sum/n_all) for M1, matching the product exactly. W_ct = sw_ct if sw_ct is not None else np.ones(n_ct) W_cs = sw_cs if sw_cs is not None else np.ones(n_cs) bread_t = _safe_inv( X_ct_int.T @ (W_ct[:, None] * X_ct_int), tracker=self._safe_inv_tracker, ) bread_s = _safe_inv( X_cs_int.T @ (W_cs[:, None] * X_cs_int), tracker=self._safe_inv_tracker, ) # R: M1 = colMeans(w.cont * out.x) = sum(w_D * X) / n_all M1 = ( np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0) ) / n_all # R: asy.lin.rep.ols (per-obs OLS score * bread) asy_lin_rep_ols_t = (W_ct * resid_ct)[:, None] * X_ct_int @ bread_t asy_lin_rep_ols_s = (W_cs * resid_cs)[:, None] * X_cs_int @ bread_s # R: inf.cont.2.post = asy.lin.rep.ols_t %*% M1 psi_cont_2_ct = asy_lin_rep_ols_t @ M1 # (n_ct,) # R: inf.cont.2.pre = asy.lin.rep.ols_s %*% M1 psi_cont_2_cs = asy_lin_rep_ols_s @ M1 # (n_cs,) # --- Assemble per-group psi --- # R: inf.treat = inf.treat.post - inf.treat.pre (across groups) # R: inf.cont = (inf.cont.1 + inf.cont.2.post - inf.cont.2.pre) / mean(w.cont) # R: att.inf.func = inf.treat - inf.cont psi_gt = psi_treat_post - psi_cont_1_gt / mean_w_D psi_gs = -psi_treat_pre - psi_cont_1_gs / mean_w_D psi_ct = -psi_cont_2_ct / mean_w_D psi_cs = psi_cont_2_cs / mean_w_D psi_all = np.concatenate([psi_gt, psi_gs, psi_ct, psi_cs]) # ================================================================= # Convert to library convention: phi = psi / n_all # se = sqrt(sum(phi^2)) == sqrt(sum(psi^2)) / n_all # ================================================================= inf_all = psi_all / n_all se = float(np.sqrt(np.sum(inf_all**2))) idx_all = None # caller builds idx from masks return att, se, inf_all, idx_all def _ipw_estimation_rc( self, y_gt, y_gs, y_ct, y_cs, X_gt, X_gs, X_ct, X_cs, sw_gt=None, sw_gs=None, sw_ct=None, sw_cs=None, context_label: str = "", epv_diagnostics_out: Optional[dict] = None, ): """ Cross-sectional IPW estimation for ATT(g,t). Propensity score P(G=g | X) estimated on pooled treated+control observations from both periods. Reweight controls in each period. IF convention ------------- Intermediate terms use R's unnormalized psi_i convention throughout (R: ``ipw_did_rc``). R computes SE as ``sd(psi) / sqrt(n)``. At the end we convert to the library's pre-scaled phi_i = psi_i / n convention where ``se = sqrt(sum(phi^2))``, used by the aggregation/bootstrap layer. Returns (att, se, inf_func_concat, idx_concat). """ n_gt = len(y_gt) n_gs = len(y_gs) n_ct = len(y_ct) n_cs = len(y_cs) n_all = n_gt + n_gs + n_ct + n_cs # Pool treated and control for propensity score X_all = np.vstack([X_gt, X_gs, X_ct, X_cs]) D_all = np.concatenate([np.ones(n_gt + n_gs), np.zeros(n_ct + n_cs)]) sw_all = None if sw_gt is not None: sw_all = np.concatenate([sw_gt, sw_gs, sw_ct, sw_cs]) ps_fallback_used = False diag = {} try: beta_logistic, pscore = solve_logit( X_all, D_all, rank_deficient_action=self.rank_deficient_action, weights=sw_all, epv_threshold=self.epv_threshold, context_label=context_label, diagnostics_out=diag, ) _check_propensity_diagnostics(pscore, self.pscore_trim) except (np.linalg.LinAlgError, ValueError): if self.pscore_fallback == "error" or self.rank_deficient_action == "error": raise ctx = f" for {context_label}" if context_label else "" warnings.warn( f"Propensity score estimation failed{ctx} (RCS IPW). " f"Falling back to unconditional propensity " f"(all covariates dropped for this cell). " f"Consider estimation_method='reg' to avoid " f"propensity scores entirely.", UserWarning, stacklevel=4, ) if sw_all is not None: pos = sw_all > 0 p_treat = float(np.average(D_all[pos], weights=sw_all[pos])) else: p_treat = (n_gt + n_gs) / len(D_all) pscore = np.full(len(D_all), p_treat) ps_fallback_used = True if epv_diagnostics_out is not None and diag: epv_diagnostics_out.update(diag) # Clip propensity scores pscore = np.clip(pscore, self.pscore_trim, 1 - self.pscore_trim) # Split propensity scores (treated ps not used -- only control IPW weights) ps_ct = pscore[n_gt + n_gs : n_gt + n_gs + n_ct] ps_cs = pscore[n_gt + n_gs + n_ct :] # IPW weights for controls (R: w1.x = ps / (1 - ps)) w_ct = ps_ct / (1 - ps_ct) w_cs = ps_cs / (1 - ps_cs) if sw_gt is not None: w_ct = sw_ct * w_ct w_cs = sw_cs * w_cs # R: mean(w.treat.post), mean(w.treat.pre), mean(w.ipw.ct), mean(w.ipw.cs) if sw_gt is not None: sum_w_treat_post = np.sum(sw_gt) sum_w_treat_pre = np.sum(sw_gs) else: sum_w_treat_post = float(n_gt) sum_w_treat_pre = float(n_gs) mean_w_treat_post = sum_w_treat_post / n_all mean_w_treat_pre = sum_w_treat_pre / n_all sum_w_ct = np.sum(w_ct) sum_w_cs = np.sum(w_cs) mean_w_ct = sum_w_ct / n_all mean_w_cs = sum_w_cs / n_all # Hajek-normalized weights (R normalizes by sum for point estimate) w_ct_norm = w_ct / sum_w_ct if sum_w_ct > 0 else w_ct w_cs_norm = w_cs / sum_w_cs if sum_w_cs > 0 else w_cs if sw_gt is not None: sw_gt_norm = sw_gt / sum_w_treat_post sw_gs_norm = sw_gs / sum_w_treat_pre mu_gt = float(np.sum(sw_gt_norm * y_gt)) mu_gs = float(np.sum(sw_gs_norm * y_gs)) else: mu_gt = float(np.mean(y_gt)) mu_gs = float(np.mean(y_gs)) mu_ct_ipw = float(np.sum(w_ct_norm * y_ct)) mu_cs_ipw = float(np.sum(w_cs_norm * y_cs)) att = (mu_gt - mu_ct_ipw) - (mu_gs - mu_cs_ipw) # ================================================================= # Influence function in R's unnormalized psi convention # (R: ipw_did_rc.R, psi = n * phi) # ================================================================= # --- Treated psi (R: eta.treat.post, eta.treat.pre) --- # R: w.treat.post * (y - eta.treat.post) / mean(w.treat.post) if sw_gt is not None: psi_gt = sw_gt * (y_gt - mu_gt) / mean_w_treat_post psi_gs = -sw_gs * (y_gs - mu_gs) / mean_w_treat_pre else: psi_gt = (y_gt - mu_gt) / mean_w_treat_post psi_gs = -(y_gs - mu_gs) / mean_w_treat_pre # --- Control psi (R: eta.cont.post, eta.cont.pre) --- # R: w.ipw * (y - eta.cont) / mean(w.ipw) psi_ct = -w_ct * (y_ct - mu_ct_ipw) / mean_w_ct if mean_w_ct > 0 else np.zeros(n_ct) psi_cs = w_cs * (y_cs - mu_cs_ipw) / mean_w_cs if mean_w_cs > 0 else np.zeros(n_cs) psi_all = np.concatenate([psi_gt, psi_gs, psi_ct, psi_cs]) # Convert leading psi to phi: phi = psi / n_all inf_all = psi_all / n_all if not ps_fallback_used: # --- PS IF correction — psi convention, convert to phi --- X_all_int = np.column_stack([np.ones(n_all), X_all]) W_ps = pscore * (1 - pscore) if sw_all is not None: W_ps = W_ps * sw_all # R: Hessian.ps = crossprod(X * sqrt(W)) / n H_psi = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all H_psi_inv = _safe_inv(H_psi, tracker=self._safe_inv_tracker) score_ps = (D_all - pscore)[:, None] * X_all_int if sw_all is not None: score_ps = score_ps * sw_all[:, None] # R: asy.lin.rep.ps = score.ps %*% Hessian.ps (psi scale, O(1) per obs) asy_lin_rep_psi = score_ps @ H_psi_inv # PS nuisance correction in psi convention # R: M2 = colMeans(w_ipw * (y-mu) * X) ipw_resid_ct = w_ct_norm * (y_ct - mu_ct_ipw) ipw_resid_cs = w_cs_norm * (y_cs - mu_cs_ipw) ct_slice = slice(n_gt + n_gs, n_gt + n_gs + n_ct) cs_slice = slice(n_gt + n_gs + n_ct, None) M2 = np.zeros(X_all_int.shape[1]) M2 += np.sum(ipw_resid_ct[:, None] * X_all_int[ct_slice], axis=0) M2 -= np.sum(ipw_resid_cs[:, None] * X_all_int[cs_slice], axis=0) # psi-scale correction, convert to phi # Subtract: R adds PS correction to inf.control, then att = treat - control inf_all = inf_all - (asy_lin_rep_psi @ M2) / n_all # ================================================================= # SE from phi: se = sqrt(sum(phi^2)) # Equivalent to R's sqrt(sum(psi^2)) / n when mean(psi) approx 0. # ================================================================= se = float(np.sqrt(np.sum(inf_all**2))) idx_all = None return att, se, inf_all, idx_all def _doubly_robust_rc( self, y_gt, y_gs, y_ct, y_cs, X_gt, X_gs, X_ct, X_cs, sw_gt=None, sw_gs=None, sw_ct=None, sw_cs=None, context_label: str = "", epv_diagnostics_out: Optional[dict] = None, ): """ Cross-sectional doubly robust estimation for ATT(g,t). Matches R DRDID::drdid_rc (Sant'Anna & Zhao 2020, Eq 3.1). Locally efficient DR estimator with 4 OLS fits (control pre/post, treated pre/post) plus propensity score. IF convention ------------- Intermediate terms use R's unnormalized psi_i convention throughout (R: ``drdid_rc``). R computes SE as ``sd(psi) / sqrt(n)``. At the end we convert to the library's pre-scaled phi_i = psi_i / n convention where ``se = sqrt(sum(phi^2))``, used by the aggregation/bootstrap layer. Returns (att, se, inf_func_concat, idx_concat). """ n_gt = len(y_gt) n_gs = len(y_gs) n_ct = len(y_ct) n_cs = len(y_cs) n_all = n_gt + n_gs + n_ct + n_cs # ===================================================================== # 1. Outcome regression: 4 OLS fits # ===================================================================== # Control OLS: E[Y|X, D=0, T=t] and E[Y|X, D=0, T=s] beta_ct, resid_ct = _linear_regression( X_ct, y_ct, rank_deficient_action=self.rank_deficient_action, weights=sw_ct, ) beta_ct = np.where(np.isfinite(beta_ct), beta_ct, 0.0) beta_cs, resid_cs = _linear_regression( X_cs, y_cs, rank_deficient_action=self.rank_deficient_action, weights=sw_cs, ) beta_cs = np.where(np.isfinite(beta_cs), beta_cs, 0.0) # Treated OLS: E[Y|X, D=1, T=t] and E[Y|X, D=1, T=s] beta_gt, resid_gt = _linear_regression( X_gt, y_gt, rank_deficient_action=self.rank_deficient_action, weights=sw_gt, ) beta_gt = np.where(np.isfinite(beta_gt), beta_gt, 0.0) beta_gs, resid_gs = _linear_regression( X_gs, y_gs, rank_deficient_action=self.rank_deficient_action, weights=sw_gs, ) beta_gs = np.where(np.isfinite(beta_gs), beta_gs, 0.0) # Intercept-augmented design matrices X_gt_int = np.column_stack([np.ones(n_gt), X_gt]) X_gs_int = np.column_stack([np.ones(n_gs), X_gs]) X_ct_int = np.column_stack([np.ones(n_ct), X_ct]) X_cs_int = np.column_stack([np.ones(n_cs), X_cs]) # Control OR predictions for all groups mu0_post_gt = X_gt_int @ beta_ct # mu_{0,1}(X) for treated-post mu0_pre_gt = X_gt_int @ beta_cs # mu_{0,0}(X) for treated-post mu0_post_gs = X_gs_int @ beta_ct # mu_{0,1}(X) for treated-pre mu0_pre_gs = X_gs_int @ beta_cs # mu_{0,0}(X) for treated-pre mu0_post_ct = X_ct_int @ beta_ct # mu_{0,1}(X) for control-post mu0_pre_ct = X_ct_int @ beta_cs # mu_{0,0}(X) for control-post mu0_post_cs = X_cs_int @ beta_ct # mu_{0,1}(X) for control-pre mu0_pre_cs = X_cs_int @ beta_cs # mu_{0,0}(X) for control-pre # Treated OR predictions for all groups (for local efficiency adjustment) mu1_post_gt = X_gt_int @ beta_gt # mu_{1,1}(X) for treated-post mu1_pre_gt = X_gt_int @ beta_gs # mu_{1,0}(X) for treated-post mu1_post_gs = X_gs_int @ beta_gt # mu_{1,1}(X) for treated-pre mu1_pre_gs = X_gs_int @ beta_gs # mu_{1,0}(X) for treated-pre # mu_{0,Y}(T_i, X_i): control OR evaluated at own period mu0Y_gt = mu0_post_gt # treated-post: use post control model mu0Y_gs = mu0_pre_gs # treated-pre: use pre control model mu0Y_ct = mu0_post_ct # control-post: use post control model mu0Y_cs = mu0_pre_cs # control-pre: use pre control model # ===================================================================== # 2. Propensity score # ===================================================================== X_all = np.vstack([X_gt, X_gs, X_ct, X_cs]) D_all = np.concatenate([np.ones(n_gt + n_gs), np.zeros(n_ct + n_cs)]) sw_all = None if sw_gt is not None: sw_all = np.concatenate([sw_gt, sw_gs, sw_ct, sw_cs]) ps_fallback_used = False diag = {} try: beta_logistic, pscore = solve_logit( X_all, D_all, rank_deficient_action=self.rank_deficient_action, weights=sw_all, epv_threshold=self.epv_threshold, context_label=context_label, diagnostics_out=diag, ) _check_propensity_diagnostics(pscore, self.pscore_trim) except (np.linalg.LinAlgError, ValueError): if self.pscore_fallback == "error" or self.rank_deficient_action == "error": raise ctx = f" for {context_label}" if context_label else "" warnings.warn( f"Propensity score estimation failed{ctx} (RCS DR). " f"Falling back to unconditional propensity " f"(propensity model ignores covariates; outcome " f"regression still uses them). " f"Consider estimation_method='reg' to avoid " f"propensity scores entirely.", UserWarning, stacklevel=4, ) if sw_all is not None: pos = sw_all > 0 p_treat = float(np.average(D_all[pos], weights=sw_all[pos])) else: p_treat = (n_gt + n_gs) / len(D_all) pscore = np.full(len(D_all), p_treat) ps_fallback_used = True if epv_diagnostics_out is not None and diag: epv_diagnostics_out.update(diag) pscore = np.clip(pscore, self.pscore_trim, 1 - self.pscore_trim) # Split propensity scores per group ps_gt = pscore[:n_gt] ps_gs = pscore[n_gt : n_gt + n_gs] ps_ct = pscore[n_gt + n_gs : n_gt + n_gs + n_ct] ps_cs = pscore[n_gt + n_gs + n_ct :] # ===================================================================== # 3. Group weights and R-convention means # ===================================================================== if sw_gt is not None: w_treat_post = sw_gt w_treat_pre = sw_gs w_D_gt = sw_gt w_D_gs = sw_gs else: w_treat_post = np.ones(n_gt) w_treat_pre = np.ones(n_gs) w_D_gt = np.ones(n_gt) w_D_gs = np.ones(n_gs) sum_w_treat_post = np.sum(w_treat_post) sum_w_treat_pre = np.sum(w_treat_pre) sum_w_D = np.sum(w_D_gt) + np.sum(w_D_gs) # R: mean(w) = sum(w) / n -- used in psi normalizers mean_w_treat_post = sum_w_treat_post / n_all mean_w_treat_pre = sum_w_treat_pre / n_all mean_w_D = sum_w_D / n_all # IPW control weights: sw * ps/(1-ps) for controls w_ipw_ct = ps_ct / (1 - ps_ct) w_ipw_cs = ps_cs / (1 - ps_cs) if sw_ct is not None: w_ipw_ct = sw_ct * w_ipw_ct w_ipw_cs = sw_cs * w_ipw_cs sum_w_ipw_ct = np.sum(w_ipw_ct) sum_w_ipw_cs = np.sum(w_ipw_cs) mean_w_ipw_ct = sum_w_ipw_ct / n_all mean_w_ipw_cs = sum_w_ipw_cs / n_all # ===================================================================== # 4. Point estimate: tau_1 (AIPW using control ORs) # ===================================================================== # Hajek-normalized means of (y - mu0Y) per group eta_treat_post = np.sum(w_treat_post * (y_gt - mu0Y_gt)) / sum_w_treat_post eta_treat_pre = np.sum(w_treat_pre * (y_gs - mu0Y_gs)) / sum_w_treat_pre eta_cont_post = ( np.sum(w_ipw_ct * (y_ct - mu0Y_ct)) / sum_w_ipw_ct if sum_w_ipw_ct > 0 else 0.0 ) eta_cont_pre = ( np.sum(w_ipw_cs * (y_cs - mu0Y_cs)) / sum_w_ipw_cs if sum_w_ipw_cs > 0 else 0.0 ) tau_1 = (eta_treat_post - eta_cont_post) - (eta_treat_pre - eta_cont_pre) # ===================================================================== # 5. Point estimate: local efficiency adjustment (tau_2) # ===================================================================== # Differences mu_{1,t}(X) - mu_{0,t}(X) for treated obs or_diff_post_gt = mu1_post_gt - mu0_post_gt # at treated-post or_diff_post_gs = mu1_post_gs - mu0_post_gs # at treated-pre or_diff_pre_gt = mu1_pre_gt - mu0_pre_gt # at treated-post or_diff_pre_gs = mu1_pre_gs - mu0_pre_gs # at treated-pre # att_d_post = mean(w_D * (mu1_post - mu0_post)) / mean(w_D) -- all treated att_d_post = (np.sum(w_D_gt * or_diff_post_gt) + np.sum(w_D_gs * or_diff_post_gs)) / sum_w_D # att_dt1_post -- treated-post only att_dt1_post = np.sum(w_treat_post * or_diff_post_gt) / sum_w_treat_post # att_d_pre -- all treated att_d_pre = (np.sum(w_D_gt * or_diff_pre_gt) + np.sum(w_D_gs * or_diff_pre_gs)) / sum_w_D # att_dt0_pre -- treated-pre only att_dt0_pre = np.sum(w_treat_pre * or_diff_pre_gs) / sum_w_treat_pre tau_2 = (att_d_post - att_dt1_post) - (att_d_pre - att_dt0_pre) att = float(tau_1 + tau_2) # ===================================================================== # 6. Influence function in R's unnormalized psi convention # (R: drdid_rc.R, psi = n * phi) # ===================================================================== # --- tau_1: treated psi (R: eta.treat.post / mean(w.treat.post)) --- # R: w.treat.post * (y - mu0Y - eta.treat.post) / mean(w.treat.post) psi_treat_post = w_treat_post * (y_gt - mu0Y_gt - eta_treat_post) / mean_w_treat_post psi_treat_pre = w_treat_pre * (y_gs - mu0Y_gs - eta_treat_pre) / mean_w_treat_pre # --- tau_1: control psi (R: eta.cont.post / mean(w.ipw)) --- # R: w.ipw * (y - mu0Y - eta.cont) / mean(w.ipw) psi_cont_post_ct = ( w_ipw_ct * (y_ct - mu0Y_ct - eta_cont_post) / mean_w_ipw_ct if mean_w_ipw_ct > 0 else np.zeros(n_ct) ) psi_cont_pre_cs = ( w_ipw_cs * (y_cs - mu0Y_cs - eta_cont_pre) / mean_w_ipw_cs if mean_w_ipw_cs > 0 else np.zeros(n_cs) ) # tau_1 psi per group psi_gt_tau1 = psi_treat_post psi_gs_tau1 = -psi_treat_pre psi_ct_tau1 = -psi_cont_post_ct psi_cs_tau1 = psi_cont_pre_cs # ===================================================================== # 7. tau_2 leading terms (R: att.d.post, att.dt1.post, etc.) # ===================================================================== # R: w.D * (or_diff - att.d.post) / mean(w.D) psi_d_post_gt = w_D_gt * (or_diff_post_gt - att_d_post) / mean_w_D psi_d_post_gs = w_D_gs * (or_diff_post_gs - att_d_post) / mean_w_D # R: w.treat.post * (or_diff - att.dt1.post) / mean(w.treat.post) psi_dt1_post = w_treat_post * (or_diff_post_gt - att_dt1_post) / mean_w_treat_post # R: w.D * (or_diff_pre - att.d.pre) / mean(w.D) psi_d_pre_gt = w_D_gt * (or_diff_pre_gt - att_d_pre) / mean_w_D psi_d_pre_gs = w_D_gs * (or_diff_pre_gs - att_d_pre) / mean_w_D # R: w.treat.pre * (or_diff_pre - att.dt0.pre) / mean(w.treat.pre) psi_dt0_pre = w_treat_pre * (or_diff_pre_gs - att_dt0_pre) / mean_w_treat_pre # tau_2 psi per group (controls contribute zero) psi_gt_tau2 = (psi_d_post_gt - psi_dt1_post) - psi_d_pre_gt psi_gs_tau2 = psi_d_post_gs - (-psi_dt0_pre + psi_d_pre_gs) # ===================================================================== # 8. Combined plug-in psi (before nuisance corrections) # ===================================================================== psi_gt = psi_gt_tau1 + psi_gt_tau2 psi_gs = psi_gs_tau1 + psi_gs_tau2 psi_ct = psi_ct_tau1 psi_cs = psi_cs_tau1 psi_all = np.concatenate([psi_gt, psi_gs, psi_ct, psi_cs]) # ================================================================= # Convert leading psi to library phi convention: phi = psi / n_all # ================================================================= inf_all = psi_all / n_all # ===================================================================== # 9. PS nuisance correction — psi convention, convert to phi # ===================================================================== X_all_int = np.column_stack([np.ones(n_all), X_all]) if not ps_fallback_used: W_ps = pscore * (1 - pscore) if sw_all is not None: W_ps = W_ps * sw_all # R: Hessian.ps = crossprod(X * sqrt(W)) / n H_psi = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all H_psi_inv = _safe_inv(H_psi, tracker=self._safe_inv_tracker) score_ps = (D_all - pscore)[:, None] * X_all_int if sw_all is not None: score_ps = score_ps * sw_all[:, None] # R: asy.lin.rep.ps = score.ps %*% Hessian.ps (psi scale, O(1) per obs) asy_lin_rep_psi = score_ps @ H_psi_inv # R: M2 = colMeans(w_ipw * dr_resid / mean(w_ipw) * X) ct_slice = slice(n_gt + n_gs, n_gt + n_gs + n_ct) cs_slice = slice(n_gt + n_gs + n_ct, None) dr_resid_ct = y_ct - mu0Y_ct - eta_cont_post dr_resid_cs = y_cs - mu0Y_cs - eta_cont_pre M2 = np.zeros(X_all_int.shape[1]) if sum_w_ipw_ct > 0: M2 -= np.sum( ((w_ipw_ct * dr_resid_ct / sum_w_ipw_ct)[:, None] * X_all_int[ct_slice]), axis=0, ) if sum_w_ipw_cs > 0: M2 += np.sum( ((w_ipw_cs * dr_resid_cs / sum_w_ipw_cs)[:, None] * X_all_int[cs_slice]), axis=0, ) # psi-scale correction, convert to phi inf_all = inf_all + (asy_lin_rep_psi @ M2) / n_all # ===================================================================== # 10. Control OR nuisance corrections (phi-scale) # ===================================================================== W_ct_vals = sw_ct if sw_ct is not None else np.ones(n_ct) W_cs_vals = sw_cs if sw_cs is not None else np.ones(n_cs) bread_ct = _safe_inv( X_ct_int.T @ (W_ct_vals[:, None] * X_ct_int), tracker=self._safe_inv_tracker, ) bread_cs = _safe_inv( X_cs_int.T @ (W_cs_vals[:, None] * X_cs_int), tracker=self._safe_inv_tracker, ) # R: asy.lin.rep.ols (per-obs OLS score * bread) asy_lin_rep_ct = (W_ct_vals * resid_ct)[:, None] * X_ct_int @ bread_ct asy_lin_rep_cs = (W_cs_vals * resid_cs)[:, None] * X_cs_int @ bread_cs # M1 for control-post model (beta_ct): gradient from tau_1 + tau_2 # tau_1: -w_treat_post*X/sum_w_treat_post (eta_treat_post via mu0Y_gt) # +w_ipw_ct*X/sum_w_ipw_ct (eta_cont_post via mu0Y_ct) # tau_2: -w_D*X/sum_w_D (att_d_post via mu0_post at all treated) # +w_treat_post*X/sum_w_treat_post (att_dt1_post via mu0_post) M1_ct = np.zeros(X_all_int.shape[1]) M1_ct -= np.sum(w_treat_post[:, None] * X_gt_int, axis=0) / sum_w_treat_post if sum_w_ipw_ct > 0: M1_ct += np.sum(w_ipw_ct[:, None] * X_ct_int, axis=0) / sum_w_ipw_ct M1_ct -= ( np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0) ) / sum_w_D M1_ct += np.sum(w_treat_post[:, None] * X_gt_int, axis=0) / sum_w_treat_post # M1 for control-pre model (beta_cs) M1_cs = np.zeros(X_all_int.shape[1]) M1_cs += np.sum(w_treat_pre[:, None] * X_gs_int, axis=0) / sum_w_treat_pre if sum_w_ipw_cs > 0: M1_cs -= np.sum(w_ipw_cs[:, None] * X_cs_int, axis=0) / sum_w_ipw_cs M1_cs += ( np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0) ) / sum_w_D M1_cs -= np.sum(w_treat_pre[:, None] * X_gs_int, axis=0) / sum_w_treat_pre inf_all[n_gt + n_gs : n_gt + n_gs + n_ct] += asy_lin_rep_ct @ M1_ct inf_all[n_gt + n_gs + n_ct :] += asy_lin_rep_cs @ M1_cs # ===================================================================== # 11. Treated OR nuisance corrections (phi-scale) # ===================================================================== W_gt_vals = sw_gt if sw_gt is not None else np.ones(n_gt) W_gs_vals = sw_gs if sw_gs is not None else np.ones(n_gs) bread_gt = _safe_inv( X_gt_int.T @ (W_gt_vals[:, None] * X_gt_int), tracker=self._safe_inv_tracker, ) bread_gs = _safe_inv( X_gs_int.T @ (W_gs_vals[:, None] * X_gs_int), tracker=self._safe_inv_tracker, ) asy_lin_rep_gt = (W_gt_vals * resid_gt)[:, None] * X_gt_int @ bread_gt asy_lin_rep_gs = (W_gs_vals * resid_gs)[:, None] * X_gs_int @ bread_gs # M1 for treated-post model (beta_gt): mu_{1,1}(X) # From att_d_post: +w_D*X/sum_w_D (all treated) # From att_dt1_post: -w_treat_post*X/sum_w_treat_post (treated-post) M1_gt = np.zeros(X_all_int.shape[1]) M1_gt += ( np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0) ) / sum_w_D M1_gt -= np.sum(w_treat_post[:, None] * X_gt_int, axis=0) / sum_w_treat_post # M1 for treated-pre model (beta_gs): mu_{1,0}(X) # From att_d_pre: -w_D*X/sum_w_D # From att_dt0_pre: +w_treat_pre*X/sum_w_treat_pre M1_gs = np.zeros(X_all_int.shape[1]) M1_gs -= ( np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0) ) / sum_w_D M1_gs += np.sum(w_treat_pre[:, None] * X_gs_int, axis=0) / sum_w_treat_pre inf_all[:n_gt] += asy_lin_rep_gt @ M1_gt inf_all[n_gt : n_gt + n_gs] += asy_lin_rep_gs @ M1_gs # ================================================================= # SE from phi: se = sqrt(sum(phi^2)) # Equivalent to R's sqrt(sum(psi^2)) / n when mean(psi) approx 0. # ================================================================= se = float(np.sqrt(np.sum(inf_all**2))) idx_all = None return att, se, inf_all, idx_all
[docs] def get_params(self) -> Dict[str, Any]: """Get estimator parameters (sklearn-compatible).""" return { "control_group": self.control_group, "anticipation": self.anticipation, "estimation_method": self.estimation_method, "alpha": self.alpha, "cluster": self.cluster, "n_bootstrap": self.n_bootstrap, "bootstrap_weights": self.bootstrap_weights, "seed": self.seed, "rank_deficient_action": self.rank_deficient_action, "base_period": self.base_period, "cband": self.cband, "pscore_trim": self.pscore_trim, "panel": self.panel, "epv_threshold": self.epv_threshold, "pscore_fallback": self.pscore_fallback, }
[docs] def set_params(self, **params) -> "CallawaySantAnna": """Set estimator parameters (sklearn-compatible).""" for key, value in params.items(): if hasattr(self, key): setattr(self, key, value) else: raise ValueError(f"Unknown parameter: {key}") return self
[docs] def summary(self) -> str: """Get summary of estimation results.""" if not self.is_fitted_: raise RuntimeError("Model must be fitted before calling summary()") assert self.results_ is not None return self.results_.summary()
[docs] def print_summary(self) -> None: """Print summary to stdout.""" print(self.summary())