Source code for diff_diff.staggered

"""
Staggered Difference-in-Differences estimators.

Implements modern methods for DiD with variation in treatment timing,
including the Callaway-Sant'Anna (2021) estimator.
"""

import warnings
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
from scipy import optimize

from diff_diff.linalg import solve_ols
from diff_diff.utils import safe_inference

# Import from split modules
from diff_diff.staggered_results import (
    GroupTimeEffect,
    CallawaySantAnnaResults,
)
from diff_diff.staggered_bootstrap import (
    CSBootstrapResults,
    CallawaySantAnnaBootstrapMixin,
)
from diff_diff.staggered_aggregation import (
    CallawaySantAnnaAggregationMixin,
)

# Re-export for backward compatibility
__all__ = [
    "CallawaySantAnna",
    "CallawaySantAnnaResults",
    "CSBootstrapResults",
    "GroupTimeEffect",
]

# Type alias for pre-computed structures
PrecomputedData = Dict[str, Any]


def _logistic_regression(
    X: np.ndarray,
    y: np.ndarray,
    max_iter: int = 100,
    tol: float = 1e-6,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Fit logistic regression using scipy optimize.

    Parameters
    ----------
    X : np.ndarray
        Feature matrix (n_samples, n_features). Intercept added automatically.
    y : np.ndarray
        Binary outcome (0/1).
    max_iter : int
        Maximum iterations.
    tol : float
        Convergence tolerance.

    Returns
    -------
    beta : np.ndarray
        Fitted coefficients (including intercept).
    probs : np.ndarray
        Predicted probabilities.
    """
    n, p = X.shape
    # Add intercept
    X_with_intercept = np.column_stack([np.ones(n), X])

    def neg_log_likelihood(beta: np.ndarray) -> float:
        z = np.dot(X_with_intercept, beta)
        # Clip to prevent overflow
        z = np.clip(z, -500, 500)
        log_lik = np.sum(y * z - np.log(1 + np.exp(z)))
        return -log_lik

    def gradient(beta: np.ndarray) -> np.ndarray:
        z = np.dot(X_with_intercept, beta)
        z = np.clip(z, -500, 500)
        probs = 1 / (1 + np.exp(-z))
        return -np.dot(X_with_intercept.T, y - probs)

    # Initialize with zeros
    beta_init = np.zeros(p + 1)

    result = optimize.minimize(
        neg_log_likelihood,
        beta_init,
        method='BFGS',
        jac=gradient,
        options={'maxiter': max_iter, 'gtol': tol}
    )

    beta = result.x
    z = np.dot(X_with_intercept, beta)
    z = np.clip(z, -500, 500)
    probs = 1 / (1 + np.exp(-z))

    return beta, probs


def _linear_regression(
    X: np.ndarray,
    y: np.ndarray,
    rank_deficient_action: str = "warn",
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Fit OLS regression.

    Parameters
    ----------
    X : np.ndarray
        Feature matrix (n_samples, n_features). Intercept added automatically.
    y : np.ndarray
        Outcome variable.
    rank_deficient_action : str, default "warn"
        Action when design matrix is rank-deficient:
        - "warn": Issue warning and drop linearly dependent columns (default)
        - "error": Raise ValueError
        - "silent": Drop columns silently without warning

    Returns
    -------
    beta : np.ndarray
        Fitted coefficients (including intercept).
    residuals : np.ndarray
        Residuals from the fit.
    """
    n = X.shape[0]
    # Add intercept
    X_with_intercept = np.column_stack([np.ones(n), X])

    # Use unified OLS backend (no vcov needed)
    beta, residuals, _ = solve_ols(
        X_with_intercept, y, return_vcov=False,
        rank_deficient_action=rank_deficient_action,
    )

    return beta, residuals


[docs] class CallawaySantAnna( CallawaySantAnnaBootstrapMixin, CallawaySantAnnaAggregationMixin, ): """ Callaway-Sant'Anna (2021) estimator for staggered Difference-in-Differences. This estimator handles DiD designs with variation in treatment timing (staggered adoption) and heterogeneous treatment effects. It avoids the bias of traditional two-way fixed effects (TWFE) estimators by: 1. Computing group-time average treatment effects ATT(g,t) for each cohort g (units first treated in period g) and time t. 2. Aggregating these to summary measures (overall ATT, event study, etc.) using appropriate weights. Parameters ---------- control_group : str, default="never_treated" Which units to use as controls: - "never_treated": Use only never-treated units (recommended) - "not_yet_treated": Use never-treated and not-yet-treated units anticipation : int, default=0 Number of periods before treatment where effects may occur. Set to > 0 if treatment effects can begin before the official treatment date. estimation_method : str, default="dr" Estimation method: - "dr": Doubly robust (recommended) - "ipw": Inverse probability weighting - "reg": Outcome regression alpha : float, default=0.05 Significance level for confidence intervals. cluster : str, optional Column name for cluster-robust standard errors. Defaults to unit-level clustering. n_bootstrap : int, default=0 Number of bootstrap iterations for inference. If 0, uses analytical standard errors. Recommended: 999 or more for reliable inference. .. note:: Memory Usage The bootstrap stores all weights in memory as a (n_bootstrap, n_units) float64 array. For large datasets, this can be significant: - 1K bootstrap × 10K units = ~80 MB - 10K bootstrap × 100K units = ~8 GB Consider reducing n_bootstrap if memory is constrained. bootstrap_weights : str, default="rademacher" Type of weights for multiplier bootstrap: - "rademacher": +1/-1 with equal probability (standard choice) - "mammen": Two-point distribution (asymptotically valid, matches skewness) - "webb": Six-point distribution (recommended when n_clusters < 20) bootstrap_weight_type : str, optional .. deprecated:: 1.0.1 Use ``bootstrap_weights`` instead. Will be removed in v3.0. seed : int, optional Random seed for reproducibility. rank_deficient_action : str, default="warn" Action when design matrix is rank-deficient (linearly dependent columns): - "warn": Issue warning and drop linearly dependent columns (default) - "error": Raise ValueError - "silent": Drop columns silently without warning base_period : str, default="varying" Method for selecting the base (reference) period for computing ATT(g,t). Options: - "varying": For pre-treatment periods (t < g - anticipation), use t-1 as base (consecutive comparisons). For post-treatment, use g-1-anticipation. Requires t-1 to exist in data. - "universal": Always use g-1-anticipation as base period. Both produce identical post-treatment effects. Matches R's did::att_gt() base_period parameter. Attributes ---------- results_ : CallawaySantAnnaResults Estimation results after calling fit(). is_fitted_ : bool Whether the model has been fitted. Examples -------- Basic usage: >>> import pandas as pd >>> from diff_diff import CallawaySantAnna >>> >>> # Panel data with staggered treatment >>> # 'first_treat' = period when unit was first treated (0 if never treated) >>> data = pd.DataFrame({ ... 'unit': [...], ... 'time': [...], ... 'outcome': [...], ... 'first_treat': [...] # 0 for never-treated, else first treatment period ... }) >>> >>> cs = CallawaySantAnna() >>> results = cs.fit(data, outcome='outcome', unit='unit', ... time='time', first_treat='first_treat') >>> >>> results.print_summary() With event study aggregation: >>> cs = CallawaySantAnna() >>> results = cs.fit(data, outcome='outcome', unit='unit', ... time='time', first_treat='first_treat', ... aggregate='event_study') >>> >>> # Plot event study >>> from diff_diff import plot_event_study >>> plot_event_study(results) With covariate adjustment (conditional parallel trends): >>> # When parallel trends only holds conditional on covariates >>> cs = CallawaySantAnna(estimation_method='dr') # doubly robust >>> results = cs.fit(data, outcome='outcome', unit='unit', ... time='time', first_treat='first_treat', ... covariates=['age', 'income']) >>> >>> # DR is recommended: consistent if either outcome model >>> # or propensity model is correctly specified Notes ----- The key innovation of Callaway & Sant'Anna (2021) is the disaggregated approach: instead of estimating a single treatment effect, they estimate ATT(g,t) for each cohort-time pair. This avoids the "forbidden comparison" problem where already-treated units act as controls. The ATT(g,t) is identified under parallel trends conditional on covariates: E[Y(0)_t - Y(0)_g-1 | G=g] = E[Y(0)_t - Y(0)_g-1 | C=1] where G=g indicates treatment cohort g and C=1 indicates control units. This uses g-1 as the base period, which applies to post-treatment (t >= g). With base_period="varying" (default), pre-treatment uses t-1 as base for consecutive comparisons useful in parallel trends diagnostics. References ---------- Callaway, B., & Sant'Anna, P. H. (2021). Difference-in-Differences with multiple time periods. Journal of Econometrics, 225(2), 200-230. """
[docs] def __init__( self, control_group: str = "never_treated", anticipation: int = 0, estimation_method: str = "dr", alpha: float = 0.05, cluster: Optional[str] = None, n_bootstrap: int = 0, bootstrap_weights: Optional[str] = None, bootstrap_weight_type: Optional[str] = None, seed: Optional[int] = None, rank_deficient_action: str = "warn", base_period: str = "varying", ): import warnings if control_group not in ["never_treated", "not_yet_treated"]: raise ValueError( f"control_group must be 'never_treated' or 'not_yet_treated', " f"got '{control_group}'" ) if estimation_method not in ["dr", "ipw", "reg"]: raise ValueError( f"estimation_method must be 'dr', 'ipw', or 'reg', " f"got '{estimation_method}'" ) # Handle bootstrap_weight_type deprecation if bootstrap_weight_type is not None: warnings.warn( "bootstrap_weight_type is deprecated and will be removed in v3.0. " "Use bootstrap_weights instead.", DeprecationWarning, stacklevel=2 ) if bootstrap_weights is None: bootstrap_weights = bootstrap_weight_type # Default to rademacher if neither specified if bootstrap_weights is None: bootstrap_weights = "rademacher" if bootstrap_weights not in ["rademacher", "mammen", "webb"]: raise ValueError( f"bootstrap_weights must be 'rademacher', 'mammen', or 'webb', " f"got '{bootstrap_weights}'" ) if rank_deficient_action not in ["warn", "error", "silent"]: raise ValueError( f"rank_deficient_action must be 'warn', 'error', or 'silent', " f"got '{rank_deficient_action}'" ) if base_period not in ["varying", "universal"]: raise ValueError( f"base_period must be 'varying' or 'universal', " f"got '{base_period}'" ) self.control_group = control_group self.anticipation = anticipation self.estimation_method = estimation_method self.alpha = alpha self.cluster = cluster self.n_bootstrap = n_bootstrap self.bootstrap_weights = bootstrap_weights # Keep bootstrap_weight_type for backward compatibility self.bootstrap_weight_type = bootstrap_weights self.seed = seed self.rank_deficient_action = rank_deficient_action self.base_period = base_period self.is_fitted_ = False self.results_: Optional[CallawaySantAnnaResults] = None
def _precompute_structures( self, df: pd.DataFrame, outcome: str, unit: str, time: str, first_treat: str, covariates: Optional[List[str]], time_periods: List[Any], treatment_groups: List[Any], ) -> PrecomputedData: """ Pre-compute data structures for efficient ATT(g,t) computation. This pivots data to wide format and pre-computes: - Outcome matrix (units x time periods) - Covariate matrix (units x covariates) from base period - Unit cohort membership masks - Control unit masks Returns ------- PrecomputedData Dictionary with pre-computed structures. """ # Get unique units and their cohort assignments unit_info = df.groupby(unit)[first_treat].first() all_units = unit_info.index.values unit_cohorts = unit_info.values n_units = len(all_units) # Create unit index mapping for fast lookups unit_to_idx = {u: i for i, u in enumerate(all_units)} # Pivot outcome to wide format: rows = units, columns = time periods outcome_wide = df.pivot(index=unit, columns=time, values=outcome) # Reindex to ensure all units are present (handles unbalanced panels) outcome_wide = outcome_wide.reindex(all_units) outcome_matrix = outcome_wide.values # Shape: (n_units, n_periods) period_to_col = {t: i for i, t in enumerate(outcome_wide.columns)} # Pre-compute cohort masks (boolean arrays) cohort_masks = {} for g in treatment_groups: cohort_masks[g] = (unit_cohorts == g) # Never-treated mask # np.inf was normalized to 0 in fit(), so the np.inf check is defensive only never_treated_mask = (unit_cohorts == 0) | (unit_cohorts == np.inf) # Pre-compute covariate matrices by time period if needed # (covariates are retrieved from the base period of each comparison) covariate_by_period = None if covariates: covariate_by_period = {} for t in time_periods: period_data = df[df[time] == t].set_index(unit) period_cov = period_data.reindex(all_units)[covariates] covariate_by_period[t] = period_cov.values # Shape: (n_units, n_covariates) return { 'all_units': all_units, 'unit_to_idx': unit_to_idx, 'unit_cohorts': unit_cohorts, 'outcome_matrix': outcome_matrix, 'period_to_col': period_to_col, 'cohort_masks': cohort_masks, 'never_treated_mask': never_treated_mask, 'covariate_by_period': covariate_by_period, 'time_periods': time_periods, } def _compute_att_gt_fast( self, precomputed: PrecomputedData, g: Any, t: Any, covariates: Optional[List[str]], ) -> Tuple[Optional[float], float, int, int, Optional[Dict[str, Any]]]: """ Compute ATT(g,t) using pre-computed data structures (fast version). Uses vectorized numpy operations on pre-pivoted outcome matrix instead of repeated pandas filtering. """ time_periods = precomputed['time_periods'] period_to_col = precomputed['period_to_col'] outcome_matrix = precomputed['outcome_matrix'] cohort_masks = precomputed['cohort_masks'] never_treated_mask = precomputed['never_treated_mask'] unit_cohorts = precomputed['unit_cohorts'] all_units = precomputed['all_units'] covariate_by_period = precomputed['covariate_by_period'] # Base period selection based on mode if self.base_period == "universal": # Universal: always use g - 1 - anticipation base_period_val = g - 1 - self.anticipation else: # varying if t < g - self.anticipation: # Pre-treatment: use t - 1 (consecutive comparison) base_period_val = t - 1 else: # Post-treatment: use g - 1 - anticipation base_period_val = g - 1 - self.anticipation if base_period_val not in period_to_col: # Base period must exist; no fallback to maintain methodological consistency return None, 0.0, 0, 0, None # Check if periods exist in the data if base_period_val not in period_to_col or t not in period_to_col: return None, 0.0, 0, 0, None base_col = period_to_col[base_period_val] post_col = period_to_col[t] # Get treated units mask (cohort g) treated_mask = cohort_masks[g] # Get control units mask if self.control_group == "never_treated": control_mask = never_treated_mask else: # not_yet_treated # Not yet treated at time t: never-treated OR (first_treat > t AND not cohort g) # Must exclude cohort g since they are the treated group for this ATT(g,t) control_mask = never_treated_mask | ( (unit_cohorts > t + self.anticipation) & (unit_cohorts != g) ) # Extract outcomes for base and post periods y_base = outcome_matrix[:, base_col] y_post = outcome_matrix[:, post_col] # Compute outcome changes (vectorized) outcome_change = y_post - y_base # Filter to units with valid data (no NaN in either period) valid_mask = ~(np.isnan(y_base) | np.isnan(y_post)) # Get treated and control with valid data treated_valid = treated_mask & valid_mask control_valid = control_mask & valid_mask n_treated = np.sum(treated_valid) n_control = np.sum(control_valid) if n_treated == 0 or n_control == 0: return None, 0.0, 0, 0, None # Extract outcome changes for treated and control treated_change = outcome_change[treated_valid] control_change = outcome_change[control_valid] # Get unit IDs for influence function treated_units = all_units[treated_valid] control_units = all_units[control_valid] # Get covariates if specified (from the base period) X_treated = None X_control = None if covariates and covariate_by_period is not None: cov_matrix = covariate_by_period[base_period_val] X_treated = cov_matrix[treated_valid] X_control = cov_matrix[control_valid] # Check for missing values if np.any(np.isnan(X_treated)) or np.any(np.isnan(X_control)): warnings.warn( f"Missing values in covariates for group {g}, time {t}. " "Falling back to unconditional estimation.", UserWarning, stacklevel=3, ) X_treated = None X_control = None # Estimation method if self.estimation_method == "reg": att_gt, se_gt, inf_func = self._outcome_regression( treated_change, control_change, X_treated, X_control ) elif self.estimation_method == "ipw": att_gt, se_gt, inf_func = self._ipw_estimation( treated_change, control_change, int(n_treated), int(n_control), X_treated, X_control ) else: # doubly robust att_gt, se_gt, inf_func = self._doubly_robust( treated_change, control_change, X_treated, X_control ) # Package influence function info with unit IDs for bootstrap n_t = int(n_treated) inf_func_info = { 'treated_units': list(treated_units), 'control_units': list(control_units), 'treated_inf': inf_func[:n_t], 'control_inf': inf_func[n_t:], } return att_gt, se_gt, int(n_treated), int(n_control), inf_func_info
[docs] def fit( self, data: pd.DataFrame, outcome: str, unit: str, time: str, first_treat: str, covariates: Optional[List[str]] = None, aggregate: Optional[str] = None, balance_e: Optional[int] = None, ) -> CallawaySantAnnaResults: """ Fit the Callaway-Sant'Anna estimator. Parameters ---------- data : pd.DataFrame Panel data with unit and time identifiers. outcome : str Name of outcome variable column. unit : str Name of unit identifier column. time : str Name of time period column. first_treat : str Name of column indicating when unit was first treated. Use 0 (or np.inf) for never-treated units. covariates : list, optional List of covariate column names for conditional parallel trends. aggregate : str, optional How to aggregate group-time effects: - None: Only compute ATT(g,t) (default) - "simple": Simple weighted average (overall ATT) - "event_study": Aggregate by relative time (event study) - "group": Aggregate by treatment cohort - "all": Compute all aggregations balance_e : int, optional For event study, balance the panel at relative time e. Ensures all groups contribute to each relative period. Returns ------- CallawaySantAnnaResults Object containing all estimation results. Raises ------ ValueError If required columns are missing or data validation fails. """ # Validate inputs required_cols = [outcome, unit, time, first_treat] if covariates: required_cols.extend(covariates) missing = [c for c in required_cols if c not in data.columns] if missing: raise ValueError(f"Missing columns: {missing}") # Create working copy df = data.copy() # Ensure numeric types df[time] = pd.to_numeric(df[time]) df[first_treat] = pd.to_numeric(df[first_treat]) # Standardize the first_treat column name for internal use # This avoids hardcoding column names in internal methods df['first_treat'] = df[first_treat] # Never-treated indicator (must precede treatment_groups to exclude np.inf) df['_never_treated'] = (df[first_treat] == 0) | (df[first_treat] == np.inf) # Normalize np.inf → 0 so all downstream `> 0` checks exclude never-treated df.loc[df[first_treat] == np.inf, first_treat] = 0 # Identify groups and time periods time_periods = sorted(df[time].unique()) treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0]) # Get unique units unit_info = df.groupby(unit).agg({ first_treat: 'first', '_never_treated': 'first' }).reset_index() n_treated_units = (unit_info[first_treat] > 0).sum() n_control_units = (unit_info['_never_treated']).sum() if n_control_units == 0: raise ValueError("No never-treated units found. Check 'first_treat' column.") # Pre-compute data structures for efficient ATT(g,t) computation precomputed = self._precompute_structures( df, outcome, unit, time, first_treat, covariates, time_periods, treatment_groups ) # Compute ATT(g,t) for each group-time combination group_time_effects = {} influence_func_info = {} # Store influence functions for bootstrap # Get minimum period for determining valid pre-treatment periods min_period = min(time_periods) for g in treatment_groups: # Compute valid periods including pre-treatment if self.base_period == "universal": # Universal: all periods except the base period (which is normalized to 0) universal_base = g - 1 - self.anticipation valid_periods = [t for t in time_periods if t != universal_base] else: # Varying: post-treatment + pre-treatment where t-1 exists valid_periods = [ t for t in time_periods if t >= g - self.anticipation or t > min_period ] for t in valid_periods: att_gt, se_gt, n_treat, n_ctrl, inf_info = self._compute_att_gt_fast( precomputed, g, t, covariates ) if att_gt is not None: t_stat, p_val, ci = safe_inference(att_gt, se_gt, alpha=self.alpha) group_time_effects[(g, t)] = { 'effect': att_gt, 'se': se_gt, 't_stat': t_stat, 'p_value': p_val, 'conf_int': ci, 'n_treated': n_treat, 'n_control': n_ctrl, } if inf_info is not None: influence_func_info[(g, t)] = inf_info if not group_time_effects: raise ValueError( "Could not estimate any group-time effects. " "Check that data has sufficient observations." ) # Compute overall ATT (simple aggregation) overall_att, overall_se = self._aggregate_simple( group_time_effects, influence_func_info, df, unit, precomputed ) overall_t, overall_p, overall_ci = safe_inference( overall_att, overall_se, alpha=self.alpha ) # Compute additional aggregations if requested event_study_effects = None group_effects = None if aggregate in ["event_study", "all"]: event_study_effects = self._aggregate_event_study( group_time_effects, influence_func_info, treatment_groups, time_periods, balance_e ) if aggregate in ["group", "all"]: group_effects = self._aggregate_by_group( group_time_effects, influence_func_info, treatment_groups ) # Run bootstrap inference if requested bootstrap_results = None if self.n_bootstrap > 0 and influence_func_info: bootstrap_results = self._run_multiplier_bootstrap( group_time_effects=group_time_effects, influence_func_info=influence_func_info, aggregate=aggregate, balance_e=balance_e, treatment_groups=treatment_groups, time_periods=time_periods, ) # Update estimates with bootstrap inference overall_se = bootstrap_results.overall_att_se overall_t = safe_inference(overall_att, overall_se, alpha=self.alpha)[0] overall_p = bootstrap_results.overall_att_p_value overall_ci = bootstrap_results.overall_att_ci # Update group-time effects with bootstrap SEs for gt in group_time_effects: if gt in bootstrap_results.group_time_ses: group_time_effects[gt]['se'] = bootstrap_results.group_time_ses[gt] group_time_effects[gt]['conf_int'] = bootstrap_results.group_time_cis[gt] group_time_effects[gt]['p_value'] = bootstrap_results.group_time_p_values[gt] effect = float(group_time_effects[gt]['effect']) se = float(group_time_effects[gt]['se']) group_time_effects[gt]['t_stat'] = safe_inference(effect, se, alpha=self.alpha)[0] # Update event study effects with bootstrap SEs if (event_study_effects is not None and bootstrap_results.event_study_ses is not None and bootstrap_results.event_study_cis is not None and bootstrap_results.event_study_p_values is not None): for e in event_study_effects: if e in bootstrap_results.event_study_ses: event_study_effects[e]['se'] = bootstrap_results.event_study_ses[e] event_study_effects[e]['conf_int'] = bootstrap_results.event_study_cis[e] p_val = bootstrap_results.event_study_p_values[e] event_study_effects[e]['p_value'] = p_val effect = float(event_study_effects[e]['effect']) se = float(event_study_effects[e]['se']) event_study_effects[e]['t_stat'] = safe_inference(effect, se, alpha=self.alpha)[0] # Update group effects with bootstrap SEs if (group_effects is not None and bootstrap_results.group_effect_ses is not None and bootstrap_results.group_effect_cis is not None and bootstrap_results.group_effect_p_values is not None): for g in group_effects: if g in bootstrap_results.group_effect_ses: group_effects[g]['se'] = bootstrap_results.group_effect_ses[g] group_effects[g]['conf_int'] = bootstrap_results.group_effect_cis[g] group_effects[g]['p_value'] = bootstrap_results.group_effect_p_values[g] effect = float(group_effects[g]['effect']) se = float(group_effects[g]['se']) group_effects[g]['t_stat'] = safe_inference(effect, se, alpha=self.alpha)[0] # Store results self.results_ = CallawaySantAnnaResults( group_time_effects=group_time_effects, overall_att=overall_att, overall_se=overall_se, overall_t_stat=overall_t, overall_p_value=overall_p, overall_conf_int=overall_ci, groups=treatment_groups, time_periods=time_periods, n_obs=len(df), n_treated_units=n_treated_units, n_control_units=n_control_units, alpha=self.alpha, control_group=self.control_group, base_period=self.base_period, event_study_effects=event_study_effects, group_effects=group_effects, bootstrap_results=bootstrap_results, ) self.is_fitted_ = True return self.results_
def _outcome_regression( self, treated_change: np.ndarray, control_change: np.ndarray, X_treated: Optional[np.ndarray] = None, X_control: Optional[np.ndarray] = None, ) -> Tuple[float, float, np.ndarray]: """ Estimate ATT using outcome regression. With covariates: 1. Regress outcome changes on covariates for control group 2. Predict counterfactual for treated using their covariates 3. ATT = mean(treated_change) - mean(predicted_counterfactual) Without covariates: Simple difference in means. """ n_t = len(treated_change) n_c = len(control_change) if X_treated is not None and X_control is not None and X_treated.shape[1] > 0: # Covariate-adjusted outcome regression # Fit regression on control units: E[Delta Y | X, D=0] beta, residuals = _linear_regression( X_control, control_change, rank_deficient_action=self.rank_deficient_action, ) # Predict counterfactual for treated units X_treated_with_intercept = np.column_stack([np.ones(n_t), X_treated]) predicted_control = np.dot(X_treated_with_intercept, beta) # ATT = mean(observed treated change - predicted counterfactual) att = np.mean(treated_change - predicted_control) # Standard error using sandwich estimator # Variance from treated: Var(Y_1 - m(X)) treated_residuals = treated_change - predicted_control var_t = np.var(treated_residuals, ddof=1) if n_t > 1 else 0.0 # Variance from control regression (residual variance) var_c = np.var(residuals, ddof=1) if n_c > 1 else 0.0 # Approximate SE (ignoring estimation error in beta for simplicity) se = np.sqrt(var_t / n_t + var_c / n_c) if (n_t > 0 and n_c > 0) else 0.0 # Influence function inf_treated = (treated_residuals - np.mean(treated_residuals)) / n_t inf_control = -residuals / n_c inf_func = np.concatenate([inf_treated, inf_control]) else: # Simple difference in means (no covariates) att = np.mean(treated_change) - np.mean(control_change) var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0 var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0 se = np.sqrt(var_t / n_t + var_c / n_c) if (n_t > 0 and n_c > 0) else 0.0 # Influence function (for aggregation) inf_treated = treated_change - np.mean(treated_change) inf_control = control_change - np.mean(control_change) inf_func = np.concatenate([inf_treated / n_t, -inf_control / n_c]) return att, se, inf_func def _ipw_estimation( self, treated_change: np.ndarray, control_change: np.ndarray, n_treated: int, n_control: int, X_treated: Optional[np.ndarray] = None, X_control: Optional[np.ndarray] = None, ) -> Tuple[float, float, np.ndarray]: """ Estimate ATT using inverse probability weighting. With covariates: 1. Estimate propensity score P(D=1|X) using logistic regression 2. Reweight control units to match treated covariate distribution 3. ATT = mean(treated) - weighted_mean(control) Without covariates: Simple difference in means with unconditional propensity weighting. """ n_t = len(treated_change) n_c = len(control_change) n_total = n_treated + n_control if X_treated is not None and X_control is not None and X_treated.shape[1] > 0: # Covariate-adjusted IPW estimation # Stack covariates and create treatment indicator X_all = np.vstack([X_treated, X_control]) D = np.concatenate([np.ones(n_t), np.zeros(n_c)]) # Estimate propensity scores using logistic regression try: _, pscore = _logistic_regression(X_all, D) except (np.linalg.LinAlgError, ValueError): # Fallback to unconditional if logistic regression fails warnings.warn( "Propensity score estimation failed. " "Falling back to unconditional estimation.", UserWarning, stacklevel=4, ) pscore = np.full(len(D), n_t / (n_t + n_c)) # Propensity scores for treated and control pscore_treated = pscore[:n_t] pscore_control = pscore[n_t:] # Clip propensity scores to avoid extreme weights pscore_control = np.clip(pscore_control, 0.01, 0.99) pscore_treated = np.clip(pscore_treated, 0.01, 0.99) # IPW weights for control units: p(X) / (1 - p(X)) # This reweights controls to have same covariate distribution as treated weights_control = pscore_control / (1 - pscore_control) weights_control = weights_control / np.sum(weights_control) # normalize # ATT = mean(treated) - weighted_mean(control) att = np.mean(treated_change) - np.sum(weights_control * control_change) # Compute standard error # Variance of treated mean var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0 # Variance of weighted control mean weighted_var_c = np.sum(weights_control * (control_change - np.sum(weights_control * control_change)) ** 2) se = np.sqrt(var_t / n_t + weighted_var_c) if (n_t > 0 and n_c > 0) else 0.0 # Influence function inf_treated = (treated_change - np.mean(treated_change)) / n_t inf_control = -weights_control * (control_change - np.sum(weights_control * control_change)) inf_func = np.concatenate([inf_treated, inf_control]) else: # Unconditional IPW (reduces to difference in means) p_treat = n_treated / n_total # unconditional propensity score att = np.mean(treated_change) - np.mean(control_change) var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0 var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0 # Adjusted variance for IPW se = np.sqrt(var_t / n_t + var_c * (1 - p_treat) / (n_c * p_treat)) if (n_t > 0 and n_c > 0 and p_treat > 0) else 0.0 # Influence function (for aggregation) inf_treated = (treated_change - np.mean(treated_change)) / n_t inf_control = (control_change - np.mean(control_change)) / n_c inf_func = np.concatenate([inf_treated, -inf_control]) return att, se, inf_func def _doubly_robust( self, treated_change: np.ndarray, control_change: np.ndarray, X_treated: Optional[np.ndarray] = None, X_control: Optional[np.ndarray] = None, ) -> Tuple[float, float, np.ndarray]: """ Estimate ATT using doubly robust estimation. With covariates: Combines outcome regression and IPW for double robustness. The estimator is consistent if either the outcome model OR the propensity model is correctly specified. ATT_DR = (1/n_t) * sum_i[D_i * (Y_i - m(X_i))] + (1/n_t) * sum_i[(1-D_i) * w_i * (m(X_i) - Y_i)] where m(X) is the outcome model and w_i are IPW weights. Without covariates: Reduces to simple difference in means. """ n_t = len(treated_change) n_c = len(control_change) if X_treated is not None and X_control is not None and X_treated.shape[1] > 0: # Doubly robust estimation with covariates # Step 1: Outcome regression - fit E[Delta Y | X] on control beta, _ = _linear_regression( X_control, control_change, rank_deficient_action=self.rank_deficient_action, ) # Predict counterfactual for both treated and control X_treated_with_intercept = np.column_stack([np.ones(n_t), X_treated]) X_control_with_intercept = np.column_stack([np.ones(n_c), X_control]) m_treated = np.dot(X_treated_with_intercept, beta) m_control = np.dot(X_control_with_intercept, beta) # Step 2: Propensity score estimation X_all = np.vstack([X_treated, X_control]) D = np.concatenate([np.ones(n_t), np.zeros(n_c)]) try: _, pscore = _logistic_regression(X_all, D) except (np.linalg.LinAlgError, ValueError): # Fallback to unconditional if logistic regression fails pscore = np.full(len(D), n_t / (n_t + n_c)) pscore_control = pscore[n_t:] # Clip propensity scores pscore_control = np.clip(pscore_control, 0.01, 0.99) # IPW weights for control: p(X) / (1 - p(X)) weights_control = pscore_control / (1 - pscore_control) # Step 3: Doubly robust ATT # ATT = mean(treated - m(X_treated)) # + weighted_mean_control((m(X) - Y) * weight) att_treated_part = np.mean(treated_change - m_treated) # Augmentation term from control augmentation = np.sum(weights_control * (m_control - control_change)) / n_t att = att_treated_part + augmentation # Step 4: Standard error using influence function # Influence function for DR estimator psi_treated = (treated_change - m_treated - att) / n_t psi_control = (weights_control * (m_control - control_change)) / n_t # Variance is sum of squared influence functions var_psi = np.sum(psi_treated ** 2) + np.sum(psi_control ** 2) se = np.sqrt(var_psi) if var_psi > 0 else 0.0 # Full influence function inf_func = np.concatenate([psi_treated, psi_control]) else: # Without covariates, DR simplifies to difference in means att = np.mean(treated_change) - np.mean(control_change) var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0 var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0 se = np.sqrt(var_t / n_t + var_c / n_c) if (n_t > 0 and n_c > 0) else 0.0 # Influence function for DR estimator inf_treated = (treated_change - np.mean(treated_change)) / n_t inf_control = (control_change - np.mean(control_change)) / n_c inf_func = np.concatenate([inf_treated, -inf_control]) return att, se, inf_func
[docs] def get_params(self) -> Dict[str, Any]: """Get estimator parameters (sklearn-compatible).""" return { "control_group": self.control_group, "anticipation": self.anticipation, "estimation_method": self.estimation_method, "alpha": self.alpha, "cluster": self.cluster, "n_bootstrap": self.n_bootstrap, "bootstrap_weights": self.bootstrap_weights, # Deprecated but kept for backward compatibility "bootstrap_weight_type": self.bootstrap_weight_type, "seed": self.seed, "rank_deficient_action": self.rank_deficient_action, "base_period": self.base_period, }
[docs] def set_params(self, **params) -> "CallawaySantAnna": """Set estimator parameters (sklearn-compatible).""" for key, value in params.items(): if hasattr(self, key): setattr(self, key, value) else: raise ValueError(f"Unknown parameter: {key}") return self
[docs] def summary(self) -> str: """Get summary of estimation results.""" if not self.is_fitted_: raise RuntimeError("Model must be fitted before calling summary()") assert self.results_ is not None return self.results_.summary()
[docs] def print_summary(self) -> None: """Print summary to stdout.""" print(self.summary())