Source code for diff_diff.trop

"""
Triply Robust Panel (TROP) estimator.

Implements the TROP estimator from Athey, Imbens, Qu & Viviano (2025).
TROP combines three robustness components:
1. Nuclear norm regularized factor model (interactive fixed effects)
2. Exponential distance-based unit weights
3. Exponential time decay weights

The estimator uses leave-one-out cross-validation for tuning parameter
selection and provides robust treatment effect estimates under factor
confounding.

References
----------
Athey, S., Imbens, G. W., Qu, Z., & Viviano, D. (2025). Triply Robust Panel
Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536
"""

import logging
import warnings
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd

logger = logging.getLogger(__name__)

from diff_diff._backend import (
    HAS_RUST_BACKEND,
    _rust_unit_distance_matrix,
    _rust_loocv_grid_search,
    _rust_bootstrap_trop_variance,
    _rust_loocv_grid_search_joint,
    _rust_bootstrap_trop_variance_joint,
)
from diff_diff.trop_results import (
    _LAMBDA_INF,
    _PrecomputedStructures,
    TROPResults,
)
from diff_diff.utils import safe_inference


[docs] class TROP: """ Triply Robust Panel (TROP) estimator. Implements the exact methodology from Athey, Imbens, Qu & Viviano (2025). TROP combines three robustness components: 1. **Nuclear norm regularized factor model**: Estimates interactive fixed effects L_it via matrix completion with nuclear norm penalty ||L||_* 2. **Exponential distance-based unit weights**: ω_j = exp(-λ_unit × d(j,i)) where d(j,i) is the RMSE of outcome differences between units 3. **Exponential time decay weights**: θ_s = exp(-λ_time × |s-t|) weighting pre-treatment periods by proximity to treatment Tuning parameters (λ_time, λ_unit, λ_nn) are selected via leave-one-out cross-validation on control observations. Parameters ---------- method : str, default='twostep' Estimation method to use: - 'twostep': Per-observation model fitting following Algorithm 2 of Athey et al. (2025). Computes observation-specific weights and fits a model for each treated observation, averaging the individual treatment effects. More flexible but computationally intensive. - 'joint': Joint weighted least squares optimization. Estimates a single scalar treatment effect τ along with fixed effects and optional low-rank factor adjustment. Faster but assumes homogeneous treatment effects. Uses alternating minimization when nuclear norm penalty is finite. lambda_time_grid : list, optional Grid of time weight decay parameters. 0.0 = uniform weights (disabled). Must not contain inf. Default: [0, 0.1, 0.5, 1, 2, 5]. lambda_unit_grid : list, optional Grid of unit weight decay parameters. 0.0 = uniform weights (disabled). Must not contain inf. Default: [0, 0.1, 0.5, 1, 2, 5]. lambda_nn_grid : list, optional Grid of nuclear norm regularization parameters. inf = factor model disabled (L=0). Default: [0, 0.01, 0.1, 1]. max_iter : int, default=100 Maximum iterations for nuclear norm optimization. tol : float, default=1e-6 Convergence tolerance for optimization. alpha : float, default=0.05 Significance level for confidence intervals. n_bootstrap : int, default=200 Number of bootstrap replications for variance estimation. Must be >= 2. seed : int, optional Random seed for reproducibility. Attributes ---------- results_ : TROPResults Estimation results after calling fit(). is_fitted_ : bool Whether the model has been fitted. Examples -------- >>> from diff_diff import TROP >>> trop = TROP() >>> results = trop.fit( ... data, ... outcome='outcome', ... treatment='treated', ... unit='unit', ... time='period', ... ) >>> results.print_summary() References ---------- Athey, S., Imbens, G. W., Qu, Z., & Viviano, D. (2025). Triply Robust Panel Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536 """ # Class constants CONVERGENCE_TOL_SVD: float = 1e-10 """Tolerance for singular value truncation in soft-thresholding. Singular values below this threshold after soft-thresholding are treated as zero to improve numerical stability. """
[docs] def __init__( self, method: str = "twostep", lambda_time_grid: Optional[List[float]] = None, lambda_unit_grid: Optional[List[float]] = None, lambda_nn_grid: Optional[List[float]] = None, max_iter: int = 100, tol: float = 1e-6, alpha: float = 0.05, n_bootstrap: int = 200, seed: Optional[int] = None, ): # Validate method parameter valid_methods = ("twostep", "joint") if method not in valid_methods: raise ValueError( f"method must be one of {valid_methods}, got '{method}'" ) self.method = method # Default grids from paper self.lambda_time_grid = lambda_time_grid or [0.0, 0.1, 0.5, 1.0, 2.0, 5.0] self.lambda_unit_grid = lambda_unit_grid or [0.0, 0.1, 0.5, 1.0, 2.0, 5.0] self.lambda_nn_grid = lambda_nn_grid or [0.0, 0.01, 0.1, 1.0, 10.0] if n_bootstrap < 2: raise ValueError( "n_bootstrap must be >= 2 for TROP (bootstrap variance " "estimation is always used)" ) self.max_iter = max_iter self.tol = tol self.alpha = alpha self.n_bootstrap = n_bootstrap self.seed = seed # Validate that time/unit grids do not contain inf. # Per Athey et al. (2025) Eq. 3, λ_time=0 and λ_unit=0 give uniform # weights (exp(-0 × dist) = 1). Using inf is a misunderstanding of # the paper's convention. Only λ_nn=∞ is valid (disables factor model). for grid_name, grid_vals in [ ("lambda_time_grid", self.lambda_time_grid), ("lambda_unit_grid", self.lambda_unit_grid), ]: if any(np.isinf(v) for v in grid_vals): raise ValueError( f"{grid_name} must not contain inf. Use 0.0 for uniform " f"weights (disabled) per Athey et al. (2025) Eq. 3: " f"exp(-0 × dist) = 1 for all distances." ) # Internal state self.results_: Optional[TROPResults] = None self.is_fitted_: bool = False self._optimal_lambda: Optional[Tuple[float, float, float]] = None # Pre-computed structures (set during fit) self._precomputed: Optional[_PrecomputedStructures] = None
def _precompute_structures( self, Y: np.ndarray, D: np.ndarray, control_unit_idx: np.ndarray, n_units: int, n_periods: int, ) -> _PrecomputedStructures: """ Pre-compute data structures that are reused across LOOCV and estimation. This method computes once what would otherwise be computed repeatedly: - Pairwise unit distance matrix - Time distance vectors - Masks and indices Parameters ---------- Y : np.ndarray Outcome matrix (n_periods x n_units). D : np.ndarray Treatment indicator matrix (n_periods x n_units). control_unit_idx : np.ndarray Indices of control units. n_units : int Number of units. n_periods : int Number of periods. Returns ------- _PrecomputedStructures Pre-computed structures for efficient reuse. """ # Compute pairwise unit distances (for all observation-specific weights) # Following Equation 3 (page 7): RMSE between units over pre-treatment if HAS_RUST_BACKEND and _rust_unit_distance_matrix is not None: # Use Rust backend for parallel distance computation (4-8x speedup) unit_dist_matrix = _rust_unit_distance_matrix(Y, D.astype(np.float64)) else: unit_dist_matrix = self._compute_all_unit_distances(Y, D, n_units, n_periods) # Pre-compute time distance vectors for each target period # Time distance: |t - s| for all s and each target t time_dist_matrix = np.abs( np.arange(n_periods)[:, np.newaxis] - np.arange(n_periods)[np.newaxis, :] ) # (n_periods, n_periods) where [t, s] = |t - s| # Control and treatment masks control_mask = D == 0 treated_mask = D == 1 # Identify treated observations treated_observations = list(zip(*np.where(treated_mask))) # Control observations for LOOCV control_obs = [(t, i) for t in range(n_periods) for i in range(n_units) if control_mask[t, i] and not np.isnan(Y[t, i])] return { "unit_dist_matrix": unit_dist_matrix, "time_dist_matrix": time_dist_matrix, "control_mask": control_mask, "treated_mask": treated_mask, "treated_observations": treated_observations, "control_obs": control_obs, "control_unit_idx": control_unit_idx, "D": D, "Y": Y, "n_units": n_units, "n_periods": n_periods, } def _compute_all_unit_distances( self, Y: np.ndarray, D: np.ndarray, n_units: int, n_periods: int, ) -> np.ndarray: """ Compute pairwise unit distance matrix using vectorized operations. Following Equation 3 (page 7): dist_unit_{-t}(j, i) = sqrt(Σ_u (Y_{iu} - Y_{ju})² / n_valid) For efficiency, we compute a base distance matrix excluding all treated observations, which provides a good approximation. The exact per-observation distances are refined when needed. Uses vectorized numpy operations with masked arrays for O(n²) complexity but with highly optimized inner loops via numpy/BLAS. Parameters ---------- Y : np.ndarray Outcome matrix (n_periods x n_units). D : np.ndarray Treatment indicator matrix (n_periods x n_units). n_units : int Number of units. n_periods : int Number of periods. Returns ------- np.ndarray Pairwise distance matrix (n_units x n_units). """ # Mask for valid observations: control periods only (D=0), non-NaN valid_mask = (D == 0) & ~np.isnan(Y) # Replace invalid values with NaN for masked computation Y_masked = np.where(valid_mask, Y, np.nan) # Transpose to (n_units, n_periods) for easier broadcasting Y_T = Y_masked.T # (n_units, n_periods) # Compute pairwise squared differences using broadcasting # Y_T[:, np.newaxis, :] has shape (n_units, 1, n_periods) # Y_T[np.newaxis, :, :] has shape (1, n_units, n_periods) # diff has shape (n_units, n_units, n_periods) diff = Y_T[:, np.newaxis, :] - Y_T[np.newaxis, :, :] sq_diff = diff ** 2 # Count valid (non-NaN) observations per pair # A difference is valid only if both units have valid observations valid_diff = ~np.isnan(sq_diff) n_valid = np.sum(valid_diff, axis=2) # (n_units, n_units) # Compute sum of squared differences (treating NaN as 0) sq_diff_sum = np.nansum(sq_diff, axis=2) # (n_units, n_units) # Compute RMSE distance: sqrt(sum / n_valid) # Avoid division by zero with np.errstate(divide='ignore', invalid='ignore'): dist_matrix = np.sqrt(sq_diff_sum / n_valid) # Set pairs with no valid observations to inf dist_matrix = np.where(n_valid > 0, dist_matrix, np.inf) # Ensure diagonal is 0 (same unit distance) np.fill_diagonal(dist_matrix, 0.0) return dist_matrix def _compute_unit_distance_for_obs( self, Y: np.ndarray, D: np.ndarray, j: int, i: int, target_period: int, ) -> float: """ Compute observation-specific pairwise distance from unit j to unit i. This is the exact computation from Equation 3, excluding the target period. Used when the base distance matrix approximation is insufficient. Parameters ---------- Y : np.ndarray Outcome matrix (n_periods x n_units). D : np.ndarray Treatment indicator matrix. j : int Control unit index. i : int Treated unit index. target_period : int Target period to exclude. Returns ------- float Pairwise RMSE distance. """ n_periods = Y.shape[0] # Mask: exclude target period, both units must be untreated, non-NaN valid = np.ones(n_periods, dtype=bool) valid[target_period] = False valid &= (D[:, i] == 0) & (D[:, j] == 0) valid &= ~np.isnan(Y[:, i]) & ~np.isnan(Y[:, j]) if np.any(valid): sq_diffs = (Y[valid, i] - Y[valid, j]) ** 2 return np.sqrt(np.mean(sq_diffs)) else: return np.inf def _univariate_loocv_search( self, Y: np.ndarray, D: np.ndarray, control_mask: np.ndarray, control_unit_idx: np.ndarray, n_units: int, n_periods: int, param_name: str, grid: List[float], fixed_params: Dict[str, float], ) -> Tuple[float, float]: """ Search over one parameter with others fixed. Following paper's footnote 2, this performs a univariate grid search for one tuning parameter while holding others fixed. The fixed_params use 0.0 for disabled time/unit weights and _LAMBDA_INF for disabled factor model: - lambda_nn = inf: Skip nuclear norm regularization (L=0) - lambda_time = 0.0: Uniform time weights (exp(-0×dist)=1) - lambda_unit = 0.0: Uniform unit weights (exp(-0×dist)=1) Parameters ---------- Y : np.ndarray Outcome matrix (n_periods x n_units). D : np.ndarray Treatment indicator matrix (n_periods x n_units). control_mask : np.ndarray Boolean mask for control observations. control_unit_idx : np.ndarray Indices of control units. n_units : int Number of units. n_periods : int Number of periods. param_name : str Name of parameter to search: 'lambda_time', 'lambda_unit', or 'lambda_nn'. grid : List[float] Grid of values to search over. fixed_params : Dict[str, float] Fixed values for other parameters. May include _LAMBDA_INF for lambda_nn. Returns ------- Tuple[float, float] (best_value, best_score) for the searched parameter. """ best_score = np.inf best_value = grid[0] if grid else 0.0 for value in grid: params = {**fixed_params, param_name: value} lambda_time = params.get('lambda_time', 0.0) lambda_unit = params.get('lambda_unit', 0.0) lambda_nn = params.get('lambda_nn', 0.0) # Convert λ_nn=∞ → large finite value (factor model disabled, L≈0) # λ_time and λ_unit use 0.0 for uniform weights per Eq. 3 (no inf conversion needed) if np.isinf(lambda_nn): lambda_nn = 1e10 try: score = self._loocv_score_obs_specific( Y, D, control_mask, control_unit_idx, lambda_time, lambda_unit, lambda_nn, n_units, n_periods ) if score < best_score: best_score = score best_value = value except (np.linalg.LinAlgError, ValueError): continue return best_value, best_score def _cycling_parameter_search( self, Y: np.ndarray, D: np.ndarray, control_mask: np.ndarray, control_unit_idx: np.ndarray, n_units: int, n_periods: int, initial_lambda: Tuple[float, float, float], max_cycles: int = 10, ) -> Tuple[float, float, float]: """ Cycle through parameters until convergence (coordinate descent). Following paper's footnote 2 (Stage 2), this iteratively optimizes each tuning parameter while holding the others fixed, until convergence. Parameters ---------- Y : np.ndarray Outcome matrix (n_periods x n_units). D : np.ndarray Treatment indicator matrix (n_periods x n_units). control_mask : np.ndarray Boolean mask for control observations. control_unit_idx : np.ndarray Indices of control units. n_units : int Number of units. n_periods : int Number of periods. initial_lambda : Tuple[float, float, float] Initial values (lambda_time, lambda_unit, lambda_nn). max_cycles : int, default=10 Maximum number of coordinate descent cycles. Returns ------- Tuple[float, float, float] Optimized (lambda_time, lambda_unit, lambda_nn). """ lambda_time, lambda_unit, lambda_nn = initial_lambda prev_score = np.inf for cycle in range(max_cycles): # Optimize λ_unit (fix λ_time, λ_nn) lambda_unit, _ = self._univariate_loocv_search( Y, D, control_mask, control_unit_idx, n_units, n_periods, 'lambda_unit', self.lambda_unit_grid, {'lambda_time': lambda_time, 'lambda_nn': lambda_nn} ) # Optimize λ_time (fix λ_unit, λ_nn) lambda_time, _ = self._univariate_loocv_search( Y, D, control_mask, control_unit_idx, n_units, n_periods, 'lambda_time', self.lambda_time_grid, {'lambda_unit': lambda_unit, 'lambda_nn': lambda_nn} ) # Optimize λ_nn (fix λ_unit, λ_time) lambda_nn, score = self._univariate_loocv_search( Y, D, control_mask, control_unit_idx, n_units, n_periods, 'lambda_nn', self.lambda_nn_grid, {'lambda_unit': lambda_unit, 'lambda_time': lambda_time} ) # Check convergence if abs(score - prev_score) < 1e-6: logger.debug( "Cycling search converged after %d cycles with score %.6f", cycle + 1, score ) break prev_score = score return lambda_time, lambda_unit, lambda_nn # ========================================================================= # Joint estimation method # ========================================================================= def _compute_joint_weights( self, Y: np.ndarray, D: np.ndarray, lambda_time: float, lambda_unit: float, treated_periods: int, n_units: int, n_periods: int, ) -> np.ndarray: """ Compute distance-based weights for joint estimation. Following the reference implementation, weights are computed based on: - Time distance: distance to center of treated block - Unit distance: RMSE to average treated trajectory over pre-periods Parameters ---------- Y : np.ndarray Outcome matrix (n_periods x n_units). D : np.ndarray Treatment indicator matrix (n_periods x n_units). lambda_time : float Time weight decay parameter. lambda_unit : float Unit weight decay parameter. treated_periods : int Number of post-treatment periods. n_units : int Number of units. n_periods : int Number of periods. Returns ------- np.ndarray Weight matrix (n_periods x n_units). """ # Identify treated units (ever treated) treated_mask = np.any(D == 1, axis=0) treated_unit_idx = np.where(treated_mask)[0] if len(treated_unit_idx) == 0: raise ValueError("No treated units found") # Time weights: distance to center of treated block # Following reference: center = T - treated_periods/2 center = n_periods - treated_periods / 2.0 dist_time = np.abs(np.arange(n_periods, dtype=float) - center) delta_time = np.exp(-lambda_time * dist_time) # Unit weights: RMSE to average treated trajectory over pre-periods # Compute average treated trajectory (use nanmean to handle NaN) average_treated = np.nanmean(Y[:, treated_unit_idx], axis=1) # Pre-period mask: 1 in pre, 0 in post pre_mask = np.ones(n_periods, dtype=float) pre_mask[-treated_periods:] = 0.0 # Compute RMS distance for each unit # dist_unit[i] = sqrt(sum_pre(avg_tr - Y_i)^2 / n_pre) # Use NaN-safe operations: treat NaN differences as 0 (excluded) diff = average_treated[:, np.newaxis] - Y diff_sq = np.where(np.isfinite(diff), diff ** 2, 0.0) * pre_mask[:, np.newaxis] # Count valid observations per unit in pre-period # Must check diff is finite (both Y and average_treated finite) # to match the periods contributing to diff_sq valid_count = np.sum( np.isfinite(diff) * pre_mask[:, np.newaxis], axis=0 ) sum_sq = np.sum(diff_sq, axis=0) n_pre = np.sum(pre_mask) if n_pre == 0: raise ValueError("No pre-treatment periods") # Track units with no valid pre-period data no_valid_pre = valid_count == 0 # Use valid count per unit (avoid division by zero for calculation) valid_count_safe = np.maximum(valid_count, 1) dist_unit = np.sqrt(sum_sq / valid_count_safe) # Units with no valid pre-period data get zero weight # (dist is undefined, so we set it to inf -> delta_unit = exp(-inf) = 0) delta_unit = np.exp(-lambda_unit * dist_unit) delta_unit[no_valid_pre] = 0.0 # Outer product: (n_periods x n_units) delta = np.outer(delta_time, delta_unit) return delta def _loocv_score_joint( self, Y: np.ndarray, D: np.ndarray, control_obs: List[Tuple[int, int]], lambda_time: float, lambda_unit: float, lambda_nn: float, treated_periods: int, n_units: int, n_periods: int, ) -> float: """ Compute LOOCV score for joint method with specific parameter combination. Following paper's Equation 5: Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² For joint method, we exclude each control observation, fit the joint model on remaining data, and compute the pseudo-treatment effect at the excluded obs. Parameters ---------- Y : np.ndarray Outcome matrix (n_periods x n_units). D : np.ndarray Treatment indicator matrix (n_periods x n_units). control_obs : List[Tuple[int, int]] List of (t, i) control observations for LOOCV. lambda_time : float Time weight decay parameter. lambda_unit : float Unit weight decay parameter. lambda_nn : float Nuclear norm regularization parameter. treated_periods : int Number of post-treatment periods. n_units : int Number of units. n_periods : int Number of periods. Returns ------- float LOOCV score (sum of squared pseudo-treatment effects). """ # Compute global weights (same for all LOOCV iterations) delta = self._compute_joint_weights( Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods ) tau_sq_sum = 0.0 n_valid = 0 for t_ex, i_ex in control_obs: # Create modified delta with excluded observation zeroed out delta_ex = delta.copy() delta_ex[t_ex, i_ex] = 0.0 try: # Fit joint model excluding this observation if lambda_nn >= 1e10: mu, alpha, beta, tau = self._solve_joint_no_lowrank(Y, D, delta_ex) L = np.zeros((n_periods, n_units)) else: mu, alpha, beta, L, tau = self._solve_joint_with_lowrank( Y, D, delta_ex, lambda_nn, self.max_iter, self.tol ) # Pseudo treatment effect: τ = Y - μ - α - β - L if np.isfinite(Y[t_ex, i_ex]): tau_loocv = Y[t_ex, i_ex] - mu - alpha[i_ex] - beta[t_ex] - L[t_ex, i_ex] tau_sq_sum += tau_loocv ** 2 n_valid += 1 except (np.linalg.LinAlgError, ValueError): # Any failure means this λ combination is invalid per Equation 5 return np.inf if n_valid == 0: return np.inf return tau_sq_sum def _solve_joint_no_lowrank( self, Y: np.ndarray, D: np.ndarray, delta: np.ndarray, ) -> Tuple[float, np.ndarray, np.ndarray, float]: """ Solve joint TWFE + treatment via weighted least squares (no low-rank). Solves: min Σ δ_{it}(Y_{it} - μ - α_i - β_t - τ*W_{it})² Parameters ---------- Y : np.ndarray Outcome matrix (n_periods x n_units). D : np.ndarray Treatment indicator matrix (n_periods x n_units). delta : np.ndarray Weight matrix (n_periods x n_units). Returns ------- Tuple[float, np.ndarray, np.ndarray, float] (mu, alpha, beta, tau) estimated parameters. """ n_periods, n_units = Y.shape # Flatten matrices for regression y = Y.flatten() # length n_periods * n_units w = D.flatten() weights = delta.flatten() # Handle NaN values: zero weight for NaN outcomes/weights, impute with 0 # This ensures NaN observations don't contribute to estimation valid_y = np.isfinite(y) valid_w = np.isfinite(weights) valid_mask = valid_y & valid_w weights = np.where(valid_mask, weights, 0.0) y = np.where(valid_mask, y, 0.0) sqrt_weights = np.sqrt(np.maximum(weights, 0)) # Check for all-zero weights (matches Rust's sum_w < 1e-10 check) sum_w = np.sum(weights) if sum_w < 1e-10: raise ValueError("All weights are zero - cannot estimate") # Build design matrix: [intercept, unit_dummies, time_dummies, treatment] # Total columns: 1 + n_units + n_periods + 1 # But we need to drop one unit and one time dummy for identification # Drop first unit (unit 0) and first time (time 0) n_obs = n_periods * n_units n_params = 1 + (n_units - 1) + (n_periods - 1) + 1 X = np.zeros((n_obs, n_params)) X[:, 0] = 1.0 # intercept # Unit dummies (skip unit 0) for i in range(1, n_units): for t in range(n_periods): X[t * n_units + i, i] = 1.0 # Time dummies (skip time 0) for t in range(1, n_periods): for i in range(n_units): X[t * n_units + i, (n_units - 1) + t] = 1.0 # Treatment indicator X[:, -1] = w # Apply weights X_weighted = X * sqrt_weights[:, np.newaxis] y_weighted = y * sqrt_weights # Solve weighted least squares try: coeffs, _, _, _ = np.linalg.lstsq(X_weighted, y_weighted, rcond=None) except np.linalg.LinAlgError: # Fallback: use pseudo-inverse coeffs = np.dot(np.linalg.pinv(X_weighted), y_weighted) # Extract parameters mu = coeffs[0] alpha = np.zeros(n_units) alpha[1:] = coeffs[1:n_units] beta = np.zeros(n_periods) beta[1:] = coeffs[n_units:(n_units + n_periods - 1)] tau = coeffs[-1] return float(mu), alpha, beta, float(tau) def _solve_joint_with_lowrank( self, Y: np.ndarray, D: np.ndarray, delta: np.ndarray, lambda_nn: float, max_iter: int = 100, tol: float = 1e-6, ) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray, float]: """ Solve joint TWFE + treatment + low-rank via alternating minimization. Solves: min Σ δ_{it}(Y_{it} - μ - α_i - β_t - L_{it} - τ*W_{it})² + λ_nn||L||_* Parameters ---------- Y : np.ndarray Outcome matrix (n_periods x n_units). D : np.ndarray Treatment indicator matrix (n_periods x n_units). delta : np.ndarray Weight matrix (n_periods x n_units). lambda_nn : float Nuclear norm regularization parameter. max_iter : int, default=100 Maximum iterations for alternating minimization. tol : float, default=1e-6 Convergence tolerance. Returns ------- Tuple[float, np.ndarray, np.ndarray, np.ndarray, float] (mu, alpha, beta, L, tau) estimated parameters. """ n_periods, n_units = Y.shape # Handle NaN values: impute with 0 for computations # The solver will also zero weights for NaN observations Y_safe = np.where(np.isfinite(Y), Y, 0.0) # Mask delta to exclude NaN outcomes from estimation # This ensures NaN observations don't contribute to the gradient step nan_mask = ~np.isfinite(Y) delta_masked = delta.copy() delta_masked[nan_mask] = 0.0 # Initialize L = 0 L = np.zeros((n_periods, n_units)) for iteration in range(max_iter): L_old = L.copy() # Step 1: Fix L, solve for (mu, alpha, beta, tau) # Adjusted outcome: Y - L (using NaN-safe Y) # Pass masked delta to exclude NaN observations from WLS Y_adj = Y_safe - L mu, alpha, beta, tau = self._solve_joint_no_lowrank(Y_adj, D, delta_masked) # Step 2: Fix (mu, alpha, beta, tau), update L # Residual: R = Y - mu - alpha - beta - tau*D (using NaN-safe Y) R = Y_safe - mu - alpha[np.newaxis, :] - beta[:, np.newaxis] - tau * D # Weighted proximal step for L (soft-threshold SVD) # Normalize weights (using masked delta to exclude NaN observations) delta_max = np.max(delta_masked) if delta_max > 0: delta_norm = delta_masked / delta_max else: delta_norm = delta_masked # Weighted average between current L and target R # L_next = L + delta_norm * (R - L), then soft-threshold # NaN observations have delta_norm=0, so they don't influence L update gradient_step = L + delta_norm * (R - L) # Soft-threshold singular values # Use eta * lambda_nn for proper proximal step size (matches Rust) eta = 1.0 / delta_max if delta_max > 0 else 1.0 L = self._soft_threshold_svd(gradient_step, eta * lambda_nn) # Check convergence if np.max(np.abs(L - L_old)) < tol: break return mu, alpha, beta, L, tau def _fit_joint( self, data: pd.DataFrame, outcome: str, treatment: str, unit: str, time: str, ) -> TROPResults: """ Fit TROP using joint weighted least squares method. This method estimates a single scalar treatment effect τ along with fixed effects and optional low-rank factor adjustment. Parameters ---------- data : pd.DataFrame Panel data. outcome : str Outcome variable column name. treatment : str Treatment indicator column name. unit : str Unit identifier column name. time : str Time period column name. Returns ------- TROPResults Estimation results. Notes ----- Bootstrap variance estimation assumes simultaneous treatment adoption (fixed `treated_periods` across resamples). The treatment timing is inferred from the data once and held constant for all bootstrap iterations. For staggered adoption designs where treatment timing varies across units, use `method="twostep"` which computes observation-specific weights that naturally handle heterogeneous timing. """ # Data setup (same as twostep method) all_units = sorted(data[unit].unique()) all_periods = sorted(data[time].unique()) n_units = len(all_units) n_periods = len(all_periods) idx_to_unit = {i: u for i, u in enumerate(all_units)} idx_to_period = {i: p for i, p in enumerate(all_periods)} # Create matrices Y = ( data.pivot(index=time, columns=unit, values=outcome) .reindex(index=all_periods, columns=all_units) .values ) D_raw = ( data.pivot(index=time, columns=unit, values=treatment) .reindex(index=all_periods, columns=all_units) ) missing_mask = pd.isna(D_raw).values D = D_raw.fillna(0).astype(int).values # Validate absorbing state violating_units = [] for unit_idx in range(n_units): observed_mask = ~missing_mask[:, unit_idx] observed_d = D[observed_mask, unit_idx] if len(observed_d) > 1 and np.any(np.diff(observed_d) < 0): violating_units.append(all_units[unit_idx]) if violating_units: raise ValueError( f"Treatment indicator is not an absorbing state for units: {violating_units}. " f"D[t, unit] must be monotonic non-decreasing." ) # Identify treated observations treated_mask = D == 1 n_treated_obs = np.sum(treated_mask) if n_treated_obs == 0: raise ValueError("No treated observations found") # Identify treated and control units unit_ever_treated = np.any(D == 1, axis=0) treated_unit_idx = np.where(unit_ever_treated)[0] control_unit_idx = np.where(~unit_ever_treated)[0] if len(control_unit_idx) == 0: raise ValueError("No control units found") # Determine pre/post periods first_treat_period = None for t in range(n_periods): if np.any(D[t, :] == 1): first_treat_period = t break if first_treat_period is None: raise ValueError("Could not infer post-treatment periods from D matrix") n_pre_periods = first_treat_period treated_periods = n_periods - first_treat_period n_post_periods = int(np.sum(np.any(D[first_treat_period:, :] == 1, axis=1))) if n_pre_periods < 2: raise ValueError("Need at least 2 pre-treatment periods") # Check for staggered adoption (joint method requires simultaneous treatment) # Use only observed periods (skip missing) to avoid false positives on unbalanced panels first_treat_by_unit = [] for i in treated_unit_idx: observed_mask = ~missing_mask[:, i] # Get D values for observed periods only observed_d = D[observed_mask, i] observed_periods = np.where(observed_mask)[0] # Find first treatment among observed periods treated_idx = np.where(observed_d == 1)[0] if len(treated_idx) > 0: first_treat_by_unit.append(observed_periods[treated_idx[0]]) unique_starts = sorted(set(first_treat_by_unit)) if len(unique_starts) > 1: raise ValueError( f"method='joint' requires simultaneous treatment adoption, but your data " f"shows staggered adoption (units first treated at periods {unique_starts}). " f"Use method='twostep' which properly handles staggered adoption designs." ) # LOOCV grid search for tuning parameters # Use Rust backend when available for parallel LOOCV (5-10x speedup) best_lambda = None best_score = np.inf control_mask = D == 0 if HAS_RUST_BACKEND and _rust_loocv_grid_search_joint is not None: try: # Prepare inputs for Rust function control_mask_u8 = control_mask.astype(np.uint8) lambda_time_arr = np.array(self.lambda_time_grid, dtype=np.float64) lambda_unit_arr = np.array(self.lambda_unit_grid, dtype=np.float64) lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64) result = _rust_loocv_grid_search_joint( Y, D.astype(np.float64), control_mask_u8, lambda_time_arr, lambda_unit_arr, lambda_nn_arr, self.max_iter, self.tol, ) # Unpack result - 7 values including optional first_failed_obs best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed_obs = result # Only accept finite scores - infinite means all fits failed if np.isfinite(best_score): best_lambda = (best_lt, best_lu, best_ln) # Emit warnings consistent with Python implementation if n_valid == 0: obs_info = "" if first_failed_obs is not None: t_idx, i_idx = first_failed_obs obs_info = f" First failure at observation ({t_idx}, {i_idx})." warnings.warn( f"LOOCV: All {n_attempted} fits failed for " f"λ=({best_lt}, {best_lu}, {best_ln}). " f"Returning infinite score.{obs_info}", UserWarning ) elif n_attempted > 0 and (n_attempted - n_valid) > 0.1 * n_attempted: n_failed = n_attempted - n_valid obs_info = "" if first_failed_obs is not None: t_idx, i_idx = first_failed_obs obs_info = f" First failure at observation ({t_idx}, {i_idx})." warnings.warn( f"LOOCV: {n_failed}/{n_attempted} fits failed for " f"λ=({best_lt}, {best_lu}, {best_ln}). " f"This may indicate numerical instability.{obs_info}", UserWarning ) except Exception as e: # Fall back to Python implementation on error logger.debug( "Rust LOOCV grid search (joint) failed, falling back to Python: %s", e ) best_lambda = None best_score = np.inf # Fall back to Python implementation if Rust unavailable or failed if best_lambda is None: # Get control observations for LOOCV control_obs = [ (t, i) for t in range(n_periods) for i in range(n_units) if control_mask[t, i] and not np.isnan(Y[t, i]) ] # Grid search with true LOOCV for lambda_time_val in self.lambda_time_grid: for lambda_unit_val in self.lambda_unit_grid: for lambda_nn_val in self.lambda_nn_grid: # Convert λ_nn=∞ → large finite value (factor model disabled) lt = lambda_time_val lu = lambda_unit_val ln = 1e10 if np.isinf(lambda_nn_val) else lambda_nn_val try: score = self._loocv_score_joint( Y, D, control_obs, lt, lu, ln, treated_periods, n_units, n_periods ) if score < best_score: best_score = score best_lambda = (lambda_time_val, lambda_unit_val, lambda_nn_val) except (np.linalg.LinAlgError, ValueError): continue if best_lambda is None: warnings.warn( "All tuning parameter combinations failed. Using defaults.", UserWarning ) best_lambda = (1.0, 1.0, 0.1) best_score = np.nan # Final estimation with best parameters lambda_time, lambda_unit, lambda_nn = best_lambda original_lambda_nn = lambda_nn # Convert λ_nn=∞ → large finite value (factor model disabled, L≈0) # λ_time and λ_unit use 0.0 for uniform weights directly (no conversion needed) if np.isinf(lambda_nn): lambda_nn = 1e10 # Compute final weights and fit delta = self._compute_joint_weights( Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods ) if lambda_nn >= 1e10: mu, alpha, beta, tau = self._solve_joint_no_lowrank(Y, D, delta) L = np.zeros((n_periods, n_units)) else: mu, alpha, beta, L, tau = self._solve_joint_with_lowrank( Y, D, delta, lambda_nn, self.max_iter, self.tol ) # ATT is the scalar treatment effect att = tau # Compute individual treatment effects for reporting (same τ for all) treatment_effects = {} for t in range(n_periods): for i in range(n_units): if D[t, i] == 1: unit_id = idx_to_unit[i] time_id = idx_to_period[t] treatment_effects[(unit_id, time_id)] = tau # Compute effective rank of L _, s, _ = np.linalg.svd(L, full_matrices=False) if s[0] > 0: effective_rank = np.sum(s) / s[0] else: effective_rank = 0.0 # Bootstrap variance estimation effective_lambda = (lambda_time, lambda_unit, lambda_nn) se, bootstrap_dist = self._bootstrap_variance_joint( data, outcome, treatment, unit, time, effective_lambda, treated_periods ) # Compute test statistics df_trop = max(1, n_treated_obs - 1) t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha, df=df_trop) # Create results dictionaries unit_effects_dict = {idx_to_unit[i]: alpha[i] for i in range(n_units)} time_effects_dict = {idx_to_period[t]: beta[t] for t in range(n_periods)} self.results_ = TROPResults( att=float(att), se=float(se), t_stat=float(t_stat) if np.isfinite(t_stat) else t_stat, p_value=float(p_value) if np.isfinite(p_value) else p_value, conf_int=conf_int, n_obs=len(data), n_treated=len(treated_unit_idx), n_control=len(control_unit_idx), n_treated_obs=int(n_treated_obs), unit_effects=unit_effects_dict, time_effects=time_effects_dict, treatment_effects=treatment_effects, lambda_time=lambda_time, lambda_unit=lambda_unit, lambda_nn=original_lambda_nn, factor_matrix=L, effective_rank=effective_rank, loocv_score=best_score, alpha=self.alpha, n_pre_periods=n_pre_periods, n_post_periods=n_post_periods, n_bootstrap=self.n_bootstrap, bootstrap_distribution=bootstrap_dist if len(bootstrap_dist) > 0 else None, ) self.is_fitted_ = True return self.results_ def _bootstrap_variance_joint( self, data: pd.DataFrame, outcome: str, treatment: str, unit: str, time: str, optimal_lambda: Tuple[float, float, float], treated_periods: int, ) -> Tuple[float, np.ndarray]: """ Compute bootstrap standard error for joint method. Uses Rust backend when available for parallel bootstrap (5-15x speedup). Parameters ---------- data : pd.DataFrame Original data. outcome : str Outcome column name. treatment : str Treatment column name. unit : str Unit column name. time : str Time column name. optimal_lambda : tuple Optimal tuning parameters. treated_periods : int Number of post-treatment periods. Returns ------- Tuple[float, np.ndarray] (se, bootstrap_estimates). """ lambda_time, lambda_unit, lambda_nn = optimal_lambda # Try Rust backend for parallel bootstrap (5-15x speedup) if HAS_RUST_BACKEND and _rust_bootstrap_trop_variance_joint is not None: try: # Create matrices for Rust function all_units = sorted(data[unit].unique()) all_periods = sorted(data[time].unique()) Y = ( data.pivot(index=time, columns=unit, values=outcome) .reindex(index=all_periods, columns=all_units) .values ) D = ( data.pivot(index=time, columns=unit, values=treatment) .reindex(index=all_periods, columns=all_units) .fillna(0) .astype(np.float64) .values ) bootstrap_estimates, se = _rust_bootstrap_trop_variance_joint( Y, D, lambda_time, lambda_unit, lambda_nn, self.n_bootstrap, self.max_iter, self.tol, self.seed if self.seed is not None else 0 ) if len(bootstrap_estimates) < 10: warnings.warn( f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.", UserWarning ) if len(bootstrap_estimates) == 0: return 0.0, np.array([]) return float(se), np.array(bootstrap_estimates) except Exception as e: logger.debug( "Rust bootstrap (joint) failed, falling back to Python: %s", e ) # Python fallback implementation rng = np.random.default_rng(self.seed) # Stratified bootstrap sampling unit_ever_treated = data.groupby(unit)[treatment].max() treated_units = np.array(unit_ever_treated[unit_ever_treated == 1].index.tolist()) control_units = np.array(unit_ever_treated[unit_ever_treated == 0].index.tolist()) n_treated_units = len(treated_units) n_control_units = len(control_units) bootstrap_estimates_list: List[float] = [] for _ in range(self.n_bootstrap): # Stratified sampling if n_control_units > 0: sampled_control = rng.choice( control_units, size=n_control_units, replace=True ) else: sampled_control = np.array([], dtype=object) if n_treated_units > 0: sampled_treated = rng.choice( treated_units, size=n_treated_units, replace=True ) else: sampled_treated = np.array([], dtype=object) sampled_units = np.concatenate([sampled_control, sampled_treated]) # Create bootstrap sample boot_data = pd.concat([ data[data[unit] == u].assign(**{unit: f"{u}_{idx}"}) for idx, u in enumerate(sampled_units) ], ignore_index=True) try: tau = self._fit_joint_with_fixed_lambda( boot_data, outcome, treatment, unit, time, optimal_lambda, treated_periods ) bootstrap_estimates_list.append(tau) except (ValueError, np.linalg.LinAlgError, KeyError): continue bootstrap_estimates = np.array(bootstrap_estimates_list) if len(bootstrap_estimates) < 10: warnings.warn( f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.", UserWarning ) if len(bootstrap_estimates) == 0: return 0.0, np.array([]) se = np.std(bootstrap_estimates, ddof=1) return float(se), bootstrap_estimates def _fit_joint_with_fixed_lambda( self, data: pd.DataFrame, outcome: str, treatment: str, unit: str, time: str, fixed_lambda: Tuple[float, float, float], treated_periods: int, ) -> float: """ Fit joint model with fixed tuning parameters. Returns only the treatment effect τ. """ lambda_time, lambda_unit, lambda_nn = fixed_lambda all_units = sorted(data[unit].unique()) all_periods = sorted(data[time].unique()) n_units = len(all_units) n_periods = len(all_periods) Y = ( data.pivot(index=time, columns=unit, values=outcome) .reindex(index=all_periods, columns=all_units) .values ) D = ( data.pivot(index=time, columns=unit, values=treatment) .reindex(index=all_periods, columns=all_units) .fillna(0) .astype(int) .values ) # Compute weights delta = self._compute_joint_weights( Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods ) # Fit model if lambda_nn >= 1e10: _, _, _, tau = self._solve_joint_no_lowrank(Y, D, delta) else: _, _, _, _, tau = self._solve_joint_with_lowrank( Y, D, delta, lambda_nn, self.max_iter, self.tol ) return tau
[docs] def fit( self, data: pd.DataFrame, outcome: str, treatment: str, unit: str, time: str, ) -> TROPResults: """ Fit the TROP model. Parameters ---------- data : pd.DataFrame Panel data with observations for multiple units over multiple time periods. outcome : str Name of the outcome variable column. treatment : str Name of the treatment indicator column (0/1). IMPORTANT: This should be an ABSORBING STATE indicator, not a treatment timing indicator. For each unit, D=1 for ALL periods during and after treatment: - D[t, i] = 0 for all t < g_i (pre-treatment periods) - D[t, i] = 1 for all t >= g_i (treatment and post-treatment) where g_i is the treatment start time for unit i. For staggered adoption, different units can have different g_i. The ATT averages over ALL D=1 cells per Equation 1 of the paper. unit : str Name of the unit identifier column. time : str Name of the time period column. Returns ------- TROPResults Object containing the ATT estimate, standard error, factor estimates, and tuning parameters. The lambda_* attributes show the selected grid values. For λ_time and λ_unit, 0.0 means uniform weights; inf is not accepted. For λ_nn, ∞ is converted to 1e10 (factor model disabled). """ # Validate inputs required_cols = [outcome, treatment, unit, time] missing = [c for c in required_cols if c not in data.columns] if missing: raise ValueError(f"Missing columns: {missing}") # Dispatch based on estimation method if self.method == "joint": return self._fit_joint(data, outcome, treatment, unit, time) # Below is the twostep method (default) # Get unique units and periods all_units = sorted(data[unit].unique()) all_periods = sorted(data[time].unique()) n_units = len(all_units) n_periods = len(all_periods) # Create mappings unit_to_idx = {u: i for i, u in enumerate(all_units)} period_to_idx = {p: i for i, p in enumerate(all_periods)} idx_to_unit = {i: u for u, i in unit_to_idx.items()} idx_to_period = {i: p for p, i in period_to_idx.items()} # Create outcome matrix Y (n_periods x n_units) and treatment matrix D # Vectorized: use pivot for O(1) reshaping instead of O(n) iterrows loop Y = ( data.pivot(index=time, columns=unit, values=outcome) .reindex(index=all_periods, columns=all_units) .values ) # For D matrix, track missing values BEFORE fillna to support unbalanced panels # Issue 3 fix: Missing observations should not trigger spurious violations D_raw = ( data.pivot(index=time, columns=unit, values=treatment) .reindex(index=all_periods, columns=all_units) ) missing_mask = pd.isna(D_raw).values # True where originally missing D = D_raw.fillna(0).astype(int).values # Validate D is monotonic non-decreasing per unit (absorbing state) # D[t, i] must satisfy: once D=1, it must stay 1 for all subsequent periods # Issue 3 fix (round 10): Check each unit's OBSERVED D sequence for monotonicity # This catches 1→0 violations that span missing period gaps # Example: D[2]=1, missing [3,4], D[5]=0 is a real violation even though # adjacent period transitions don't show it (the gap hides the transition) violating_units = [] for unit_idx in range(n_units): # Get observed D values for this unit (where not missing) observed_mask = ~missing_mask[:, unit_idx] observed_d = D[observed_mask, unit_idx] # Check if observed sequence is monotonically non-decreasing if len(observed_d) > 1 and np.any(np.diff(observed_d) < 0): violating_units.append(all_units[unit_idx]) if violating_units: raise ValueError( f"Treatment indicator is not an absorbing state for units: {violating_units}. " f"D[t, unit] must be monotonic non-decreasing (once treated, always treated). " f"If this is event-study style data, convert to absorbing state: " f"D[t, i] = 1 for all t >= first treatment period." ) # Identify treated observations treated_mask = D == 1 n_treated_obs = np.sum(treated_mask) if n_treated_obs == 0: raise ValueError("No treated observations found") # Identify treated and control units unit_ever_treated = np.any(D == 1, axis=0) treated_unit_idx = np.where(unit_ever_treated)[0] control_unit_idx = np.where(~unit_ever_treated)[0] if len(control_unit_idx) == 0: raise ValueError("No control units found") # Determine pre/post periods from treatment indicator D # D matrix is the sole input for treatment timing per the paper first_treat_period = None for t in range(n_periods): if np.any(D[t, :] == 1): first_treat_period = t break if first_treat_period is None: raise ValueError("Could not infer post-treatment periods from D matrix") n_pre_periods = first_treat_period # Count periods where D=1 is actually observed (matches docstring) # Per docstring: "Number of post-treatment periods (periods with D=1 observations)" n_post_periods = int(np.sum(np.any(D[first_treat_period:, :] == 1, axis=1))) if n_pre_periods < 2: raise ValueError("Need at least 2 pre-treatment periods") # Step 1: Grid search with LOOCV for tuning parameters best_lambda = None best_score = np.inf # Control observations mask (for LOOCV) control_mask = D == 0 # Pre-compute structures that are reused across LOOCV iterations self._precomputed = self._precompute_structures( Y, D, control_unit_idx, n_units, n_periods ) # Use Rust backend for parallel LOOCV grid search (10-50x speedup) if HAS_RUST_BACKEND and _rust_loocv_grid_search is not None: try: # Prepare inputs for Rust function control_mask_u8 = control_mask.astype(np.uint8) time_dist_matrix = self._precomputed["time_dist_matrix"].astype(np.int64) lambda_time_arr = np.array(self.lambda_time_grid, dtype=np.float64) lambda_unit_arr = np.array(self.lambda_unit_grid, dtype=np.float64) lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64) result = _rust_loocv_grid_search( Y, D.astype(np.float64), control_mask_u8, time_dist_matrix, lambda_time_arr, lambda_unit_arr, lambda_nn_arr, self.max_iter, self.tol, ) # Unpack result - 7 values including optional first_failed_obs best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed_obs = result # Only accept finite scores - infinite means all fits failed if np.isfinite(best_score): best_lambda = (best_lt, best_lu, best_ln) # else: best_lambda stays None, triggering defaults fallback # Emit warnings consistent with Python implementation if n_valid == 0: # Include failed observation coordinates if available (Issue 2 fix) obs_info = "" if first_failed_obs is not None: t_idx, i_idx = first_failed_obs obs_info = f" First failure at observation ({t_idx}, {i_idx})." warnings.warn( f"LOOCV: All {n_attempted} fits failed for " f"λ=({best_lt}, {best_lu}, {best_ln}). " f"Returning infinite score.{obs_info}", UserWarning ) elif n_attempted > 0 and (n_attempted - n_valid) > 0.1 * n_attempted: n_failed = n_attempted - n_valid # Include failed observation coordinates if available obs_info = "" if first_failed_obs is not None: t_idx, i_idx = first_failed_obs obs_info = f" First failure at observation ({t_idx}, {i_idx})." warnings.warn( f"LOOCV: {n_failed}/{n_attempted} fits failed for " f"λ=({best_lt}, {best_lu}, {best_ln}). " f"This may indicate numerical instability.{obs_info}", UserWarning ) except Exception as e: # Fall back to Python implementation on error logger.debug( "Rust LOOCV grid search failed, falling back to Python: %s", e ) best_lambda = None best_score = np.inf # Fall back to Python implementation if Rust unavailable or failed # Uses two-stage approach per paper's footnote 2: # Stage 1: Univariate searches for initial values # Stage 2: Cycling (coordinate descent) until convergence if best_lambda is None: # Stage 1: Univariate searches with extreme fixed values # Following paper's footnote 2 for initial bounds # λ_time search: fix λ_unit=0, λ_nn=∞ (disabled - no factor adjustment) lambda_time_init, _ = self._univariate_loocv_search( Y, D, control_mask, control_unit_idx, n_units, n_periods, 'lambda_time', self.lambda_time_grid, {'lambda_unit': 0.0, 'lambda_nn': _LAMBDA_INF} ) # λ_nn search: fix λ_time=0 (uniform time weights), λ_unit=0 lambda_nn_init, _ = self._univariate_loocv_search( Y, D, control_mask, control_unit_idx, n_units, n_periods, 'lambda_nn', self.lambda_nn_grid, {'lambda_time': 0.0, 'lambda_unit': 0.0} ) # λ_unit search: fix λ_nn=∞, λ_time=0 lambda_unit_init, _ = self._univariate_loocv_search( Y, D, control_mask, control_unit_idx, n_units, n_periods, 'lambda_unit', self.lambda_unit_grid, {'lambda_nn': _LAMBDA_INF, 'lambda_time': 0.0} ) # Stage 2: Cycling refinement (coordinate descent) lambda_time, lambda_unit, lambda_nn = self._cycling_parameter_search( Y, D, control_mask, control_unit_idx, n_units, n_periods, (lambda_time_init, lambda_unit_init, lambda_nn_init) ) # Compute final score for the optimized parameters try: best_score = self._loocv_score_obs_specific( Y, D, control_mask, control_unit_idx, lambda_time, lambda_unit, lambda_nn, n_units, n_periods ) # Only accept finite scores - infinite means all fits failed if np.isfinite(best_score): best_lambda = (lambda_time, lambda_unit, lambda_nn) # else: best_lambda stays None, triggering defaults fallback except (np.linalg.LinAlgError, ValueError): # If even the optimized parameters fail, best_lambda stays None pass if best_lambda is None: warnings.warn( "All tuning parameter combinations failed. Using defaults.", UserWarning ) best_lambda = (1.0, 1.0, 0.1) best_score = np.nan self._optimal_lambda = best_lambda lambda_time, lambda_unit, lambda_nn = best_lambda # Store original λ_nn for results (only λ_nn needs original→effective conversion). # λ_time and λ_unit use 0.0 for uniform weights directly per Eq. 3. original_lambda_nn = lambda_nn # Convert λ_nn=∞ → large finite value (factor model disabled, L≈0) if np.isinf(lambda_nn): lambda_nn = 1e10 # effective_lambda with converted λ_nn for ALL downstream computation # (variance estimation uses the same parameters as point estimation) effective_lambda = (lambda_time, lambda_unit, lambda_nn) # Step 2: Final estimation - per-observation model fitting following Algorithm 2 # For each treated (i,t): compute observation-specific weights, fit model, compute τ̂_{it} treatment_effects = {} tau_values = [] alpha_estimates = [] beta_estimates = [] L_estimates = [] # Use pre-computed treated observations treated_observations = self._precomputed["treated_observations"] for t, i in treated_observations: # Compute observation-specific weights for this (i, t) weight_matrix = self._compute_observation_weights( Y, D, i, t, lambda_time, lambda_unit, control_unit_idx, n_units, n_periods ) # Fit model with these weights alpha_hat, beta_hat, L_hat = self._estimate_model( Y, control_mask, weight_matrix, lambda_nn, n_units, n_periods ) # Compute treatment effect: τ̂_{it} = Y_{it} - α̂_i - β̂_t - L̂_{it} tau_it = Y[t, i] - alpha_hat[i] - beta_hat[t] - L_hat[t, i] unit_id = idx_to_unit[i] time_id = idx_to_period[t] treatment_effects[(unit_id, time_id)] = tau_it tau_values.append(tau_it) # Store for averaging alpha_estimates.append(alpha_hat) beta_estimates.append(beta_hat) L_estimates.append(L_hat) # Average ATT att = np.mean(tau_values) # Average parameter estimates for output (representative) alpha_hat = np.mean(alpha_estimates, axis=0) if alpha_estimates else np.zeros(n_units) beta_hat = np.mean(beta_estimates, axis=0) if beta_estimates else np.zeros(n_periods) L_hat = np.mean(L_estimates, axis=0) if L_estimates else np.zeros((n_periods, n_units)) # Compute effective rank _, s, _ = np.linalg.svd(L_hat, full_matrices=False) if s[0] > 0: effective_rank = np.sum(s) / s[0] else: effective_rank = 0.0 # Step 4: Variance estimation # Use effective_lambda (converted values) to ensure SE is computed with same # parameters as point estimation. This fixes the variance inconsistency issue. se, bootstrap_dist = self._bootstrap_variance( data, outcome, treatment, unit, time, effective_lambda, Y=Y, D=D, control_unit_idx=control_unit_idx ) # Compute test statistics df_trop = max(1, n_treated_obs - 1) t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha, df=df_trop) # Create results dictionaries unit_effects_dict = {idx_to_unit[i]: alpha_hat[i] for i in range(n_units)} time_effects_dict = {idx_to_period[t]: beta_hat[t] for t in range(n_periods)} # Store results self.results_ = TROPResults( att=att, se=se, t_stat=t_stat, p_value=p_value, conf_int=conf_int, n_obs=len(data), n_treated=len(treated_unit_idx), n_control=len(control_unit_idx), n_treated_obs=n_treated_obs, unit_effects=unit_effects_dict, time_effects=time_effects_dict, treatment_effects=treatment_effects, lambda_time=lambda_time, lambda_unit=lambda_unit, lambda_nn=original_lambda_nn, factor_matrix=L_hat, effective_rank=effective_rank, loocv_score=best_score, alpha=self.alpha, n_pre_periods=n_pre_periods, n_post_periods=n_post_periods, n_bootstrap=self.n_bootstrap, bootstrap_distribution=bootstrap_dist if len(bootstrap_dist) > 0 else None, ) self.is_fitted_ = True return self.results_
def _compute_observation_weights( self, Y: np.ndarray, D: np.ndarray, i: int, t: int, lambda_time: float, lambda_unit: float, control_unit_idx: np.ndarray, n_units: int, n_periods: int, ) -> np.ndarray: """ Compute observation-specific weight matrix for treated observation (i, t). Following the paper's Algorithm 2 (page 27) and Equation 2 (page 7): - Time weights θ_s^{i,t} = exp(-λ_time × |t - s|) - Unit weights ω_j^{i,t} = exp(-λ_unit × dist_unit_{-t}(j, i)) IMPORTANT (Issue A fix): The paper's objective sums over ALL observations where (1 - W_js) is non-zero, which includes pre-treatment observations of eventually-treated units since W_js = 0 for those. This method computes weights for ALL units where D[t, j] = 0 at the target period, not just never-treated units. Uses pre-computed structures when available for efficiency. Parameters ---------- Y : np.ndarray Outcome matrix (n_periods x n_units). D : np.ndarray Treatment indicator matrix (n_periods x n_units). i : int Treated unit index. t : int Treatment period index. lambda_time : float Time weight decay parameter. lambda_unit : float Unit weight decay parameter. control_unit_idx : np.ndarray Indices of never-treated units (for backward compatibility, but not used for weight computation - we use D matrix directly). n_units : int Number of units. n_periods : int Number of periods. Returns ------- np.ndarray Weight matrix (n_periods x n_units) for observation (i, t). """ # Use pre-computed structures when available if self._precomputed is not None: # Time weights from pre-computed time distance matrix # time_dist_matrix[t, s] = |t - s| time_weights = np.exp(-lambda_time * self._precomputed["time_dist_matrix"][t, :]) # Unit weights - computed for ALL units where D[t, j] = 0 # (Issue A fix: includes pre-treatment obs of eventually-treated units) unit_weights = np.zeros(n_units) D_stored = self._precomputed["D"] Y_stored = self._precomputed["Y"] # Valid control units at time t: D[t, j] == 0 valid_control_at_t = D_stored[t, :] == 0 if lambda_unit == 0: # Uniform weights when lambda_unit = 0 # All units not treated at time t get weight 1 unit_weights[valid_control_at_t] = 1.0 else: # Use observation-specific distances with target period excluded # (Issue B fix: compute exact per-observation distance) for j in range(n_units): if valid_control_at_t[j] and j != i: # Compute distance excluding target period t dist = self._compute_unit_distance_for_obs(Y_stored, D_stored, j, i, t) if np.isinf(dist): unit_weights[j] = 0.0 else: unit_weights[j] = np.exp(-lambda_unit * dist) # Treated unit i gets weight 1 unit_weights[i] = 1.0 # Weight matrix: outer product (n_periods x n_units) return np.outer(time_weights, unit_weights) # Fallback: compute from scratch (used in bootstrap) # Time distance: |t - s| following paper's Equation 3 (page 7) dist_time = np.abs(np.arange(n_periods) - t) time_weights = np.exp(-lambda_time * dist_time) # Unit weights - computed for ALL units where D[t, j] = 0 # (Issue A fix: includes pre-treatment obs of eventually-treated units) unit_weights = np.zeros(n_units) # Valid control units at time t: D[t, j] == 0 valid_control_at_t = D[t, :] == 0 if lambda_unit == 0: # Uniform weights when lambda_unit = 0 unit_weights[valid_control_at_t] = 1.0 else: for j in range(n_units): if valid_control_at_t[j] and j != i: # Compute distance excluding target period t (Issue B fix) dist = self._compute_unit_distance_for_obs(Y, D, j, i, t) if np.isinf(dist): unit_weights[j] = 0.0 else: unit_weights[j] = np.exp(-lambda_unit * dist) # Treated unit i gets weight 1 (or could be omitted since we fit on controls) # We include treated unit's own observation for model fitting unit_weights[i] = 1.0 # Weight matrix: outer product (n_periods x n_units) W = np.outer(time_weights, unit_weights) return W def _soft_threshold_svd( self, M: np.ndarray, threshold: float, ) -> np.ndarray: """ Apply soft-thresholding to singular values (proximal operator for nuclear norm). Parameters ---------- M : np.ndarray Input matrix. threshold : float Soft-thresholding parameter. Returns ------- np.ndarray Matrix with soft-thresholded singular values. """ if threshold <= 0: return M # Handle NaN/Inf values in input if not np.isfinite(M).all(): M = np.nan_to_num(M, nan=0.0, posinf=0.0, neginf=0.0) try: U, s, Vt = np.linalg.svd(M, full_matrices=False) except np.linalg.LinAlgError: # SVD failed, return zero matrix return np.zeros_like(M) # Check for numerical issues in SVD output if not (np.isfinite(U).all() and np.isfinite(s).all() and np.isfinite(Vt).all()): # SVD produced non-finite values, return zero matrix return np.zeros_like(M) s_thresh = np.maximum(s - threshold, 0) # Use truncated reconstruction with only non-zero singular values nonzero_mask = s_thresh > self.CONVERGENCE_TOL_SVD if not np.any(nonzero_mask): return np.zeros_like(M) # Truncate to non-zero components for numerical stability U_trunc = U[:, nonzero_mask] s_trunc = s_thresh[nonzero_mask] Vt_trunc = Vt[nonzero_mask, :] # Compute result, suppressing expected numerical warnings from # ill-conditioned matrices during alternating minimization with np.errstate(divide='ignore', over='ignore', invalid='ignore'): result = (U_trunc * s_trunc) @ Vt_trunc # Replace any NaN/Inf in result with zeros if not np.isfinite(result).all(): result = np.nan_to_num(result, nan=0.0, posinf=0.0, neginf=0.0) return result def _weighted_nuclear_norm_solve( self, Y: np.ndarray, W: np.ndarray, L_init: np.ndarray, alpha: np.ndarray, beta: np.ndarray, lambda_nn: float, max_inner_iter: int = 20, ) -> np.ndarray: """ Solve weighted nuclear norm problem using iterative weighted soft-impute. Issue C fix: Implements the weighted nuclear norm optimization from the paper's Equation 2 (page 7). The full objective is: min_L Σ W_{ti}(R_{ti} - L_{ti})² + λ_nn||L||_* This uses a proximal gradient / soft-impute approach (Mazumder et al. 2010): L_{k+1} = prox_{λ||·||_*}(L_k + W ⊙ (R - L_k)) where W ⊙ denotes element-wise multiplication with normalized weights. IMPORTANT: For observations with W=0 (treated observations), we keep L values from the previous iteration rather than setting L = R, which would absorb the treatment effect. Parameters ---------- Y : np.ndarray Outcome matrix (n_periods x n_units). W : np.ndarray Weight matrix (n_periods x n_units), non-negative. W=0 indicates observations that should not be used for fitting (treated obs). L_init : np.ndarray Initial estimate of L matrix. alpha : np.ndarray Current unit fixed effects estimate. beta : np.ndarray Current time fixed effects estimate. lambda_nn : float Nuclear norm regularization parameter. max_inner_iter : int, default=20 Maximum inner iterations for the proximal algorithm. Returns ------- np.ndarray Updated L matrix estimate. """ # Compute target residual R = Y - α - β R = Y - alpha[np.newaxis, :] - beta[:, np.newaxis] # Handle invalid values R = np.nan_to_num(R, nan=0.0, posinf=0.0, neginf=0.0) # For observations with W=0 (treated obs), keep L_init instead of R # This prevents L from absorbing the treatment effect valid_obs_mask = W > 0 R_masked = np.where(valid_obs_mask, R, L_init) if lambda_nn <= 0: # No regularization - just return masked residual # Use soft-thresholding with threshold=0 which returns the input return R_masked # Normalize weights so max is 1 (for step size stability) W_max = np.max(W) if W_max > 0: W_norm = W / W_max else: W_norm = W # Initialize L L = L_init.copy() # Proximal gradient iteration with weighted soft-impute # This solves: min_L ||W^{1/2} ⊙ (R - L)||_F^2 + λ||L||_* # Using: L_{k+1} = prox_{λ/η}(L_k + W ⊙ (R - L_k)) # where η is the step size (we use η = 1 with normalized weights) for _ in range(max_inner_iter): L_old = L.copy() # Gradient step: L_k + W ⊙ (R - L_k) # For W=0 observations, this keeps L_k unchanged gradient_step = L + W_norm * (R_masked - L) # Proximal step: soft-threshold singular values L = self._soft_threshold_svd(gradient_step, lambda_nn) # Check convergence if np.max(np.abs(L - L_old)) < self.tol: break return L def _estimate_model( self, Y: np.ndarray, control_mask: np.ndarray, weight_matrix: np.ndarray, lambda_nn: float, n_units: int, n_periods: int, exclude_obs: Optional[Tuple[int, int]] = None, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Estimate the model: Y = α + β + L + τD + ε with nuclear norm penalty on L. Uses alternating minimization with vectorized operations: 1. Fix L, solve for α, β via weighted means 2. Fix α, β, solve for L via soft-thresholding Parameters ---------- Y : np.ndarray Outcome matrix (n_periods x n_units). control_mask : np.ndarray Boolean mask for control observations. weight_matrix : np.ndarray Pre-computed global weight matrix (n_periods x n_units). lambda_nn : float Nuclear norm regularization parameter. n_units : int Number of units. n_periods : int Number of periods. exclude_obs : tuple, optional (t, i) observation to exclude (for LOOCV). Returns ------- tuple (alpha, beta, L) estimated parameters. """ W = weight_matrix # Mask for estimation (control obs only, excluding LOOCV obs if specified) est_mask = control_mask.copy() if exclude_obs is not None: t_ex, i_ex = exclude_obs est_mask[t_ex, i_ex] = False # Handle missing values valid_mask = ~np.isnan(Y) & est_mask # Initialize alpha = np.zeros(n_units) beta = np.zeros(n_periods) L = np.zeros((n_periods, n_units)) # Pre-compute masked weights for vectorized operations # Set weights to 0 where not valid W_masked = W * valid_mask # Pre-compute weight sums per unit and per time (for denominator) # shape: (n_units,) and (n_periods,) weight_sum_per_unit = np.sum(W_masked, axis=0) # sum over periods weight_sum_per_time = np.sum(W_masked, axis=1) # sum over units # Handle units/periods with zero weight sum unit_has_obs = weight_sum_per_unit > 0 time_has_obs = weight_sum_per_time > 0 # Create safe denominators (avoid division by zero) safe_unit_denom = np.where(unit_has_obs, weight_sum_per_unit, 1.0) safe_time_denom = np.where(time_has_obs, weight_sum_per_time, 1.0) # Replace NaN in Y with 0 for computation (mask handles exclusion) Y_safe = np.where(np.isnan(Y), 0.0, Y) # Alternating minimization following Algorithm 1 (page 9) # Minimize: Σ W_{ti}(Y_{ti} - α_i - β_t - L_{ti})² + λ_nn||L||_* for _ in range(self.max_iter): alpha_old = alpha.copy() beta_old = beta.copy() L_old = L.copy() # Step 1: Update α and β (weighted least squares) # Following Equation 2 (page 7), fix L and solve for α, β # R = Y - L (residual without fixed effects) R = Y_safe - L # Alpha update (unit fixed effects): # α_i = argmin_α Σ_t W_{ti}(R_{ti} - α - β_t)² # Solution: α_i = Σ_t W_{ti}(R_{ti} - β_t) / Σ_t W_{ti} R_minus_beta = R - beta[:, np.newaxis] # (n_periods, n_units) weighted_R_minus_beta = W_masked * R_minus_beta alpha_numerator = np.sum(weighted_R_minus_beta, axis=0) # (n_units,) alpha = np.where(unit_has_obs, alpha_numerator / safe_unit_denom, 0.0) # Beta update (time fixed effects): # β_t = argmin_β Σ_i W_{ti}(R_{ti} - α_i - β)² # Solution: β_t = Σ_i W_{ti}(R_{ti} - α_i) / Σ_i W_{ti} R_minus_alpha = R - alpha[np.newaxis, :] # (n_periods, n_units) weighted_R_minus_alpha = W_masked * R_minus_alpha beta_numerator = np.sum(weighted_R_minus_alpha, axis=1) # (n_periods,) beta = np.where(time_has_obs, beta_numerator / safe_time_denom, 0.0) # Step 2: Update L with weighted nuclear norm penalty # Issue C fix: Use weighted soft-impute to properly account for # observation weights in the nuclear norm optimization. # Following Equation 2 (page 7): min_L Σ W_{ti}(Y - α - β - L)² + λ||L||_* L = self._weighted_nuclear_norm_solve( Y_safe, W_masked, L, alpha, beta, lambda_nn, max_inner_iter=10 ) # Check convergence alpha_diff = np.max(np.abs(alpha - alpha_old)) beta_diff = np.max(np.abs(beta - beta_old)) L_diff = np.max(np.abs(L - L_old)) if max(alpha_diff, beta_diff, L_diff) < self.tol: break return alpha, beta, L def _loocv_score_obs_specific( self, Y: np.ndarray, D: np.ndarray, control_mask: np.ndarray, control_unit_idx: np.ndarray, lambda_time: float, lambda_unit: float, lambda_nn: float, n_units: int, n_periods: int, ) -> float: """ Compute leave-one-out cross-validation score with observation-specific weights. Following the paper's Equation 5 (page 8): Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² For each control observation (j, s), treat it as pseudo-treated, compute observation-specific weights, fit model excluding (j, s), and sum squared pseudo-treatment effects. Uses pre-computed structures when available for efficiency. Parameters ---------- Y : np.ndarray Outcome matrix (n_periods x n_units). D : np.ndarray Treatment indicator matrix (n_periods x n_units). control_mask : np.ndarray Boolean mask for control observations. control_unit_idx : np.ndarray Indices of control units. lambda_time : float Time weight decay parameter. lambda_unit : float Unit weight decay parameter. lambda_nn : float Nuclear norm regularization parameter. n_units : int Number of units. n_periods : int Number of periods. Returns ------- float LOOCV score (lower is better). """ # Use pre-computed control observations if available if self._precomputed is not None: control_obs = self._precomputed["control_obs"] else: # Get all control observations control_obs = [(t, i) for t in range(n_periods) for i in range(n_units) if control_mask[t, i] and not np.isnan(Y[t, i])] # Empty control set check: if no control observations, return infinity # A score of 0.0 would incorrectly "win" over legitimate parameters if len(control_obs) == 0: warnings.warn( f"LOOCV: No valid control observations for " f"λ=({lambda_time}, {lambda_unit}, {lambda_nn}). " "Returning infinite score.", UserWarning ) return np.inf tau_squared_sum = 0.0 n_valid = 0 for t, i in control_obs: try: # Compute observation-specific weights for pseudo-treated (i, t) # Uses pre-computed distance matrices when available weight_matrix = self._compute_observation_weights( Y, D, i, t, lambda_time, lambda_unit, control_unit_idx, n_units, n_periods ) # Estimate model excluding observation (t, i) alpha, beta, L = self._estimate_model( Y, control_mask, weight_matrix, lambda_nn, n_units, n_periods, exclude_obs=(t, i) ) # Pseudo treatment effect tau_ti = Y[t, i] - alpha[i] - beta[t] - L[t, i] tau_squared_sum += tau_ti ** 2 n_valid += 1 except (np.linalg.LinAlgError, ValueError): # Per Equation 5: Q(λ) must sum over ALL D==0 cells # Any failure means this λ cannot produce valid estimates for all cells warnings.warn( f"LOOCV: Fit failed for observation ({t}, {i}) with " f"λ=({lambda_time}, {lambda_unit}, {lambda_nn}). " "Returning infinite score per Equation 5.", UserWarning ) return np.inf # Return SUM of squared pseudo-treatment effects per Equation 5 (page 8): # Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² return tau_squared_sum def _bootstrap_variance( self, data: pd.DataFrame, outcome: str, treatment: str, unit: str, time: str, optimal_lambda: Tuple[float, float, float], Y: Optional[np.ndarray] = None, D: Optional[np.ndarray] = None, control_unit_idx: Optional[np.ndarray] = None, ) -> Tuple[float, np.ndarray]: """ Compute bootstrap standard error using unit-level block bootstrap. When the optional Rust backend is available and the matrix parameters (Y, D, control_unit_idx) are provided, uses parallelized Rust implementation for 5-15x speedup. Falls back to Python implementation if Rust is unavailable or if matrix parameters are not provided. Parameters ---------- data : pd.DataFrame Original data in long format with unit, time, outcome, and treatment. outcome : str Name of the outcome column in data. treatment : str Name of the treatment indicator column in data. unit : str Name of the unit identifier column in data. time : str Name of the time period column in data. optimal_lambda : tuple of float Optimal tuning parameters (lambda_time, lambda_unit, lambda_nn) from cross-validation. Used for model estimation in each bootstrap. Y : np.ndarray, optional Outcome matrix of shape (n_periods, n_units). Required for Rust backend acceleration. If None, falls back to Python implementation. D : np.ndarray, optional Treatment indicator matrix of shape (n_periods, n_units) where D[t,i]=1 indicates unit i is treated at time t. Required for Rust backend acceleration. control_unit_idx : np.ndarray, optional Array of indices for control units (never-treated). Required for Rust backend acceleration. Returns ------- se : float Bootstrap standard error of the ATT estimate. bootstrap_estimates : np.ndarray Array of ATT estimates from each bootstrap iteration. Length may be less than n_bootstrap if some iterations failed. Notes ----- Uses unit-level block bootstrap where entire unit time series are resampled with replacement. This preserves within-unit correlation structure and is appropriate for panel data. """ lambda_time, lambda_unit, lambda_nn = optimal_lambda # Try Rust backend for parallel bootstrap (5-15x speedup) if (HAS_RUST_BACKEND and _rust_bootstrap_trop_variance is not None and self._precomputed is not None and Y is not None and D is not None): try: control_mask = self._precomputed["control_mask"] time_dist_matrix = self._precomputed["time_dist_matrix"].astype(np.int64) bootstrap_estimates, se = _rust_bootstrap_trop_variance( Y, D.astype(np.float64), control_mask.astype(np.uint8), time_dist_matrix, lambda_time, lambda_unit, lambda_nn, self.n_bootstrap, self.max_iter, self.tol, self.seed if self.seed is not None else 0 ) if len(bootstrap_estimates) >= 10: return float(se), bootstrap_estimates # Fall through to Python if too few bootstrap samples logger.debug( "Rust bootstrap returned only %d samples, falling back to Python", len(bootstrap_estimates) ) except Exception as e: logger.debug( "Rust bootstrap variance failed, falling back to Python: %s", e ) # Python implementation (fallback) rng = np.random.default_rng(self.seed) # Issue D fix: Stratified bootstrap sampling # Paper's Algorithm 3 (page 27) specifies sampling N_0 control rows # and N_1 treated rows separately to preserve treatment ratio unit_ever_treated = data.groupby(unit)[treatment].max() treated_units = np.array(unit_ever_treated[unit_ever_treated == 1].index) control_units = np.array(unit_ever_treated[unit_ever_treated == 0].index) n_treated_units = len(treated_units) n_control_units = len(control_units) bootstrap_estimates_list = [] for _ in range(self.n_bootstrap): # Stratified sampling: sample control and treated units separately # This preserves the treatment ratio in each bootstrap sample if n_control_units > 0: sampled_control = rng.choice( control_units, size=n_control_units, replace=True ) else: sampled_control = np.array([], dtype=control_units.dtype) if n_treated_units > 0: sampled_treated = rng.choice( treated_units, size=n_treated_units, replace=True ) else: sampled_treated = np.array([], dtype=treated_units.dtype) # Combine stratified samples sampled_units = np.concatenate([sampled_control, sampled_treated]) # Create bootstrap sample with unique unit IDs boot_data = pd.concat([ data[data[unit] == u].assign(**{unit: f"{u}_{idx}"}) for idx, u in enumerate(sampled_units) ], ignore_index=True) try: # Fit with fixed lambda (skip LOOCV for speed) att = self._fit_with_fixed_lambda( boot_data, outcome, treatment, unit, time, optimal_lambda ) bootstrap_estimates_list.append(att) except (ValueError, np.linalg.LinAlgError, KeyError): continue bootstrap_estimates = np.array(bootstrap_estimates_list) if len(bootstrap_estimates) < 10: warnings.warn( f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded. " "Standard errors may be unreliable.", UserWarning ) if len(bootstrap_estimates) == 0: return 0.0, np.array([]) se = np.std(bootstrap_estimates, ddof=1) return float(se), bootstrap_estimates def _fit_with_fixed_lambda( self, data: pd.DataFrame, outcome: str, treatment: str, unit: str, time: str, fixed_lambda: Tuple[float, float, float], ) -> float: """ Fit model with fixed tuning parameters (for bootstrap). Uses observation-specific weights following Algorithm 2. Returns only the ATT estimate. """ lambda_time, lambda_unit, lambda_nn = fixed_lambda # Setup matrices all_units = sorted(data[unit].unique()) all_periods = sorted(data[time].unique()) n_units = len(all_units) n_periods = len(all_periods) unit_to_idx = {u: i for i, u in enumerate(all_units)} period_to_idx = {p: i for i, p in enumerate(all_periods)} # Vectorized: use pivot for O(1) reshaping instead of O(n) iterrows loop Y = ( data.pivot(index=time, columns=unit, values=outcome) .reindex(index=all_periods, columns=all_units) .values ) D = ( data.pivot(index=time, columns=unit, values=treatment) .reindex(index=all_periods, columns=all_units) .fillna(0) .astype(int) .values ) control_mask = D == 0 # Get control unit indices unit_ever_treated = np.any(D == 1, axis=0) control_unit_idx = np.where(~unit_ever_treated)[0] # Get list of treated observations treated_observations = [(t, i) for t in range(n_periods) for i in range(n_units) if D[t, i] == 1] if not treated_observations: raise ValueError("No treated observations") # Compute ATT using observation-specific weights (Algorithm 2) tau_values = [] for t, i in treated_observations: # Compute observation-specific weights for this (i, t) weight_matrix = self._compute_observation_weights( Y, D, i, t, lambda_time, lambda_unit, control_unit_idx, n_units, n_periods ) # Fit model with these weights alpha, beta, L = self._estimate_model( Y, control_mask, weight_matrix, lambda_nn, n_units, n_periods ) # Compute treatment effect: τ̂_{it} = Y_{it} - α̂_i - β̂_t - L̂_{it} tau = Y[t, i] - alpha[i] - beta[t] - L[t, i] tau_values.append(tau) return np.mean(tau_values)
[docs] def get_params(self) -> Dict[str, Any]: """Get estimator parameters.""" return { "method": self.method, "lambda_time_grid": self.lambda_time_grid, "lambda_unit_grid": self.lambda_unit_grid, "lambda_nn_grid": self.lambda_nn_grid, "max_iter": self.max_iter, "tol": self.tol, "alpha": self.alpha, "n_bootstrap": self.n_bootstrap, "seed": self.seed, }
[docs] def set_params(self, **params) -> "TROP": """Set estimator parameters.""" for key, value in params.items(): if hasattr(self, key): setattr(self, key, value) else: raise ValueError(f"Unknown parameter: {key}") return self
[docs] def trop( data: pd.DataFrame, outcome: str, treatment: str, unit: str, time: str, **kwargs, ) -> TROPResults: """ Convenience function for TROP estimation. Parameters ---------- data : pd.DataFrame Panel data. outcome : str Outcome variable column name. treatment : str Treatment indicator column name (0/1). IMPORTANT: This should be an ABSORBING STATE indicator, not a treatment timing indicator. For each unit, D=1 for ALL periods during and after treatment (D[t,i]=0 for t < g_i, D[t,i]=1 for t >= g_i where g_i is the treatment start time for unit i). unit : str Unit identifier column name. time : str Time period column name. **kwargs Additional arguments passed to TROP constructor. Returns ------- TROPResults Estimation results. Examples -------- >>> from diff_diff import trop >>> results = trop(data, 'y', 'treated', 'unit', 'time') >>> print(f"ATT: {results.att:.3f}") """ estimator = TROP(**kwargs) return estimator.fit(data, outcome, treatment, unit, time)