Source code for diff_diff.wooldridge_results

"""Results class for WooldridgeDiD (ETWFE) estimator."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd

from diff_diff.utils import safe_inference


[docs] @dataclass class WooldridgeDiDResults: """Results from WooldridgeDiD.fit(). Core output is ``group_time_effects``: a dict keyed by (cohort_g, time_t) with per-cell ATT estimates and inference. Call ``.aggregate(type)`` to compute any of the four jwdid_estat aggregation types. """ # ------------------------------------------------------------------ # # Core cohort×time estimates # # ------------------------------------------------------------------ # group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]] """key=(g,t), value={att, se, t_stat, p_value, conf_int}""" # ------------------------------------------------------------------ # # Simple (overall) aggregation — always populated at fit time # # ------------------------------------------------------------------ # overall_att: float overall_se: float overall_t_stat: float overall_p_value: float overall_conf_int: Tuple[float, float] # ------------------------------------------------------------------ # # Other aggregations — populated by .aggregate() # # ------------------------------------------------------------------ # group_effects: Optional[Dict[Any, Dict]] = field(default=None, repr=False) calendar_effects: Optional[Dict[Any, Dict]] = field(default=None, repr=False) event_study_effects: Optional[Dict[int, Dict]] = field(default=None, repr=False) # ------------------------------------------------------------------ # # Metadata # # ------------------------------------------------------------------ # method: str = "ols" control_group: str = "not_yet_treated" groups: List[Any] = field(default_factory=list) time_periods: List[Any] = field(default_factory=list) n_obs: int = 0 n_treated_units: int = 0 n_control_units: int = 0 alpha: float = 0.05 anticipation: int = 0 survey_metadata: Optional[Any] = field(default=None, repr=False) # ------------------------------------------------------------------ # # Internal — used by aggregate() for delta-method SEs # # ------------------------------------------------------------------ # _gt_weights: Dict[Tuple[Any, Any], int] = field(default_factory=dict, repr=False) _gt_vcov: Optional[np.ndarray] = field(default=None, repr=False) """Full vcov of all β_{g,t} coefficients (ordered same as sorted group_time_effects keys).""" _gt_keys: List[Tuple[Any, Any]] = field(default_factory=list, repr=False) """Ordered list of (g,t) keys corresponding to _gt_vcov columns.""" _df_survey: Optional[int] = field(default=None, repr=False) """Survey degrees of freedom for t-distribution inference.""" # ------------------------------------------------------------------ # # Public methods # # ------------------------------------------------------------------ #
[docs] def aggregate(self, type: str) -> "WooldridgeDiDResults": # noqa: A002 """Compute and store one of the four jwdid_estat aggregation types. Parameters ---------- type : "simple" | "group" | "calendar" | "event" Returns self for chaining. """ valid = ("simple", "group", "calendar", "event") if type not in valid: raise ValueError(f"type must be one of {valid}, got {type!r}") gt = self.group_time_effects weights = self._gt_weights vcov = self._gt_vcov keys_ordered = self._gt_keys if self._gt_keys else sorted(gt.keys()) def _agg_se(w_vec: np.ndarray) -> float: """Delta-method SE for a linear combination w'β given full vcov.""" if vcov is None or len(w_vec) != vcov.shape[0]: return float("nan") return float(np.sqrt(max(w_vec @ vcov @ w_vec, 0.0))) def _build_effect(att: float, se: float) -> Dict[str, Any]: t_stat, p_value, conf_int = safe_inference( att, se, alpha=self.alpha, df=self._df_survey ) return { "att": att, "se": se, "t_stat": t_stat, "p_value": p_value, "conf_int": conf_int, } if type == "simple": # Re-compute overall using delta method (already stored in overall_* fields) # This is a no-op but keeps the method callable. pass elif type == "group": result: Dict[Any, Dict] = {} for g in self.groups: cells = [(g2, t) for (g2, t) in keys_ordered if g2 == g and t >= g] if not cells: continue w_total = sum(weights.get(c, 0) for c in cells) if w_total == 0: continue att = sum(weights.get(c, 0) * gt[c]["att"] for c in cells) / w_total # delta-method weights vector over all keys_ordered w_vec = np.array( [weights.get(c, 0) / w_total if c in cells else 0.0 for c in keys_ordered] ) se = _agg_se(w_vec) result[g] = _build_effect(att, se) self.group_effects = result elif type == "calendar": result = {} for t in self.time_periods: cells = [(g, t2) for (g, t2) in keys_ordered if t2 == t and t >= g] if not cells: continue w_total = sum(weights.get(c, 0) for c in cells) if w_total == 0: continue att = sum(weights.get(c, 0) * gt[c]["att"] for c in cells) / w_total w_vec = np.array( [weights.get(c, 0) / w_total if c in cells else 0.0 for c in keys_ordered] ) se = _agg_se(w_vec) result[t] = _build_effect(att, se) self.calendar_effects = result elif type == "event": all_k = sorted({t - g for (g, t) in keys_ordered}) result = {} for k in all_k: cells = [(g, t) for (g, t) in keys_ordered if t - g == k] if not cells: continue w_total = sum(weights.get(c, 0) for c in cells) if w_total == 0: continue att = sum(weights.get(c, 0) * gt[c]["att"] for c in cells) / w_total w_vec = np.array( [weights.get(c, 0) / w_total if c in cells else 0.0 for c in keys_ordered] ) se = _agg_se(w_vec) result[k] = _build_effect(att, se) self.event_study_effects = result return self
[docs] def summary(self, aggregation: str = "simple") -> str: """Print formatted summary table. Parameters ---------- aggregation : which aggregation to display ("simple", "group", "calendar", "event") """ lines = [ "=" * 70, " Wooldridge Extended Two-Way Fixed Effects (ETWFE) Results", "=" * 70, f"Method: {self.method}", f"Control group: {self.control_group}", f"Observations: {self.n_obs}", f"Treated units: {self.n_treated_units}", f"Control units: {self.n_control_units}", "-" * 70, ] if self.survey_metadata is not None: from diff_diff.results import _format_survey_block lines.extend(_format_survey_block(self.survey_metadata, 70)) lines.append("-" * 70) def _fmt_row(label: str, att: float, se: float, t: float, p: float, ci: Tuple) -> str: from diff_diff.results import _get_significance_stars # type: ignore stars = _get_significance_stars(p) if not np.isnan(p) else "" ci_lo = f"{ci[0]:.4f}" if not np.isnan(ci[0]) else "NaN" ci_hi = f"{ci[1]:.4f}" if not np.isnan(ci[1]) else "NaN" return ( f"{label:<22} {att:>10.4f} {se:>10.4f} {t:>8.3f} " f"{p:>8.4f}{stars} [{ci_lo}, {ci_hi}]" ) ci_pct = f"{(1 - self.alpha) * 100:.0f}%" header = ( f"{'Parameter':<22} {'Estimate':>10} {'Std. Err.':>10} " f"{'t-stat':>8} {'P>|t|':>8} [{ci_pct} CI]" ) lines.append(header) lines.append("-" * 70) if aggregation == "simple": lines.append( _fmt_row( "ATT (simple)", self.overall_att, self.overall_se, self.overall_t_stat, self.overall_p_value, self.overall_conf_int, ) ) elif aggregation == "group" and self.group_effects: for g, eff in sorted(self.group_effects.items()): lines.append( _fmt_row( f"ATT(g={g})", eff["att"], eff["se"], eff["t_stat"], eff["p_value"], eff["conf_int"], ) ) elif aggregation == "calendar" and self.calendar_effects: for t, eff in sorted(self.calendar_effects.items()): lines.append( _fmt_row( f"ATT(t={t})", eff["att"], eff["se"], eff["t_stat"], eff["p_value"], eff["conf_int"], ) ) elif aggregation == "event" and self.event_study_effects: for k, eff in sorted(self.event_study_effects.items()): if k < -self.anticipation: suffix = " [pre]" elif k < 0: suffix = " [antic]" else: suffix = "" label = f"ATT(k={k})" + suffix lines.append( _fmt_row( label, eff["att"], eff["se"], eff["t_stat"], eff["p_value"], eff["conf_int"], ) ) else: lines.append(f" (call .aggregate({aggregation!r}) first)") lines.append("=" * 70) return "\n".join(lines)
[docs] def to_dataframe(self, aggregation: str = "event") -> pd.DataFrame: """Export aggregated effects to a DataFrame. Parameters ---------- aggregation : "simple" | "group" | "calendar" | "event" | "gt" Use "gt" to export raw group-time effects. """ if aggregation == "gt": rows = [] for (g, t), eff in sorted(self.group_time_effects.items()): row = {"cohort": g, "time": t, "relative_period": t - g} row.update(eff) rows.append(row) return pd.DataFrame(rows) mapping = { "simple": [ { "label": "ATT", "att": self.overall_att, "se": self.overall_se, "t_stat": self.overall_t_stat, "p_value": self.overall_p_value, "conf_int_lo": self.overall_conf_int[0], "conf_int_hi": self.overall_conf_int[1], } ], "group": [ { "cohort": g, **{k: v for k, v in eff.items() if k != "conf_int"}, "conf_int_lo": eff["conf_int"][0], "conf_int_hi": eff["conf_int"][1], } for g, eff in sorted((self.group_effects or {}).items()) ], "calendar": [ { "time": t, **{k: v for k, v in eff.items() if k != "conf_int"}, "conf_int_lo": eff["conf_int"][0], "conf_int_hi": eff["conf_int"][1], } for t, eff in sorted((self.calendar_effects or {}).items()) ], "event": [ { "relative_period": k, **{kk: vv for kk, vv in eff.items() if kk != "conf_int"}, "conf_int_lo": eff["conf_int"][0], "conf_int_hi": eff["conf_int"][1], } for k, eff in sorted((self.event_study_effects or {}).items()) ], } rows = mapping.get(aggregation, []) return pd.DataFrame(rows)
[docs] def plot_event_study(self, **kwargs) -> None: """Event study plot. Calls aggregate('event') if needed.""" if self.event_study_effects is None: self.aggregate("event") from diff_diff.visualization import plot_event_study # type: ignore effects = {k: v["att"] for k, v in (self.event_study_effects or {}).items()} se = {k: v["se"] for k, v in (self.event_study_effects or {}).items()} plot_event_study(effects=effects, se=se, alpha=self.alpha, **kwargs)
# --- Inference-field aliases (balance/external-adapter compatibility) --- @property def att(self) -> float: return self.overall_att @property def se(self) -> float: return self.overall_se @property def conf_int(self) -> Tuple[float, float]: return self.overall_conf_int @property def p_value(self) -> float: return self.overall_p_value @property def t_stat(self) -> float: return self.overall_t_stat def __repr__(self) -> str: n_gt = len(self.group_time_effects) att_str = f"{self.overall_att:.4f}" if not np.isnan(self.overall_att) else "NaN" se_str = f"{self.overall_se:.4f}" if not np.isnan(self.overall_se) else "NaN" p_str = f"{self.overall_p_value:.4f}" if not np.isnan(self.overall_p_value) else "NaN" return ( f"WooldridgeDiDResults(" f"ATT={att_str}, SE={se_str}, p={p_str}, " f"n_gt={n_gt}, method={self.method!r})" )