Source code for diff_diff.trop_results

"""
Result containers for the Triply Robust Panel (TROP) estimator.

This module contains the TROPResults dataclass, _PrecomputedStructures TypedDict,
and _LAMBDA_INF sentinel value. Extracted from trop.py for module size management.
"""

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd

try:
    from typing import TypedDict
except ImportError:
    from typing_extensions import TypedDict

from diff_diff.results import _format_survey_block, _get_significance_stars

__all__ = [
    "_LAMBDA_INF",
    "_PrecomputedStructures",
    "TROPResults",
]


# Sentinel value for "disabled" λ_nn in LOOCV parameter search.
# Per paper's footnote 2: λ_nn=∞ disables the factor model (L=0).
# For λ_time and λ_unit, 0.0 means disabled (uniform weights) per Eq. 3:
#   exp(-0 × dist) = 1 for all distances.
_LAMBDA_INF: float = float("inf")


class _PrecomputedStructures(TypedDict):
    """Type definition for pre-computed structures used across LOOCV iterations.

    These structures are computed once in `_precompute_structures()` and reused
    to avoid redundant computation during LOOCV and final estimation.
    """

    unit_dist_matrix: np.ndarray
    """Pairwise unit distance matrix (n_units x n_units)."""
    time_dist_matrix: np.ndarray
    """Time distance matrix where [t, s] = |t - s| (n_periods x n_periods)."""
    control_mask: np.ndarray
    """Boolean mask for control observations (D == 0)."""
    treated_mask: np.ndarray
    """Boolean mask for treated observations (D == 1)."""
    treated_observations: List[Tuple[int, int]]
    """List of (t, i) tuples for treated observations."""
    control_obs: List[Tuple[int, int]]
    """List of (t, i) tuples for valid control observations."""
    control_unit_idx: np.ndarray
    """Array of never-treated unit indices (for backward compatibility)."""
    D: np.ndarray
    """Treatment indicator matrix (n_periods x n_units) for dynamic control sets."""
    Y: np.ndarray
    """Outcome matrix (n_periods x n_units)."""
    n_units: int
    """Number of units."""
    n_periods: int
    """Number of time periods."""


