Source code for diff_diff.sun_abraham

"""
Sun-Abraham Interaction-Weighted Estimator for staggered DiD.

Implements the estimator from Sun & Abraham (2021), "Estimating dynamic
treatment effects in event studies with heterogeneous treatment effects",
Journal of Econometrics.

This provides an alternative to Callaway-Sant'Anna using a saturated
regression with cohort × relative-time interactions.
"""

import warnings
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd

from diff_diff.bootstrap_utils import compute_effect_bootstrap_stats
from diff_diff.linalg import LinearRegression
from diff_diff.results import _format_survey_block, _get_significance_stars
from diff_diff.utils import (
    safe_inference,
)
from diff_diff.utils import (
    within_transform as _within_transform_util,
)


[docs] @dataclass class SunAbrahamResults: """ Results from Sun-Abraham (2021) interaction-weighted estimation. Attributes ---------- event_study_effects : dict Dictionary mapping relative time to effect dictionaries with keys: 'effect', 'se', 't_stat', 'p_value', 'conf_int', 'n_groups'. overall_att : float Overall average treatment effect (weighted average of post-treatment effects). overall_se : float Standard error of overall ATT. overall_t_stat : float T-statistic for overall ATT. overall_p_value : float P-value for overall ATT. overall_conf_int : tuple Confidence interval for overall ATT. cohort_weights : dict Dictionary mapping relative time to cohort weight dictionaries. groups : list List of treatment cohorts (first treatment periods). time_periods : list List of all time periods. n_obs : int Total number of observations. n_treated_units : int Number of ever-treated units. n_control_units : int Number of never-treated units. alpha : float Significance level used for confidence intervals. control_group : str Type of control group used. """ event_study_effects: Dict[int, Dict[str, Any]] overall_att: float overall_se: float overall_t_stat: float overall_p_value: float overall_conf_int: Tuple[float, float] cohort_weights: Dict[int, Dict[Any, float]] groups: List[Any] time_periods: List[Any] n_obs: int n_treated_units: int n_control_units: int alpha: float = 0.05 control_group: str = "never_treated" # Anticipation periods (``k``) used at fit time. Persisted so # downstream diagnostics (``BusinessReport`` / ``DiagnosticReport`` # / ``compute_pretrends_power``) can classify pre-period vs # anticipation-window coefficients without re-plumbing the kwarg # through every caller. anticipation: int = 0 bootstrap_results: Optional["SABootstrapResults"] = field(default=None, repr=False) cohort_effects: Optional[Dict[Tuple[Any, int], Dict[str, Any]]] = field( default=None, repr=False ) # Survey design metadata (SurveyMetadata instance from diff_diff.survey) survey_metadata: Optional[Any] = field(default=None) # --- Inference-field aliases (balance/external-adapter compatibility) --- @property def att(self) -> float: return self.overall_att @property def se(self) -> float: return self.overall_se @property def conf_int(self) -> Tuple[float, float]: return self.overall_conf_int @property def p_value(self) -> float: return self.overall_p_value @property def t_stat(self) -> float: return self.overall_t_stat
[docs] def __repr__(self) -> str: """Concise string representation.""" sig = _get_significance_stars(self.overall_p_value) n_rel_periods = len(self.event_study_effects) return ( f"SunAbrahamResults(ATT={self.overall_att:.4f}{sig}, " f"SE={self.overall_se:.4f}, " f"n_groups={len(self.groups)}, " f"n_rel_periods={n_rel_periods})" )
@property def coef_var(self) -> float: """Coefficient of variation: SE / abs(overall ATT). NaN when ATT is 0 or SE non-finite.""" if not (np.isfinite(self.overall_se) and self.overall_se >= 0): return np.nan if not np.isfinite(self.overall_att) or self.overall_att == 0: return np.nan return self.overall_se / abs(self.overall_att)
[docs] def summary(self, alpha: Optional[float] = None) -> str: """ Generate formatted summary of estimation results. Parameters ---------- alpha : float, optional Significance level. Defaults to alpha used in estimation. Returns ------- str Formatted summary. """ alpha = alpha or self.alpha conf_level = int((1 - alpha) * 100) lines = [ "=" * 85, "Sun-Abraham Interaction-Weighted Estimator Results".center(85), "=" * 85, "", f"{'Total observations:':<30} {self.n_obs:>10}", f"{'Treated units:':<30} {self.n_treated_units:>10}", f"{'Control units:':<30} {self.n_control_units:>10}", f"{'Treatment cohorts:':<30} {len(self.groups):>10}", f"{'Time periods:':<30} {len(self.time_periods):>10}", f"{'Control group:':<30} {self.control_group:>10}", "", ] # Add survey design info if self.survey_metadata is not None: sm = self.survey_metadata lines.extend(_format_survey_block(sm, 85)) # Overall ATT lines.extend( [ "-" * 85, "Overall Average Treatment Effect on the Treated".center(85), "-" * 85, f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} " f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", "-" * 85, f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} " f"{self.overall_t_stat:>10.3f} {self.overall_p_value:>10.4f} " f"{_get_significance_stars(self.overall_p_value):>6}", "-" * 85, "", f"{conf_level}% Confidence Interval: " f"[{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]", ] ) cv = self.coef_var if np.isfinite(cv): lines.append(f"{'CV (SE/abs(ATT)):':<25} {cv:>10.4f}") lines.append("") # Event study effects lines.extend( [ "-" * 85, "Event Study (Dynamic) Effects".center(85), "-" * 85, f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} " f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", "-" * 85, ] ) for rel_t in sorted(self.event_study_effects.keys()): eff = self.event_study_effects[rel_t] sig = _get_significance_stars(eff["p_value"]) lines.append( f"{rel_t:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} " f"{eff['t_stat']:>10.3f} {eff['p_value']:>10.4f} {sig:>6}" ) lines.extend(["-" * 85, ""]) lines.extend( [ "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", "=" * 85, ] ) return "\n".join(lines)
[docs] def print_summary(self, alpha: Optional[float] = None) -> None: """Print summary to stdout.""" print(self.summary(alpha))
[docs] def to_dataframe(self, level: str = "event_study") -> pd.DataFrame: """ Convert results to DataFrame. Parameters ---------- level : str, default="event_study" Level of aggregation: "event_study" or "cohort". Returns ------- pd.DataFrame Results as DataFrame. """ if level == "event_study": rows = [] for rel_t, data in sorted(self.event_study_effects.items()): rows.append( { "relative_period": rel_t, "effect": data["effect"], "se": data["se"], "t_stat": data["t_stat"], "p_value": data["p_value"], "conf_int_lower": data["conf_int"][0], "conf_int_upper": data["conf_int"][1], } ) return pd.DataFrame(rows) elif level == "cohort": if self.cohort_effects is None: raise ValueError( "Cohort-level effects not available. " "They are computed internally but not stored by default." ) rows = [] for (cohort, rel_t), data in sorted(self.cohort_effects.items()): rows.append( { "cohort": cohort, "relative_period": rel_t, "effect": data["effect"], "se": data["se"], "weight": data.get("weight", np.nan), } ) return pd.DataFrame(rows) else: raise ValueError(f"Unknown level: {level}. Use 'event_study' or 'cohort'.")
@property def is_significant(self) -> bool: """Check if overall ATT is significant.""" return bool(self.overall_p_value < self.alpha) @property def significance_stars(self) -> str: """Significance stars for overall ATT.""" return _get_significance_stars(self.overall_p_value)
[docs] @dataclass class SABootstrapResults: """ Results from Sun-Abraham bootstrap inference. Attributes ---------- n_bootstrap : int Number of bootstrap iterations. weight_type : str Type of bootstrap used (always "pairs" for pairs bootstrap). alpha : float Significance level used for confidence intervals. overall_att_se : float Bootstrap standard error for overall ATT. overall_att_ci : Tuple[float, float] Bootstrap confidence interval for overall ATT. overall_att_p_value : float Bootstrap p-value for overall ATT. event_study_ses : Dict[int, float] Bootstrap SEs for event study effects. event_study_cis : Dict[int, Tuple[float, float]] Bootstrap CIs for event study effects. event_study_p_values : Dict[int, float] Bootstrap p-values for event study effects. bootstrap_distribution : Optional[np.ndarray] Full bootstrap distribution of overall ATT. """ n_bootstrap: int weight_type: str alpha: float overall_att_se: float overall_att_ci: Tuple[float, float] overall_att_p_value: float event_study_ses: Dict[int, float] event_study_cis: Dict[int, Tuple[float, float]] event_study_p_values: Dict[int, float] bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
[docs] class SunAbraham: """ Sun-Abraham (2021) interaction-weighted estimator for staggered DiD. This estimator provides event-study coefficients using a saturated TWFE regression with cohort × relative-time interactions, following the methodology in Sun & Abraham (2021). The estimation procedure follows three steps: 1. Run a saturated TWFE regression with cohort × relative-time dummies 2. Compute cohort shares (weights) at each relative time 3. Aggregate cohort-specific effects using interaction weights This avoids the negative weighting problem of standard TWFE and provides consistent event-study estimates under treatment effect heterogeneity. Parameters ---------- control_group : str, default="never_treated" Which units to use as controls: - "never_treated": Use only never-treated units (recommended) - "not_yet_treated": Use never-treated and not-yet-treated units anticipation : int, default=0 Number of periods before treatment where effects may occur. alpha : float, default=0.05 Significance level for confidence intervals. cluster : str, optional Column name for cluster-robust standard errors. If None, clusters at the unit level by default. n_bootstrap : int, default=0 Number of bootstrap iterations for inference. If 0, uses analytical cluster-robust standard errors. seed : int, optional Random seed for reproducibility. rank_deficient_action : str, default="warn" Action when design matrix is rank-deficient (linearly dependent columns): - "warn": Issue warning and drop linearly dependent columns (default) - "error": Raise ValueError - "silent": Drop columns silently without warning Attributes ---------- results_ : SunAbrahamResults Estimation results after calling fit(). is_fitted_ : bool Whether the model has been fitted. Examples -------- Basic usage: >>> import pandas as pd >>> from diff_diff import SunAbraham >>> >>> # Panel data with staggered treatment >>> data = pd.DataFrame({ ... 'unit': [...], ... 'time': [...], ... 'outcome': [...], ... 'first_treat': [...] # 0 for never-treated ... }) >>> >>> sa = SunAbraham() >>> results = sa.fit(data, outcome='outcome', unit='unit', ... time='time', first_treat='first_treat') >>> results.print_summary() With covariates: >>> sa = SunAbraham() >>> results = sa.fit(data, outcome='outcome', unit='unit', ... time='time', first_treat='first_treat', ... covariates=['age', 'income']) Notes ----- The Sun-Abraham estimator uses a saturated regression approach: Y_it = α_i + λ_t + Σ_g Σ_e [δ_{g,e} × 1(G_i=g) × D_{it}^e] + X'γ + ε_it where: - α_i = unit fixed effects - λ_t = time fixed effects - G_i = unit i's treatment cohort (first treatment period) - D_{it}^e = indicator for being e periods from treatment - δ_{g,e} = cohort-specific effect (CATT) at relative time e The event-study coefficients are then computed as: β_e = Σ_g w_{g,e} × δ_{g,e} where w_{g,e} is the share of cohort g in the treated population at relative time e (interaction weights). Compared to Callaway-Sant'Anna: - SA uses saturated regression; CS uses 2x2 DiD comparisons - SA can be more efficient when model is correctly specified - Both are consistent under heterogeneous treatment effects - Running both provides a useful robustness check References ---------- Sun, L., & Abraham, S. (2021). Estimating dynamic treatment effects in event studies with heterogeneous treatment effects. Journal of Econometrics, 225(2), 175-199. """
[docs] def __init__( self, control_group: str = "never_treated", anticipation: int = 0, alpha: float = 0.05, cluster: Optional[str] = None, n_bootstrap: int = 0, seed: Optional[int] = None, rank_deficient_action: str = "warn", ): if control_group not in ["never_treated", "not_yet_treated"]: raise ValueError( f"control_group must be 'never_treated' or 'not_yet_treated', " f"got '{control_group}'" ) if rank_deficient_action not in ["warn", "error", "silent"]: raise ValueError( f"rank_deficient_action must be 'warn', 'error', or 'silent', " f"got '{rank_deficient_action}'" ) self.control_group = control_group self.anticipation = anticipation self.alpha = alpha self.cluster = cluster self.n_bootstrap = n_bootstrap self.seed = seed self.rank_deficient_action = rank_deficient_action self.is_fitted_ = False self.results_: Optional[SunAbrahamResults] = None self._reference_period = -1 # Will be set during fit
[docs] def fit( self, data: pd.DataFrame, outcome: str, unit: str, time: str, first_treat: str, covariates: Optional[List[str]] = None, survey_design: object = None, ) -> SunAbrahamResults: """ Fit the Sun-Abraham estimator using saturated regression. Parameters ---------- data : pd.DataFrame Panel data with unit and time identifiers. outcome : str Name of outcome variable column. unit : str Name of unit identifier column. time : str Name of time period column. first_treat : str Name of column indicating when unit was first treated. Use 0 (or np.inf) for never-treated units. covariates : list, optional List of covariate column names to include in regression. survey_design : SurveyDesign, optional Survey design specification for design-based inference. Supports weighted estimation and Taylor series linearization variance with strata, PSU, and FPC. Returns ------- SunAbrahamResults Object containing all estimation results. Raises ------ ValueError If required columns are missing or data validation fails. """ # Validate inputs required_cols = [outcome, unit, time, first_treat] if covariates: required_cols.extend(covariates) missing = [c for c in required_cols if c not in data.columns] if missing: raise ValueError(f"Missing columns: {missing}") # Resolve survey design if provided from diff_diff.survey import ( _resolve_effective_cluster, _resolve_survey_for_fit, _validate_unit_constant_survey, ) resolved_survey, survey_weights, survey_weight_type, survey_metadata = ( _resolve_survey_for_fit(survey_design, data, "analytical") ) # Validate survey columns are constant within units (required for # unit-level collapse in Rao-Wu bootstrap) if resolved_survey is not None: _validate_unit_constant_survey(data, unit, survey_design) _uses_replicate_sa = resolved_survey is not None and resolved_survey.uses_replicate_variance if _uses_replicate_sa and self.n_bootstrap > 0: raise ValueError( "Cannot use n_bootstrap > 0 with replicate-weight survey designs. " "Replicate weights provide their own variance estimation." ) # Bootstrap + survey supported via Rao-Wu rescaled bootstrap. # Determine Rao-Wu eligibility from the *original* survey_design # (before cluster-as-PSU injection which adds PSU to weights-only designs). _use_rao_wu = False if survey_design is not None and resolved_survey is not None: _has_explicit_strata = getattr(survey_design, "strata", None) is not None _has_explicit_psu = getattr(survey_design, "psu", None) is not None _has_explicit_fpc = getattr(survey_design, "fpc", None) is not None if _has_explicit_strata or _has_explicit_psu or _has_explicit_fpc: _use_rao_wu = True # Create working copy df = data.copy() # Ensure numeric types df[time] = pd.to_numeric(df[time]) df[first_treat] = pd.to_numeric(df[first_treat]) # Never-treated indicator (must precede treatment_groups to exclude np.inf) df["_never_treated"] = (df[first_treat] == 0) | (df[first_treat] == np.inf) # Normalize np.inf → 0 so all downstream `> 0` checks exclude never-treated df.loc[df[first_treat] == np.inf, first_treat] = 0 # Identify groups and time periods time_periods = sorted(df[time].unique()) treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0]) # Get unique units unit_info = ( df.groupby(unit).agg({first_treat: "first", "_never_treated": "first"}).reset_index() ) n_treated_units = int((unit_info[first_treat] > 0).sum()) n_control_units = int((unit_info["_never_treated"]).sum()) if n_control_units == 0: raise ValueError("No never-treated units found. Check 'first_treat' column.") if len(treatment_groups) == 0: raise ValueError("No treated units found. Check 'first_treat' column.") # Compute relative time for each observation (vectorized) df["_rel_time"] = np.where(df[first_treat] > 0, df[time] - df[first_treat], np.nan) # Identify the range of relative time periods to estimate rel_times_by_cohort = {} for g in treatment_groups: g_times = df[df[first_treat] == g][time].unique() rel_times_by_cohort[g] = sorted([t - g for t in g_times]) # Find all relative time values all_rel_times: set = set() for g, rel_times in rel_times_by_cohort.items(): all_rel_times.update(rel_times) all_rel_times_sorted = sorted(all_rel_times) # Use full range of relative times (no artificial truncation, matches R's fixest::sunab()) min_rel = min(all_rel_times_sorted) max_rel = max(all_rel_times_sorted) # Reference period: last pre-treatment period (typically -1) self._reference_period = -1 - self.anticipation # Get relative periods to estimate (excluding reference) rel_periods_to_estimate = [ e for e in all_rel_times_sorted if min_rel <= e <= max_rel and e != self._reference_period ] # Determine cluster variable cluster_var = self.cluster if self.cluster is not None else unit # Filter data based on control_group setting if self.control_group == "never_treated": # Only keep never-treated as controls df_reg = df[df["_never_treated"] | (df[first_treat] > 0)].copy() else: # Keep all units (not_yet_treated will be handled by the regression) df_reg = df.copy() # Resolve effective cluster and inject cluster-as-PSU cluster_ids_raw = df_reg[cluster_var].values if cluster_var in df_reg.columns else None effective_cluster_ids = _resolve_effective_cluster( resolved_survey, cluster_ids_raw, cluster_var if self.cluster is not None else None ) if resolved_survey is not None and effective_cluster_ids is not None: from diff_diff.survey import _inject_cluster_as_psu, compute_survey_metadata resolved_survey = _inject_cluster_as_psu(resolved_survey, effective_cluster_ids) if resolved_survey.psu is not None and survey_metadata is not None: raw_w = ( data[survey_design.weights].values.astype(np.float64) if survey_design.weights else np.ones(len(data), dtype=np.float64) ) survey_metadata = compute_survey_metadata(resolved_survey, raw_w) # Fit saturated regression ( cohort_effects, cohort_ses, vcov_cohort, coef_index_map, ) = self._fit_saturated_regression( df_reg, outcome, unit, time, first_treat, treatment_groups, rel_periods_to_estimate, covariates, cluster_var, survey_weights=survey_weights, survey_weight_type=survey_weight_type, # For replicate designs: pass None to prevent LinearRegression from # computing bogus replicate vcov on already-demeaned data. We # override vcov_cohort below with the correct estimator-level refit. resolved_survey=None if _uses_replicate_sa else resolved_survey, ) # Replicate variance override: fully refit the IW estimator per # replicate, including recomputing cohort-share aggregation weights # from w_r, so replicate SEs reflect the complete estimator. _n_valid_rep_sa = None if _uses_replicate_sa: from diff_diff.survey import compute_replicate_refit_variance # The refit returns [overall_att, es_e0, es_e1, ...] after # full re-aggregation with replicate-weighted cohort shares. _sa_rel_periods = list(rel_periods_to_estimate) def _refit_sa(w_r): # Drop zero-weight obs for within-transform safety nz = w_r > 0 df_reg_nz = df_reg[nz] if not np.all(nz) else df_reg w_nz = w_r[nz] if not np.all(nz) else w_r ce_r, _, vcov_r, cim_r = self._fit_saturated_regression( df_reg_nz, outcome, unit, time, first_treat, treatment_groups, _sa_rel_periods, covariates, cluster_var, survey_weights=w_nz, survey_weight_type=survey_weight_type, resolved_survey=None, ) # Create temp weight column for IW aggregation with w_r # Use full w_r (including zeros) for correct mass computation _wt_col = "_rep_wt" df[_wt_col] = w_r es_r, _ = self._compute_iw_effects( df, unit, first_treat, treatment_groups, _sa_rel_periods, ce_r, {}, vcov_r, cim_r, survey_weight_col=_wt_col, ) att_r, _ = self._compute_overall_att( df, first_treat, es_r, ce_r, _, vcov_r, cim_r, survey_weight_col=_wt_col, ) results = [att_r] for e in _sa_rel_periods: results.append(es_r[e]["effect"] if e in es_r else np.nan) return np.array(results) # Resolve survey weight column name for cohort aggregation survey_weight_col = ( survey_design.weights if survey_design is not None and hasattr(survey_design, "weights") and survey_design.weights else None ) # Survey degrees of freedom for t-distribution inference _sa_survey_df = ( max(survey_metadata.df_survey, 1) if survey_metadata is not None and survey_metadata.df_survey is not None else None ) # Replicate df: rank-deficient → NaN inference (dropped-replicate # override happens after replicate refit below) if _uses_replicate_sa and _sa_survey_df is None: _sa_survey_df = 0 # rank-deficient replicate → NaN inference # Compute interaction-weighted event study effects event_study_effects, cohort_weights = self._compute_iw_effects( df, unit, first_treat, treatment_groups, rel_periods_to_estimate, cohort_effects, cohort_ses, vcov_cohort, coef_index_map, survey_weight_col=survey_weight_col, survey_df=_sa_survey_df, ) # Compute overall ATT (average of post-treatment effects) overall_att, overall_se = self._compute_overall_att( df, first_treat, event_study_effects, cohort_effects, cohort_weights, vcov_cohort, coef_index_map, survey_weight_col=survey_weight_col, ) overall_t, overall_p, overall_ci = safe_inference( overall_att, overall_se, alpha=self.alpha, df=_sa_survey_df ) # Replicate variance override: refit fully re-aggregated estimates if _uses_replicate_sa: # Build full-sample estimate vector from actual outputs _full_est_sa = [overall_att] for e in _sa_rel_periods: _full_est_sa.append( event_study_effects[e]["effect"] if e in event_study_effects else np.nan ) _vcov_sa, _n_valid_rep_sa = compute_replicate_refit_variance( _refit_sa, np.array(_full_est_sa), resolved_survey ) # Override df if replicates dropped if _n_valid_rep_sa < resolved_survey.n_replicates: _sa_survey_df = _n_valid_rep_sa - 1 if _n_valid_rep_sa > 1 else 0 if survey_metadata is not None: survey_metadata.df_survey = ( _sa_survey_df if _sa_survey_df and _sa_survey_df > 0 else None ) # Override overall ATT SE overall_se = float(np.sqrt(max(_vcov_sa[0, 0], 0.0))) overall_t, overall_p, overall_ci = safe_inference( overall_att, overall_se, alpha=self.alpha, df=_sa_survey_df ) # Override event-study SEs for i, e in enumerate(_sa_rel_periods): if e in event_study_effects and np.isfinite(event_study_effects[e]["effect"]): se_e = float(np.sqrt(max(_vcov_sa[1 + i, 1 + i], 0.0))) eff_e = event_study_effects[e]["effect"] t_e, p_e, ci_e = safe_inference(eff_e, se_e, alpha=self.alpha, df=_sa_survey_df) event_study_effects[e]["se"] = se_e event_study_effects[e]["t_stat"] = t_e event_study_effects[e]["p_value"] = p_e event_study_effects[e]["conf_int"] = ci_e # Cohort-level replicate SEs: second refit for raw (g,e) coefficients _keys_ordered = sorted(coef_index_map.keys(), key=lambda k: coef_index_map[k]) _full_cohort_vec = np.array([cohort_effects.get(k, np.nan) for k in _keys_ordered]) def _refit_sa_cohort(w_r): nz = w_r > 0 df_reg_nz = df_reg[nz] if not np.all(nz) else df_reg w_nz = w_r[nz] if not np.all(nz) else w_r ce_r, _, _, _ = self._fit_saturated_regression( df_reg_nz, outcome, unit, time, first_treat, treatment_groups, _sa_rel_periods, covariates, cluster_var, survey_weights=w_nz, survey_weight_type=survey_weight_type, resolved_survey=None, ) return np.array([ce_r.get(k, np.nan) for k in _keys_ordered]) _vcov_cohort_rep, _ = compute_replicate_refit_variance( _refit_sa_cohort, _full_cohort_vec, resolved_survey ) for key in _keys_ordered: idx = coef_index_map[key] cohort_ses[key] = float(np.sqrt(max(_vcov_cohort_rep[idx, idx], 0.0))) # Run bootstrap if requested bootstrap_results = None if self.n_bootstrap > 0: bootstrap_results = self._run_bootstrap( df=df_reg, outcome=outcome, unit=unit, time=time, first_treat=first_treat, treatment_groups=treatment_groups, rel_periods_to_estimate=rel_periods_to_estimate, covariates=covariates, cluster_var=cluster_var, original_event_study=event_study_effects, original_overall_att=overall_att, resolved_survey=resolved_survey, survey_weights=survey_weights, survey_weight_type=survey_weight_type, survey_weight_col=survey_weight_col, use_rao_wu=_use_rao_wu, ) # Update results with bootstrap inference overall_se = bootstrap_results.overall_att_se overall_t = safe_inference(overall_att, overall_se, alpha=self.alpha)[0] overall_p = bootstrap_results.overall_att_p_value overall_ci = bootstrap_results.overall_att_ci # Update event study effects for e in event_study_effects: if e in bootstrap_results.event_study_ses: event_study_effects[e]["se"] = bootstrap_results.event_study_ses[e] event_study_effects[e]["conf_int"] = bootstrap_results.event_study_cis[e] event_study_effects[e]["p_value"] = bootstrap_results.event_study_p_values[e] eff_val = event_study_effects[e]["effect"] se_val = event_study_effects[e]["se"] event_study_effects[e]["t_stat"] = safe_inference( eff_val, se_val, alpha=self.alpha )[0] # Convert cohort effects to storage format cohort_effects_storage: Dict[Tuple[Any, int], Dict[str, Any]] = {} for (g, e), effect in cohort_effects.items(): weight = cohort_weights.get(e, {}).get(g, 0.0) se = cohort_ses.get((g, e), 0.0) cohort_effects_storage[(g, e)] = { "effect": effect, "se": se, "weight": weight, } # Store results self.results_ = SunAbrahamResults( event_study_effects=event_study_effects, overall_att=overall_att, overall_se=overall_se, overall_t_stat=overall_t, overall_p_value=overall_p, overall_conf_int=overall_ci, cohort_weights=cohort_weights, groups=treatment_groups, time_periods=time_periods, n_obs=len(df), n_treated_units=n_treated_units, n_control_units=n_control_units, alpha=self.alpha, control_group=self.control_group, anticipation=self.anticipation, bootstrap_results=bootstrap_results, cohort_effects=cohort_effects_storage, survey_metadata=survey_metadata, ) self.is_fitted_ = True return self.results_
def _fit_saturated_regression( self, df: pd.DataFrame, outcome: str, unit: str, time: str, first_treat: str, treatment_groups: List[Any], rel_periods: List[int], covariates: Optional[List[str]], cluster_var: str, survey_weights: Optional[np.ndarray] = None, survey_weight_type: str = "pweight", resolved_survey: object = None, ) -> Tuple[ Dict[Tuple[Any, int], float], Dict[Tuple[Any, int], float], np.ndarray, Dict[Tuple[Any, int], int], ]: """ Fit saturated TWFE regression with cohort × relative-time interactions. Y_it = α_i + λ_t + Σ_g Σ_e [δ_{g,e} × D_{g,e,it}] + X'γ + ε Uses within-transformation for unit fixed effects and time dummies. Returns ------- cohort_effects : dict Mapping (cohort, rel_period) -> effect estimate δ_{g,e} cohort_ses : dict Mapping (cohort, rel_period) -> standard error vcov : np.ndarray Variance-covariance matrix for cohort effects coef_index_map : dict Mapping (cohort, rel_period) -> index in coefficient vector """ df = df.copy() # Create cohort × relative-time interaction dummies # Exclude reference period # Build all columns at once to avoid fragmentation interaction_data = {} coef_index_map: Dict[Tuple[Any, int], int] = {} idx = 0 for g in treatment_groups: for e in rel_periods: col_name = f"_D_{g}_{e}" # Indicator: unit is in cohort g AND at relative time e indicator = ((df[first_treat] == g) & (df["_rel_time"] == e)).astype(float) # Only include if there are observations if indicator.sum() > 0: interaction_data[col_name] = indicator.values coef_index_map[(g, e)] = idx idx += 1 # Add all interaction columns at once interaction_cols = list(interaction_data.keys()) if interaction_data: interaction_df = pd.DataFrame(interaction_data, index=df.index) df = pd.concat([df, interaction_df], axis=1) if len(interaction_cols) == 0: raise ValueError( "No valid cohort × relative-time interactions found. " "Check your data structure." ) # Apply within-transformation for unit and time fixed effects variables_to_demean = [outcome] + interaction_cols if covariates: variables_to_demean.extend(covariates) df_demeaned = _within_transform_util( df, variables_to_demean, unit, time, suffix="_dm", weights=survey_weights ) # Build design matrix X_cols = [f"{col}_dm" for col in interaction_cols] if covariates: X_cols.extend([f"{cov}_dm" for cov in covariates]) X = df_demeaned[X_cols].values y = df_demeaned[f"{outcome}_dm"].values # Fit OLS using LinearRegression helper (more stable than manual X'X inverse) cluster_ids = df_demeaned[cluster_var].values # Degrees of freedom adjustment for absorbed unit and time fixed effects n_units_fe = df[unit].nunique() n_times_fe = df[time].nunique() df_adj = n_units_fe + n_times_fe - 1 reg = LinearRegression( include_intercept=False, # Already demeaned, no intercept needed robust=True, cluster_ids=cluster_ids, rank_deficient_action=self.rank_deficient_action, weights=survey_weights, weight_type=survey_weight_type, survey_design=resolved_survey, ).fit(X, y, df_adjustment=df_adj) vcov = reg.vcov_ # Extract cohort effects and standard errors using get_inference cohort_effects: Dict[Tuple[Any, int], float] = {} cohort_ses: Dict[Tuple[Any, int], float] = {} n_interactions = len(interaction_cols) for (g, e), coef_idx in coef_index_map.items(): inference = reg.get_inference(coef_idx) cohort_effects[(g, e)] = inference.coefficient cohort_ses[(g, e)] = inference.se # Extract just the vcov for cohort effects (excluding covariates) assert vcov is not None vcov_cohort = vcov[:n_interactions, :n_interactions] return cohort_effects, cohort_ses, vcov_cohort, coef_index_map def _within_transform( self, df: pd.DataFrame, variables: List[str], unit: str, time: str, ) -> pd.DataFrame: """ Apply two-way within transformation to remove unit and time fixed effects. y_it - y_i. - y_.t + y_.. """ return _within_transform_util(df, variables, unit, time, suffix="_dm") def _compute_iw_effects( self, df: pd.DataFrame, unit: str, first_treat: str, treatment_groups: List[Any], rel_periods: List[int], cohort_effects: Dict[Tuple[Any, int], float], cohort_ses: Dict[Tuple[Any, int], float], vcov_cohort: np.ndarray, coef_index_map: Dict[Tuple[Any, int], int], survey_weight_col: Optional[str] = None, survey_df: Optional[int] = None, ) -> Tuple[Dict[int, Dict[str, Any]], Dict[int, Dict[Any, float]]]: """ Compute interaction-weighted event study effects. β_e = Σ_g w_{g,e} × δ_{g,e} where w_{g,e} = n_{g,e} / Σ_g n_{g,e} is the share of observations from cohort g at event-time e among all treated observations at that event-time. When survey weights are provided, n_{g,e} is the survey-weighted mass (sum of weights) rather than raw observation counts, so the estimand reflects the survey-weighted cohort composition. Returns ------- event_study_effects : dict Dictionary mapping relative period to aggregated effect info. cohort_weights : dict Dictionary mapping relative period to cohort weight dictionary. """ event_study_effects: Dict[int, Dict[str, Any]] = {} cohort_weights: Dict[int, Dict[Any, float]] = {} # Pre-compute per-event-time observation mass: n_{g,e} # With survey weights, use weighted sum; otherwise raw counts. treated_mask = df[first_treat] > 0 if survey_weight_col is not None and survey_weight_col in df.columns: event_time_counts = ( df[treated_mask].groupby([first_treat, "_rel_time"])[survey_weight_col].sum() ) else: event_time_counts = df[treated_mask].groupby([first_treat, "_rel_time"]).size() for e in rel_periods: # Get cohorts that have observations at this relative time cohorts_at_e = [g for g in treatment_groups if (g, e) in cohort_effects] if not cohorts_at_e: continue # Compute IW weights: n_{g,e} / Σ_g n_{g,e} weights = {} total_size = 0 for g in cohorts_at_e: n_g_e = event_time_counts.get((g, e), 0) weights[g] = n_g_e total_size += n_g_e if total_size == 0: continue # Normalize weights for g in weights: weights[g] = weights[g] / total_size cohort_weights[e] = weights # Compute weighted average effect agg_effect = 0.0 for g in cohorts_at_e: w = weights[g] agg_effect += w * cohort_effects[(g, e)] # Compute SE using delta method with vcov # Var(β_e) = w' Σ w where w is weight vector and Σ is vcov submatrix indices = [coef_index_map[(g, e)] for g in cohorts_at_e] weight_vec = np.array([weights[g] for g in cohorts_at_e]) vcov_subset = vcov_cohort[np.ix_(indices, indices)] agg_var = float(weight_vec @ vcov_subset @ weight_vec) agg_se = np.sqrt(max(agg_var, 0)) t_stat, p_val, ci = safe_inference(agg_effect, agg_se, alpha=self.alpha, df=survey_df) event_study_effects[e] = { "effect": agg_effect, "se": agg_se, "t_stat": t_stat, "p_value": p_val, "conf_int": ci, "n_groups": len(cohorts_at_e), } return event_study_effects, cohort_weights def _compute_overall_att( self, df: pd.DataFrame, first_treat: str, event_study_effects: Dict[int, Dict[str, Any]], cohort_effects: Dict[Tuple[Any, int], float], cohort_weights: Dict[int, Dict[Any, float]], vcov_cohort: np.ndarray, coef_index_map: Dict[Tuple[Any, int], int], survey_weight_col: Optional[str] = None, ) -> Tuple[float, float]: """ Compute overall ATT as weighted average of post-treatment effects. When survey weights are provided, the per-period weights use survey-weighted mass rather than raw observation counts. Returns (att, se) tuple. """ post_effects = [(e, eff) for e, eff in event_study_effects.items() if e >= 0] if not post_effects: return np.nan, np.nan # Weight by (survey-weighted) mass of treated observations at each relative time post_weights = [] post_estimates = [] for e, eff in post_effects: mask = (df["_rel_time"] == e) & (df[first_treat] > 0) if survey_weight_col is not None and survey_weight_col in df.columns: # No floor for survey weights — valid masses can be < 1 n_at_e = df.loc[mask, survey_weight_col].sum() post_weights.append(n_at_e if n_at_e > 0 else 0.0) else: n_at_e = len(df[mask]) post_weights.append(max(n_at_e, 1)) post_estimates.append(eff["effect"]) post_weights_arr = np.array(post_weights, dtype=float) post_weights_arr = post_weights_arr / post_weights_arr.sum() overall_att = float(np.sum(post_weights_arr * np.array(post_estimates))) # Compute SE using delta method # Need to trace back through the full weighting scheme # ATT = Σ_e w_e × β_e = Σ_e w_e × Σ_g w_{g,e} × δ_{g,e} # Collect all (g, e) pairs and their overall weights overall_weights_by_coef: Dict[Tuple[Any, int], float] = {} for i, (e, _) in enumerate(post_effects): period_weight = post_weights_arr[i] if e in cohort_weights: for g, cw in cohort_weights[e].items(): key = (g, e) if key in coef_index_map: if key not in overall_weights_by_coef: overall_weights_by_coef[key] = 0.0 overall_weights_by_coef[key] += period_weight * cw if not overall_weights_by_coef: # Fallback to simplified variance that ignores covariances between periods warnings.warn( "Could not construct full weight vector for overall ATT SE. " "Using simplified variance that ignores covariances between periods.", UserWarning, stacklevel=2, ) overall_var = float( np.sum( (post_weights_arr**2) * np.array([eff["se"] ** 2 for _, eff in post_effects]) ) ) return overall_att, np.sqrt(overall_var) # Build full weight vector and compute variance indices = [coef_index_map[key] for key in overall_weights_by_coef.keys()] weight_vec = np.array(list(overall_weights_by_coef.values())) vcov_subset = vcov_cohort[np.ix_(indices, indices)] overall_var = float(weight_vec @ vcov_subset @ weight_vec) overall_se = np.sqrt(max(overall_var, 0)) return overall_att, overall_se def _run_bootstrap( self, df: pd.DataFrame, outcome: str, unit: str, time: str, first_treat: str, treatment_groups: List[Any], rel_periods_to_estimate: List[int], covariates: Optional[List[str]], cluster_var: str, original_event_study: Dict[int, Dict[str, Any]], original_overall_att: float, resolved_survey: object = None, survey_weights: Optional[np.ndarray] = None, survey_weight_type: str = "pweight", survey_weight_col: Optional[str] = None, use_rao_wu: bool = False, ) -> SABootstrapResults: """ Run bootstrap for inference. When use_rao_wu is True (survey design with explicit strata/PSU/FPC), uses Rao-Wu rescaled bootstrap (weight perturbation). Otherwise, uses pairs bootstrap (resampling units with replacement). """ if self.n_bootstrap < 50: warnings.warn( f"n_bootstrap={self.n_bootstrap} is low. Consider n_bootstrap >= 199 " "for reliable inference.", UserWarning, stacklevel=3, ) rng = np.random.default_rng(self.seed) if use_rao_wu: return self._run_rao_wu_bootstrap( df=df, outcome=outcome, unit=unit, time=time, first_treat=first_treat, treatment_groups=treatment_groups, rel_periods_to_estimate=rel_periods_to_estimate, covariates=covariates, cluster_var=cluster_var, original_event_study=original_event_study, original_overall_att=original_overall_att, resolved_survey=resolved_survey, survey_weight_type=survey_weight_type, survey_weight_col=survey_weight_col, rng=rng, ) # --- Pairs bootstrap (non-survey or weights-only survey) --- # Get unique units all_units = df[unit].unique() n_units = len(all_units) # Pre-compute unit -> row indices mapping (avoids repeated boolean scans) unit_row_indices = {u: df.index[df[unit] == u].values for u in all_units} unit_row_counts = {u: len(idx) for u, idx in unit_row_indices.items()} # Store bootstrap samples rel_periods = sorted(original_event_study.keys()) bootstrap_effects = {e: np.zeros(self.n_bootstrap) for e in rel_periods} bootstrap_overall = np.zeros(self.n_bootstrap) for b in range(self.n_bootstrap): # Resample units with replacement (pairs bootstrap) boot_units = rng.choice(all_units, size=n_units, replace=True) # Create bootstrap sample using pre-computed index mapping boot_indices = np.concatenate([unit_row_indices[u] for u in boot_units]) df_b = df.iloc[boot_indices].copy() # Reassign unique unit IDs (vectorized) rows_per_unit = np.array([unit_row_counts[u] for u in boot_units]) df_b[unit] = np.repeat(np.arange(n_units), rows_per_unit) # Recompute relative time (vectorized) df_b["_rel_time"] = np.where( df_b[first_treat] > 0, df_b[time] - df_b[first_treat], np.nan ) # np.inf was normalized to 0 in fit(), so the np.inf check is defensive only df_b["_never_treated"] = (df_b[first_treat] == 0) | (df_b[first_treat] == np.inf) try: # Extract survey weights from resampled data if present boot_survey_weights = None if survey_weight_col is not None and survey_weight_col in df_b.columns: boot_survey_weights = df_b[survey_weight_col].values # Re-estimate saturated regression ( cohort_effects_b, cohort_ses_b, vcov_b, coef_map_b, ) = self._fit_saturated_regression( df_b, outcome, unit, time, first_treat, treatment_groups, rel_periods_to_estimate, covariates, cluster_var, survey_weights=boot_survey_weights, survey_weight_type=survey_weight_type, resolved_survey=None, # Use explicit weights, not stale design ) # Compute IW effects for this bootstrap sample event_study_b, cohort_weights_b = self._compute_iw_effects( df_b, unit, first_treat, treatment_groups, rel_periods_to_estimate, cohort_effects_b, cohort_ses_b, vcov_b, coef_map_b, survey_weight_col=survey_weight_col, ) # Store bootstrap estimates for e in rel_periods: if e in event_study_b: bootstrap_effects[e][b] = event_study_b[e]["effect"] else: bootstrap_effects[e][b] = original_event_study[e]["effect"] # Compute overall ATT for this bootstrap sample overall_b, _ = self._compute_overall_att( df_b, first_treat, event_study_b, cohort_effects_b, cohort_weights_b, vcov_b, coef_map_b, survey_weight_col=survey_weight_col, ) bootstrap_overall[b] = overall_b except (ValueError, np.linalg.LinAlgError) as exc: # If bootstrap iteration fails, use original warnings.warn( f"Bootstrap iteration {b} failed: {exc}. Using original estimate.", UserWarning, stacklevel=2, ) for e in rel_periods: bootstrap_effects[e][b] = original_event_study[e]["effect"] bootstrap_overall[b] = original_overall_att # Compute bootstrap statistics event_study_ses = {} event_study_cis = {} event_study_p_values = {} for e in rel_periods: boot_dist = bootstrap_effects[e] original_effect = original_event_study[e]["effect"] se, ci, p_value = compute_effect_bootstrap_stats( original_effect, boot_dist, alpha=self.alpha, context=f"event study e={e}", ) event_study_ses[e] = se event_study_cis[e] = ci event_study_p_values[e] = p_value # Overall ATT statistics overall_se, overall_ci, overall_p = compute_effect_bootstrap_stats( original_overall_att, bootstrap_overall, alpha=self.alpha, context="overall ATT", ) return SABootstrapResults( n_bootstrap=self.n_bootstrap, weight_type="pairs", alpha=self.alpha, overall_att_se=overall_se, overall_att_ci=overall_ci, overall_att_p_value=overall_p, event_study_ses=event_study_ses, event_study_cis=event_study_cis, event_study_p_values=event_study_p_values, bootstrap_distribution=bootstrap_overall, ) def _run_rao_wu_bootstrap( self, df: pd.DataFrame, outcome: str, unit: str, time: str, first_treat: str, treatment_groups: List[Any], rel_periods_to_estimate: List[int], covariates: Optional[List[str]], cluster_var: str, original_event_study: Dict[int, Dict[str, Any]], original_overall_att: float, resolved_survey: object, survey_weight_type: str, survey_weight_col: Optional[str], rng: np.random.Generator, ) -> SABootstrapResults: """ Run Rao-Wu rescaled bootstrap for survey-aware inference. Instead of physically resampling units, each iteration generates rescaled observation weights via Rao-Wu (1988) weight perturbation. The rescaled weights feed into the existing WLS regression path. """ from diff_diff.bootstrap_utils import generate_rao_wu_weights from diff_diff.survey import ResolvedSurveyDesign # Column name for rescaled weights in the bootstrap DataFrame _rw_col = "__rw_boot_weight" # Collapse survey design to unit level so Rao-Wu respects panel # structure: each unit gets one set of weights regardless of how # many time periods it has. Without this, when there is no # explicit PSU, generate_rao_wu_weights treats each observation as # its own PSU and different obs of the same unit can get different # weights, breaking panel semantics. all_units = df[unit].unique() weights_unit = ( pd.Series(resolved_survey.weights, index=df.index) .groupby(df[unit]) .first() .reindex(all_units) .values.astype(np.float64) ) strata_unit = None if resolved_survey.strata is not None: strata_unit = ( pd.Series(resolved_survey.strata, index=df.index) .groupby(df[unit]) .first() .reindex(all_units) .values ) psu_unit = None if resolved_survey.psu is not None: psu_unit = ( pd.Series(resolved_survey.psu, index=df.index) .groupby(df[unit]) .first() .reindex(all_units) .values ) fpc_unit = None if resolved_survey.fpc is not None: fpc_unit = ( pd.Series(resolved_survey.fpc, index=df.index) .groupby(df[unit]) .first() .reindex(all_units) .values ) unit_resolved = ResolvedSurveyDesign( weights=weights_unit, weight_type=resolved_survey.weight_type, strata=strata_unit, psu=psu_unit, fpc=fpc_unit, n_strata=resolved_survey.n_strata, n_psu=resolved_survey.n_psu, lonely_psu=resolved_survey.lonely_psu, ) # Build unit -> row indices mapping for expanding unit-level weights unit_to_rows = {u: df.index[df[unit] == u].values for u in all_units} unit_order = {u: i for i, u in enumerate(all_units)} # Store bootstrap samples rel_periods = sorted(original_event_study.keys()) bootstrap_effects = {e: np.full(self.n_bootstrap, np.nan) for e in rel_periods} bootstrap_overall = np.full(self.n_bootstrap, np.nan) for b in range(self.n_bootstrap): try: # Generate Rao-Wu rescaled weights at unit level unit_boot_weights = generate_rao_wu_weights(unit_resolved, rng) # Expand unit-level weights to observation level boot_weights = np.empty(len(df), dtype=np.float64) for u, idx in unit_to_rows.items(): boot_weights[idx] = unit_boot_weights[unit_order[u]] # Drop observations with zero weight (PSUs not drawn in this # iteration) to avoid NaN/Inf in within-transformation. positive_mask = boot_weights > 0 if positive_mask.sum() < 2: # Too few observations with positive weight raise ValueError("Rao-Wu iteration produced < 2 positive weights") df_b = df[positive_mask].reset_index(drop=True) boot_weights_b = boot_weights[positive_mask] df_b[_rw_col] = boot_weights_b # Verify we still have both treated and control observations has_treated = (df_b[first_treat] > 0).any() has_control = ((df_b[first_treat] == 0) | (df_b[first_treat] == np.inf)).any() if not has_treated or not has_control: raise ValueError("Rao-Wu iteration dropped all treated or control units") # Re-estimate saturated regression with rescaled weights. # Pass resolved_survey=None since inference comes from the # bootstrap distribution, not from within-iteration vcov. ( cohort_effects_b, cohort_ses_b, vcov_b, coef_map_b, ) = self._fit_saturated_regression( df_b, outcome, unit, time, first_treat, treatment_groups, rel_periods_to_estimate, covariates, cluster_var, survey_weights=boot_weights_b, survey_weight_type=survey_weight_type, resolved_survey=None, ) # Compute IW effects using rescaled weights for cohort shares event_study_b, cohort_weights_b = self._compute_iw_effects( df_b, unit, first_treat, treatment_groups, rel_periods_to_estimate, cohort_effects_b, cohort_ses_b, vcov_b, coef_map_b, survey_weight_col=_rw_col, ) # Store bootstrap estimates for e in rel_periods: if e in event_study_b: bootstrap_effects[e][b] = event_study_b[e]["effect"] else: bootstrap_effects[e][b] = original_event_study[e]["effect"] # Compute overall ATT using rescaled weights overall_b, _ = self._compute_overall_att( df_b, first_treat, event_study_b, cohort_effects_b, cohort_weights_b, vcov_b, coef_map_b, survey_weight_col=_rw_col, ) bootstrap_overall[b] = overall_b except (ValueError, np.linalg.LinAlgError) as exc: # Failed draws stored as NaN (not original estimate) to avoid # shrinking bootstrap dispersion. compute_effect_bootstrap_stats # handles NaN draws via nanstd. warnings.warn( f"Bootstrap iteration {b} failed: {exc}. Storing NaN.", UserWarning, stacklevel=2, ) for e in rel_periods: bootstrap_effects[e][b] = np.nan bootstrap_overall[b] = np.nan # Compute bootstrap statistics event_study_ses = {} event_study_cis = {} event_study_p_values = {} for e in rel_periods: boot_dist = bootstrap_effects[e] original_effect = original_event_study[e]["effect"] se, ci, p_value = compute_effect_bootstrap_stats( original_effect, boot_dist, alpha=self.alpha, context=f"event study e={e}", ) event_study_ses[e] = se event_study_cis[e] = ci event_study_p_values[e] = p_value # Overall ATT statistics overall_se, overall_ci, overall_p = compute_effect_bootstrap_stats( original_overall_att, bootstrap_overall, alpha=self.alpha, context="overall ATT", ) return SABootstrapResults( n_bootstrap=self.n_bootstrap, weight_type="rao_wu", alpha=self.alpha, overall_att_se=overall_se, overall_att_ci=overall_ci, overall_att_p_value=overall_p, event_study_ses=event_study_ses, event_study_cis=event_study_cis, event_study_p_values=event_study_p_values, bootstrap_distribution=bootstrap_overall, )
[docs] def get_params(self) -> Dict[str, Any]: """Get estimator parameters (sklearn-compatible).""" return { "control_group": self.control_group, "anticipation": self.anticipation, "alpha": self.alpha, "cluster": self.cluster, "n_bootstrap": self.n_bootstrap, "seed": self.seed, "rank_deficient_action": self.rank_deficient_action, }
[docs] def set_params(self, **params) -> "SunAbraham": """Set estimator parameters (sklearn-compatible).""" for key, value in params.items(): if hasattr(self, key): setattr(self, key, value) else: raise ValueError(f"Unknown parameter: {key}") return self
[docs] def summary(self) -> str: """Get summary of estimation results.""" if not self.is_fitted_: raise RuntimeError("Model must be fitted before calling summary()") assert self.results_ is not None return self.results_.summary()
[docs] def print_summary(self) -> None: """Print summary to stdout.""" print(self.summary())