[docs] @dataclass class TROPResults: """ Results from a Triply Robust Panel (TROP) estimation. TROP combines nuclear norm regularized factor estimation with exponential distance-based unit weights and time decay weights. Attributes ---------- att : float Average Treatment effect on the Treated (ATT). se : float Standard error of the ATT estimate. t_stat : float T-statistic for the ATT estimate. p_value : float P-value for the null hypothesis that ATT = 0. conf_int : tuple[float, float] Confidence interval for the ATT. n_obs : int Number of observations used in estimation. n_treated : int Number of treated units. n_control : int Number of control units. n_treated_obs : int Number of treated unit-time observations. unit_effects : dict Estimated unit fixed effects (alpha_i). time_effects : dict Estimated time fixed effects (beta_t). treatment_effects : dict Individual treatment effects for each treated (unit, time) pair. lambda_time : float Selected time weight decay parameter from grid. 0.0 = uniform time weights (disabled) per Eq. 3. lambda_unit : float Selected unit weight decay parameter from grid. 0.0 = uniform unit weights (disabled) per Eq. 3. lambda_nn : float Selected nuclear norm regularization parameter from grid. inf = factor model disabled (L=0); converted to 1e10 internally for computation. factor_matrix : np.ndarray Estimated low-rank factor matrix L (n_periods x n_units). effective_rank : float Effective rank of the factor matrix (sum of singular values / max). loocv_score : float Leave-one-out cross-validation score for selected parameters. alpha : float Significance level for confidence interval. n_pre_periods : int Number of pre-treatment periods. n_post_periods : int Number of post-treatment periods (periods with D=1 observations). n_bootstrap : int, optional Number of bootstrap replications (if bootstrap variance). bootstrap_distribution : np.ndarray, optional Bootstrap distribution of estimates. """ att: float se: float t_stat: float p_value: float conf_int: Tuple[float, float] n_obs: int n_treated: int n_control: int n_treated_obs: int unit_effects: Dict[Any, float] time_effects: Dict[Any, float] treatment_effects: Dict[Tuple[Any, Any], float] lambda_time: float lambda_unit: float lambda_nn: float factor_matrix: np.ndarray effective_rank: float loocv_score: float alpha: float = 0.05 n_pre_periods: int = 0 n_post_periods: int = 0 n_bootstrap: Optional[int] = field(default=None) bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False) # Survey design metadata (SurveyMetadata instance from diff_diff.survey) survey_metadata: Optional[Any] = field(default=None)
[docs] def __repr__(self) -> str: """Concise string representation.""" sig = _get_significance_stars(self.p_value) return ( f"TROPResults(ATT={self.att:.4f}{sig}, " f"SE={self.se:.4f}, " f"eff_rank={self.effective_rank:.1f}, " f"p={self.p_value:.4f})" )
@property def coef_var(self) -> float: """Coefficient of variation: SE / abs(ATT). NaN when ATT is 0 or SE non-finite.""" if not (np.isfinite(self.se) and self.se >= 0): return np.nan if not np.isfinite(self.att) or self.att == 0: return np.nan return self.se / abs(self.att)
[docs] def summary(self, alpha: Optional[float] = None) -> str: """ Generate a formatted summary of the estimation results. Parameters ---------- alpha : float, optional Significance level for confidence intervals. Defaults to the alpha used during estimation. Returns ------- str Formatted summary table. """ alpha = alpha or self.alpha conf_level = int((1 - alpha) * 100) lines = [ "=" * 75, "Triply Robust Panel (TROP) Estimation Results".center(75), "Athey, Imbens, Qu & Viviano (2025)".center(75), "=" * 75, "", f"{'Observations:':<25} {self.n_obs:>10}", f"{'Treated units:':<25} {self.n_treated:>10}", f"{'Control units:':<25} {self.n_control:>10}", f"{'Treated observations:':<25} {self.n_treated_obs:>10}", f"{'Pre-treatment periods:':<25} {self.n_pre_periods:>10}", f"{'Post-treatment periods:':<25} {self.n_post_periods:>10}", "", "-" * 75, "Tuning Parameters (selected via LOOCV)".center(75), "-" * 75, f"{'Lambda (time decay):':<25} {self.lambda_time:>10.4f}", f"{'Lambda (unit distance):':<25} {self.lambda_unit:>10.4f}", f"{'Lambda (nuclear norm):':<25} {self.lambda_nn:>10.4f}", f"{'Effective rank:':<25} {self.effective_rank:>10.2f}", f"{'LOOCV score:':<25} {self.loocv_score:>10.6f}", ] # Variance info if self.n_bootstrap is not None: lines.append(f"{'Bootstrap replications:':<25} {self.n_bootstrap:>10}") # Add survey design info if self.survey_metadata is not None: sm = self.survey_metadata lines.extend(_format_survey_block(sm, 75)) lines.extend( [ "", "-" * 75, f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} " f"{'t-stat':>10} {'P>|t|':>10} {'':>5}", "-" * 75, f"{'ATT':<15} {self.att:>12.4f} {self.se:>12.4f} " f"{self.t_stat:>10.3f} {self.p_value:>10.4f} {self.significance_stars:>5}", "-" * 75, "", f"{conf_level}% Confidence Interval: [{self.conf_int[0]:.4f}, {self.conf_int[1]:.4f}]", ] ) cv = self.coef_var if np.isfinite(cv): lines.append(f"{'CV (SE/abs(ATT)):':<25} {cv:>10.4f}") # Add significance codes lines.extend( [ "", "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", "=" * 75, ] ) return "\n".join(lines)
[docs] def print_summary(self, alpha: Optional[float] = None) -> None: """Print the summary to stdout.""" print(self.summary(alpha))
[docs] def to_dict(self) -> Dict[str, Any]: """ Convert results to a dictionary. Returns ------- Dict[str, Any] Dictionary containing all estimation results. """ result = { "att": self.att, "se": self.se, "t_stat": self.t_stat, "p_value": self.p_value, "conf_int_lower": self.conf_int[0], "conf_int_upper": self.conf_int[1], "n_obs": self.n_obs, "n_treated": self.n_treated, "n_control": self.n_control, "n_treated_obs": self.n_treated_obs, "n_pre_periods": self.n_pre_periods, "n_post_periods": self.n_post_periods, "lambda_time": self.lambda_time, "lambda_unit": self.lambda_unit, "lambda_nn": self.lambda_nn, "effective_rank": self.effective_rank, "loocv_score": self.loocv_score, } if self.survey_metadata is not None: sm = self.survey_metadata result["weight_type"] = sm.weight_type result["effective_n"] = sm.effective_n result["design_effect"] = sm.design_effect result["sum_weights"] = sm.sum_weights result["n_strata"] = sm.n_strata result["n_psu"] = sm.n_psu result["df_survey"] = sm.df_survey return result
[docs] def to_dataframe(self) -> pd.DataFrame: """ Convert results to a pandas DataFrame. Returns ------- pd.DataFrame DataFrame with estimation results. """ return pd.DataFrame([self.to_dict()])
[docs] def get_treatment_effects_df(self) -> pd.DataFrame: """ Get individual treatment effects as a DataFrame. Returns ------- pd.DataFrame DataFrame with unit, time, and treatment effect columns. """ return pd.DataFrame( [ {"unit": unit, "time": time, "effect": effect} for (unit, time), effect in self.treatment_effects.items() ] )
[docs] def get_unit_effects_df(self) -> pd.DataFrame: """ Get unit fixed effects as a DataFrame. Returns ------- pd.DataFrame DataFrame with unit and effect columns. """ return pd.DataFrame( [{"unit": unit, "effect": effect} for unit, effect in self.unit_effects.items()] )
[docs] def get_time_effects_df(self) -> pd.DataFrame: """ Get time fixed effects as a DataFrame. Returns ------- pd.DataFrame DataFrame with time and effect columns. """ return pd.DataFrame( [{"time": time, "effect": effect} for time, effect in self.time_effects.items()] )
@property def is_significant(self) -> bool: """Check if the ATT is statistically significant at the alpha level.""" return bool(self.p_value < self.alpha) @property def significance_stars(self) -> str: """Return significance stars based on p-value.""" return _get_significance_stars(self.p_value)