Source code for diff_diff.utils

"""
Utility functions for difference-in-differences estimation.
"""

import warnings
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Optional, Tuple

import numpy as np
import pandas as pd
from scipy import stats

from diff_diff.linalg import compute_robust_vcov as _compute_robust_vcov_linalg
from diff_diff.linalg import solve_ols as _solve_ols_linalg

# Import Rust backend if available (from _backend to avoid circular imports)
from diff_diff._backend import (
    HAS_RUST_BACKEND,
    _rust_project_simplex,
    _rust_sdid_unit_weights,
    _rust_compute_time_weights,
    _rust_compute_noise_level,
    _rust_sc_weight_fw,
    _rust_sc_weight_fw_with_convergence,
    _rust_sc_weight_fw_weighted,
    _rust_sc_weight_fw_weighted_with_convergence,
)

# Numerical constants for optimization algorithms
_OPTIMIZATION_MAX_ITER = 1000  # Maximum iterations for weight optimization
_OPTIMIZATION_TOL = 1e-8  # Convergence tolerance for optimization
_NUMERICAL_EPS = 1e-10  # Small constant to prevent division by zero

# Cache for critical values to avoid repeated scipy calls
_critical_value_cache: Dict[Tuple[float, Optional[int]], float] = {}


def _get_critical_value(alpha: float, df: Optional[int] = None) -> float:
    """Return cached critical value for (alpha, df) pair."""
    key = (alpha, df)
    if key not in _critical_value_cache:
        if df is not None:
            _critical_value_cache[key] = float(stats.t.ppf(1 - alpha / 2, df))
        else:
            _critical_value_cache[key] = float(stats.norm.ppf(1 - alpha / 2))
    return _critical_value_cache[key]


def validate_binary(arr: np.ndarray, name: str) -> None:
    """
    Validate that an array contains only binary values (0 or 1).

    Parameters
    ----------
    arr : np.ndarray
        Array to validate.
    name : str
        Name of the variable (for error messages).

    Raises
    ------
    ValueError
        If array contains non-binary values.
    """
    unique_values = np.unique(arr[~np.isnan(arr)])
    if not np.all(np.isin(unique_values, [0, 1])):
        raise ValueError(f"{name} must be binary (0 or 1). " f"Found values: {unique_values}")


def validate_covariate_names(
    covariates: Optional[List[str]],
    reserved_names: Iterable[str],
    *,
    estimator: str = "estimator",
) -> None:
    """
    Validate that covariate column names do not collide with reserved
    structural term names (and are not duplicated within ``covariates``).

    Fitted coefficients are stored in a ``name -> value`` dict built by zipping
    a variable-name list -- structural term names PLUS the user covariate column
    names appended verbatim -- with the coefficient vector. A covariate whose
    name equals a reserved structural name (the intercept ``const``, the
    treatment/time indicators, the interaction term, period dummies,
    fixed-effect dummies, or an internal working column) would silently
    overwrite the structural coefficient (Python dict last-write-wins),
    corrupting the result with no error. Duplicate names within ``covariates``
    collapse to a single dict entry the same way.

    The comparison is case-sensitive: column names and dict keys are
    case-sensitive, so e.g. ``Const`` does not actually collide with ``const``
    and is allowed.

    Parameters
    ----------
    covariates : list of str or None
        User-supplied covariate column names. ``None`` or empty is a no-op.
    reserved_names : iterable of str
        Reserved structural term names this estimator builds (estimator-specific).
    estimator : str
        Estimator name, used in the error message.

    Raises
    ------
    ValueError
        If a covariate name collides with a reserved structural name, or if
        ``covariates`` contains duplicate names.
    """
    if not covariates:
        return
    reserved = set(reserved_names)
    collisions = sorted({c for c in covariates if c in reserved})
    if collisions:
        raise ValueError(
            f"{estimator}: covariate name(s) {collisions} collide with reserved "
            f"structural term name(s). These names are used internally for the "
            f"intercept, the treatment/time indicators, the interaction term, "
            f"period dummies, fixed-effect dummies, or internal working columns, "
            f"and a colliding covariate would silently overwrite the structural "
            f"coefficient. Rename the covariate column(s). Reserved names for "
            f"this fit: {sorted(reserved)}."
        )
    seen: set = set()
    duplicates = []
    for c in covariates:
        if c in seen:
            duplicates.append(c)
        seen.add(c)
    if duplicates:
        raise ValueError(
            f"{estimator}: duplicate covariate name(s) {sorted(set(duplicates))} "
            f"in `covariates`. Each covariate maps to one coefficient; duplicates "
            f"collapse to a single entry. Remove the duplicate(s)."
        )


def validate_design_term_names(
    var_names: Iterable[str],
    *,
    estimator: str = "estimator",
) -> None:
    """
    Raise if the assembled design term-name list contains duplicates.

    Backstop for :func:`validate_covariate_names`: even after the user
    covariates are cleared, a fixed-effect dummy name (``{fe}_{value}``) can
    still collide with a structural term — most notably a ``MultiPeriodDiD``
    ``period_{p}`` event-study key when a non-time fixed effect produces matching
    dummy names — or with another dummy. Such a duplicate would silently
    overwrite a coefficient when ``var_names`` is zipped into the result's
    ``coefficients`` dict (Python dict last-write-wins). This checks the FINAL
    name list (structural terms + covariates + fixed-effect dummies) right
    before the dict is built, catching collisions that depend on the data and so
    cannot be known up front.

    Parameters
    ----------
    var_names : iterable of str
        The fully assembled design-matrix column-name list.
    estimator : str
        Estimator name, used in the error message.

    Raises
    ------
    ValueError
        If any name appears more than once.
    """
    seen: set = set()
    duplicates = []
    for name in var_names:
        if name in seen:
            duplicates.append(name)
        seen.add(name)
    if duplicates:
        raise ValueError(
            f"{estimator}: the fitted design has duplicate term name(s) "
            f"{sorted(set(duplicates))} — a covariate or fixed-effect dummy name "
            f"collides with a structural term (intercept, treatment/time "
            f"indicators, the interaction, or period dummies) or with another "
            f"column. This would silently overwrite a coefficient in the result. "
            f"Rename the offending fixed-effect category or covariate column."
        )


def fe_dummy_names(col: pd.Series, prefix: str) -> List[str]:
    """
    Reserved fixed-effect dummy column names for the collision guard, matching
    ``pd.get_dummies(col, prefix=prefix, drop_first=True).columns`` WITHOUT
    materializing the dense ``(n x G)`` dummy matrix.

    The within-transform ``TwoWayFixedEffects`` path is specifically designed to
    avoid expanding high-cardinality fixed-effect dummies (that is its scaling
    contract), so the collision guard must reserve those names without building
    the dummy block. ``pd.get_dummies`` orders categories via
    ``pd.Categorical(col).categories`` — sorted unique values for a plain column,
    the declared category order for a ``Categorical`` — then ``drop_first=True``
    drops the first. This derivation reproduces that exactly (including
    ``Categorical`` columns with a non-default category order) at ``O(G)`` memory.

    Parameters
    ----------
    col : pandas.Series
        The fixed-effect / unit / time column.
    prefix : str
        Dummy-name prefix (the project uses ``fe`` for ``fixed_effects`` and
        ``_fe_{unit}`` / ``_fe_{time}`` for TWFE unit/time dummies).

    Returns
    -------
    list of str
        The kept (post ``drop_first``) dummy column names.
    """
    if isinstance(col.dtype, pd.CategoricalDtype):
        cats = list(col.cat.categories)
    else:
        cats = list(pd.Categorical(col).categories)
    return [f"{prefix}_{c}" for c in cats[1:]]


def warn_if_not_converged(
    converged: bool,
    method_name: str,
    max_iter: int,
    tol: Optional[float] = None,
    stacklevel: int = 3,
) -> None:
    """Emit a UserWarning when an iterative solver exhausts max_iter without converging.

    Shared helper for axis-B silent-failure fixes (iterative loops that otherwise
    return the current iterate without signaling non-convergence).
    """
    if converged:
        return
    tol_suffix = f" (tol={tol})" if tol is not None else ""
    warnings.warn(
        f"{method_name} did not converge in {max_iter} iterations{tol_suffix}. "
        "Results may be inaccurate.",
        UserWarning,
        stacklevel=stacklevel,
    )


def compute_robust_se(
    X: np.ndarray, residuals: np.ndarray, cluster_ids: Optional[np.ndarray] = None
) -> np.ndarray:
    """
    Compute heteroskedasticity-robust (HC1) or cluster-robust standard errors.

    This function is a thin wrapper around the optimized implementation in
    diff_diff.linalg for backwards compatibility.

    Parameters
    ----------
    X : np.ndarray
        Design matrix of shape (n, k).
    residuals : np.ndarray
        Residuals from regression of shape (n,).
    cluster_ids : np.ndarray, optional
        Cluster identifiers for cluster-robust SEs.

    Returns
    -------
    np.ndarray
        Variance-covariance matrix of shape (k, k).
    """
    return _compute_robust_vcov_linalg(X, residuals, cluster_ids)


def compute_confidence_interval(
    estimate: float, se: float, alpha: float = 0.05, df: Optional[int] = None
) -> Tuple[float, float]:
    """
    Compute confidence interval for an estimate.

    Parameters
    ----------
    estimate : float
        Point estimate.
    se : float
        Standard error.
    alpha : float
        Significance level (default 0.05 for 95% CI).
    df : int, optional
        Degrees of freedom. If None, uses normal distribution.

    Returns
    -------
    tuple
        (lower_bound, upper_bound) of confidence interval.
    """
    critical_value = _get_critical_value(alpha, df)
    lower = estimate - critical_value * se
    upper = estimate + critical_value * se

    return (lower, upper)


def compute_p_value(t_stat: float, df: Optional[int] = None, two_sided: bool = True) -> float:
    """
    Compute p-value for a t-statistic.

    Parameters
    ----------
    t_stat : float
        T-statistic.
    df : int, optional
        Degrees of freedom. If None, uses normal distribution.
    two_sided : bool
        Whether to compute two-sided p-value (default True).

    Returns
    -------
    float
        P-value.
    """
    if df is not None:
        p_value = stats.t.sf(np.abs(t_stat), df)
    else:
        p_value = stats.norm.sf(np.abs(t_stat))

    if two_sided:
        p_value *= 2

    return float(p_value)


def safe_inference(effect, se, alpha=0.05, df=None):
    """Compute t_stat, p_value, conf_int with NaN-safe gating.

    When SE is non-finite, zero, or negative, ALL inference fields
    are set to NaN to prevent misleading statistical output.

    Accepts scalar inputs only (not numpy arrays). All existing inference
    call sites operate on scalars within loops.

    Parameters
    ----------
    effect : float
        Point estimate (treatment effect or coefficient).
    se : float
        Standard error of the estimate.
    alpha : float, optional
        Significance level for confidence interval (default 0.05).
    df : int, optional
        Degrees of freedom. If None, uses normal distribution.

    Returns
    -------
    tuple
        (t_stat, p_value, (ci_lower, ci_upper)). All NaN when SE is
        non-finite, zero, or negative.
    """
    if not (np.isfinite(se) and se > 0):
        return np.nan, np.nan, (np.nan, np.nan)
    if df is not None and df <= 0:
        # Undefined degrees of freedom (e.g., rank-deficient replicate design)
        return np.nan, np.nan, (np.nan, np.nan)
    t_stat = effect / se
    p_value = compute_p_value(t_stat, df=df)
    conf_int = compute_confidence_interval(effect, se, alpha, df=df)
    return t_stat, p_value, conf_int


def safe_inference_batch(effects, ses, alpha=0.05, df=None):
    """Vectorized batch inference for arrays of effects and SEs.

    Parameters
    ----------
    effects : np.ndarray
        Array of point estimates.
    ses : np.ndarray
        Array of standard errors.
    alpha : float, optional
        Significance level (default 0.05).
    df : int, optional
        Degrees of freedom. If None, uses normal distribution.

    Returns
    -------
    t_stats : np.ndarray
    p_values : np.ndarray
    ci_lowers : np.ndarray
    ci_uppers : np.ndarray
    """
    effects = np.asarray(effects, dtype=float)
    ses = np.asarray(ses, dtype=float)
    n = len(effects)

    t_stats = np.full(n, np.nan)
    p_values = np.full(n, np.nan)
    ci_lowers = np.full(n, np.nan)
    ci_uppers = np.full(n, np.nan)

    # Undefined df (e.g., rank-deficient replicate design) → all NaN
    if df is not None and df <= 0:
        return t_stats, p_values, ci_lowers, ci_uppers

    valid = np.isfinite(ses) & (ses > 0)
    if not np.any(valid):
        return t_stats, p_values, ci_lowers, ci_uppers

    t_stats[valid] = effects[valid] / ses[valid]

    if df is not None:
        p_values[valid] = 2.0 * stats.t.sf(np.abs(t_stats[valid]), df)
    else:
        p_values[valid] = 2.0 * stats.norm.sf(np.abs(t_stats[valid]))

    crit = _get_critical_value(alpha, df)
    ci_lowers[valid] = effects[valid] - crit * ses[valid]
    ci_uppers[valid] = effects[valid] + crit * ses[valid]

    return t_stats, p_values, ci_lowers, ci_uppers


# =============================================================================
# Wild Cluster Bootstrap
# =============================================================================


[docs] @dataclass class WildBootstrapResults: """ Results from wild cluster bootstrap inference. Attributes ---------- se : float Bootstrap standard error of the coefficient. p_value : float Bootstrap p-value (two-sided). t_stat_original : float Original t-statistic from the data. ci_lower : float Lower bound of the confidence interval. ci_upper : float Upper bound of the confidence interval. n_clusters : int Number of clusters in the data. n_bootstrap : int Number of bootstrap replications. weight_type : str Type of bootstrap weights used ("rademacher", "webb", or "mammen"). alpha : float Significance level used for confidence interval. bootstrap_distribution : np.ndarray, optional Full bootstrap distribution of coefficients (if requested). References ---------- Cameron, A. C., Gelbach, J. B., & Miller, D. L. (2008). Bootstrap-Based Improvements for Inference with Clustered Errors. The Review of Economics and Statistics, 90(3), 414-427. """ se: float p_value: float t_stat_original: float ci_lower: float ci_upper: float n_clusters: int n_bootstrap: int weight_type: str alpha: float = 0.05 bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False)
[docs] def summary(self) -> str: """Generate formatted summary of bootstrap results.""" lines = [ "Wild Cluster Bootstrap Results", "=" * 40, f"Bootstrap SE: {self.se:.6f}", f"Bootstrap p-value: {self.p_value:.4f}", f"Original t-stat: {self.t_stat_original:.4f}", f"CI ({int((1-self.alpha)*100)}%): [{self.ci_lower:.6f}, {self.ci_upper:.6f}]", f"Number of clusters: {self.n_clusters}", f"Bootstrap reps: {self.n_bootstrap}", f"Weight type: {self.weight_type}", ] return "\n".join(lines)
[docs] def print_summary(self) -> None: """Print formatted summary to stdout.""" print(self.summary())
def _generate_rademacher_weights(n_clusters: int, rng: np.random.Generator) -> np.ndarray: """ Generate Rademacher weights: +1 or -1 with probability 0.5. Parameters ---------- n_clusters : int Number of clusters. rng : np.random.Generator Random number generator. Returns ------- np.ndarray Array of Rademacher weights. """ return np.asarray(rng.choice([-1.0, 1.0], size=n_clusters)) def _generate_webb_weights(n_clusters: int, rng: np.random.Generator) -> np.ndarray: """ Generate Webb's 6-point distribution weights. Values: {-sqrt(3/2), -sqrt(2/2), -sqrt(1/2), sqrt(1/2), sqrt(2/2), sqrt(3/2)} with equal probabilities (1/6 each), giving E[w]=0 and Var(w)=1.0. This distribution is recommended for very few clusters (G < 10) as it provides better finite-sample properties than Rademacher weights. Parameters ---------- n_clusters : int Number of clusters. rng : np.random.Generator Random number generator. Returns ------- np.ndarray Array of Webb weights. References ---------- Webb, M. D. (2014). Reworking wild bootstrap based inference for clustered errors. Queen's Economics Department Working Paper No. 1315. Note: Uses equal probabilities (1/6 each) matching R's `did` package, which gives unit variance for consistency with other weight distributions. """ values = np.array( [ -np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2), np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2), ] ) # Equal probabilities (1/6 each) matching R's did package, giving Var(w) = 1.0 return np.asarray(rng.choice(values, size=n_clusters)) def _generate_mammen_weights(n_clusters: int, rng: np.random.Generator) -> np.ndarray: """ Generate Mammen's two-point distribution weights. Values: {-(sqrt(5)-1)/2, (sqrt(5)+1)/2} with probabilities {(sqrt(5)+1)/(2*sqrt(5)), (sqrt(5)-1)/(2*sqrt(5))}. This distribution satisfies E[v]=0, E[v^2]=1, E[v^3]=1, which provides asymptotic refinement for skewed error distributions. Parameters ---------- n_clusters : int Number of clusters. rng : np.random.Generator Random number generator. Returns ------- np.ndarray Array of Mammen weights. References ---------- Mammen, E. (1993). Bootstrap and Wild Bootstrap for High Dimensional Linear Models. The Annals of Statistics, 21(1), 255-285. """ sqrt5 = np.sqrt(5) # Values from Mammen (1993) val1 = -(sqrt5 - 1) / 2 # approximately -0.618 val2 = (sqrt5 + 1) / 2 # approximately 1.618 (golden ratio) # Probability of val1 p1 = (sqrt5 + 1) / (2 * sqrt5) # approximately 0.724 return np.asarray(rng.choice([val1, val2], size=n_clusters, p=[p1, 1 - p1]))
[docs] def wild_bootstrap_se( X: np.ndarray, y: np.ndarray, residuals: np.ndarray, cluster_ids: np.ndarray, coefficient_index: int, n_bootstrap: int = 999, weight_type: str = "rademacher", null_hypothesis: float = 0.0, alpha: float = 0.05, seed: Optional[int] = None, return_distribution: bool = False, ) -> WildBootstrapResults: """ Compute wild cluster bootstrap standard errors and p-values. Implements the Wild Cluster Residual (WCR) bootstrap procedure from Cameron, Gelbach, and Miller (2008). Uses the restricted residuals approach (imposing H0: coefficient = null_hypothesis) for more accurate p-value computation. Parameters ---------- X : np.ndarray Design matrix of shape (n, k). y : np.ndarray Outcome vector of shape (n,). residuals : np.ndarray OLS residuals from unrestricted regression, shape (n,). cluster_ids : np.ndarray Cluster identifiers of shape (n,). coefficient_index : int Index of the coefficient for which to compute bootstrap inference. For DiD, this is typically 3 (the treatment*post interaction term). n_bootstrap : int, default=999 Number of bootstrap replications. Odd numbers are recommended for exact p-value computation. weight_type : str, default="rademacher" Type of bootstrap weights: - "rademacher": +1 or -1 with equal probability (standard choice) - "webb": 6-point distribution (recommended for <10 clusters) - "mammen": Two-point distribution with skewness correction null_hypothesis : float, default=0.0 Value of the null hypothesis for p-value computation. alpha : float, default=0.05 Significance level for confidence interval. seed : int, optional Random seed for reproducibility. If None (default), results will vary between runs. return_distribution : bool, default=False If True, include full bootstrap distribution in results. Returns ------- WildBootstrapResults Dataclass containing bootstrap SE, p-value, confidence interval, and other inference results. Raises ------ ValueError If weight_type is not recognized or if there are fewer than 2 clusters. Warns ----- UserWarning If the number of clusters is less than 5, as bootstrap inference may be unreliable. Examples -------- >>> from diff_diff.utils import wild_bootstrap_se >>> results = wild_bootstrap_se( ... X, y, residuals, cluster_ids, ... coefficient_index=3, # ATT coefficient ... n_bootstrap=999, ... weight_type="rademacher", ... seed=42 ... ) >>> print(f"Bootstrap SE: {results.se:.4f}") >>> print(f"Bootstrap p-value: {results.p_value:.4f}") References ---------- Cameron, A. C., Gelbach, J. B., & Miller, D. L. (2008). Bootstrap-Based Improvements for Inference with Clustered Errors. The Review of Economics and Statistics, 90(3), 414-427. MacKinnon, J. G., & Webb, M. D. (2018). The wild bootstrap for few (treated) clusters. The Econometrics Journal, 21(2), 114-135. """ # Validate inputs valid_weight_types = ["rademacher", "webb", "mammen"] if weight_type not in valid_weight_types: raise ValueError(f"weight_type must be one of {valid_weight_types}, got '{weight_type}'") unique_clusters = np.unique(cluster_ids) n_clusters = len(unique_clusters) if n_clusters < 2: raise ValueError(f"Wild cluster bootstrap requires at least 2 clusters, got {n_clusters}") if n_clusters < 5: warnings.warn( f"Only {n_clusters} clusters detected. Wild bootstrap inference may be " "unreliable with fewer than 5 clusters. Consider using Webb weights " "(weight_type='webb') for improved finite-sample properties.", UserWarning, ) # Initialize RNG rng = np.random.default_rng(seed) # Select weight generator weight_generators = { "rademacher": _generate_rademacher_weights, "webb": _generate_webb_weights, "mammen": _generate_mammen_weights, } generate_weights = weight_generators[weight_type] n = X.shape[0] # Step 1: Compute original coefficient and cluster-robust SE beta_hat, _, vcov_original = _solve_ols_linalg(X, y, cluster_ids=cluster_ids, return_vcov=True) original_coef = beta_hat[coefficient_index] assert vcov_original is not None se_original = np.sqrt(vcov_original[coefficient_index, coefficient_index]) t_stat_original = (original_coef - null_hypothesis) / se_original # Step 2: Impose null hypothesis (restricted estimation) # Create restricted y: y_restricted = y - X[:, coef_index] * null_hypothesis # This imposes the null that the coefficient equals null_hypothesis y_restricted = y - X[:, coefficient_index] * null_hypothesis # Fit restricted model (but we need to drop the column for the restricted coef) # Actually, for WCR bootstrap we keep all columns but impose the null via residuals # Re-estimate with the restricted dependent variable. # # Use return_fitted=True so we get NaN-safe fitted values from the kept # columns when solve_ols drops rank-deficient nuisance columns. Without # this, building y_star via `X @ beta_restricted` would propagate NaN # through every observation whenever a nuisance column was dropped # (e.g. always-treated unit dummy collinear with treated*post on the # full-dummy TWFE HC2/HC2-BM path), poisoning the entire bootstrap loop # despite the ATT being analytically identified. beta_restricted, residuals_restricted, fitted_restricted, _ = _solve_ols_linalg( X, y_restricted, return_vcov=False, return_fitted=True ) # Create cluster-to-observation mapping for efficiency cluster_map = {c: np.where(cluster_ids == c)[0] for c in unique_clusters} cluster_indices = [cluster_map[c] for c in unique_clusters] # Step 3: Bootstrap loop # Use NaN for invalid draws (singular bootstrap SE) and filter at the # p-value step, rather than coercing to t*=0 which biases the p-value # toward small values (since |0| < |t_original| counts as "non-rejection" # only when the original t is large). bootstrap_t_stats = np.full(n_bootstrap, np.nan) bootstrap_coefs = np.full(n_bootstrap, np.nan) for b in range(n_bootstrap): # Generate cluster-level weights cluster_weights = generate_weights(n_clusters, rng) # Map cluster weights to observations obs_weights = np.zeros(n) for g, indices in enumerate(cluster_indices): obs_weights[indices] = cluster_weights[g] # Construct bootstrap sample: y* = fitted_restricted + e_restricted * weights # (fitted_restricted comes from solve_ols's kept-columns reconstruction, # so it's NaN-safe even when beta_restricted has NaN on dropped columns) y_star = fitted_restricted + residuals_restricted * obs_weights # Estimate bootstrap coefficients with cluster-robust SE beta_star, residuals_star, vcov_star = _solve_ols_linalg( X, y_star, cluster_ids=cluster_ids, return_vcov=True ) bootstrap_coefs[b] = beta_star[coefficient_index] assert vcov_star is not None se_star = np.sqrt(vcov_star[coefficient_index, coefficient_index]) # Compute bootstrap t-statistic (under null hypothesis); invalid # draws (singular SE) leave the NaN sentinel for filtering below. if se_star > 0 and np.isfinite(beta_star[coefficient_index]): bootstrap_t_stats[b] = (beta_star[coefficient_index] - null_hypothesis) / se_star # Step 4: Compute bootstrap inference from VALID (finite) draws only. # # All-or-nothing NaN contract (per feedback_bootstrap_nan_on_invalid_contract): # when bootstrap output is degenerate (fewer than 2 finite t-stats or # 2 finite coefs), return NaN across the full inference surface (se, # p_value, both CI endpoints, AND the surfaced t_stat_original). The # original analytical t_stat is still computed in step 1 for diagnostic # use but is NOT propagated to the user-facing result when bootstrap # is degenerate — surfacing it alongside NaN se/p/CI would mix # analytical and bootstrap inference families on the same coefficient. finite_mask = np.isfinite(bootstrap_t_stats) n_valid = int(finite_mask.sum()) valid_coefs = bootstrap_coefs[np.isfinite(bootstrap_coefs)] lower_percentile = alpha / 2 * 100 upper_percentile = (1 - alpha / 2) * 100 if n_valid >= 2 and valid_coefs.size >= 2: p_value = float(np.mean(np.abs(bootstrap_t_stats[finite_mask]) >= np.abs(t_stat_original))) # Ensure p-value is at least 1/(n_valid+1) to avoid exact zero. p_value = float(max(p_value, 1 / (n_valid + 1))) se_bootstrap = float(np.std(valid_coefs, ddof=1)) ci_lower = float(np.percentile(valid_coefs, lower_percentile)) ci_upper = float(np.percentile(valid_coefs, upper_percentile)) surfaced_t_stat = t_stat_original else: # Degenerate bootstrap (insufficient valid draws): NaN-out the # entire inference tuple. Downstream consumers (estimator-level # `_run_wild_bootstrap_inference`) map these fields directly onto # the result object; this guarantees the (se, t_stat, p_value, ci) # quadruple moves together rather than reporting analytical t_stat # with NaN se. p_value = float("nan") se_bootstrap = float("nan") ci_lower = float("nan") ci_upper = float("nan") surfaced_t_stat = float("nan") return WildBootstrapResults( se=se_bootstrap, p_value=p_value, t_stat_original=surfaced_t_stat, ci_lower=ci_lower, ci_upper=ci_upper, n_clusters=n_clusters, n_bootstrap=n_bootstrap, weight_type=weight_type, alpha=alpha, bootstrap_distribution=bootstrap_coefs if return_distribution else None, )
def _compute_outcome_changes( data: pd.DataFrame, outcome: str, time: str, treatment_group: str, unit: Optional[str] = None, caller_label: str = "parallel-trend diagnostic", ) -> Tuple[np.ndarray, np.ndarray]: """ Compute period-to-period outcome changes for treated and control groups. Parameters ---------- data : pd.DataFrame Panel data. outcome : str Outcome variable column. time : str Time period column. treatment_group : str Treatment group indicator column. unit : str, optional Unit identifier column. Returns ------- tuple (treated_changes, control_changes) as numpy arrays. """ if unit is not None: # Unit-level changes: compute change for each unit across periods data_sorted = data.sort_values([unit, time]) data_sorted["_outcome_change"] = data_sorted.groupby(unit)[outcome].diff() # Remove NaN from first period of each unit. The first period per unit # has no prior observation to diff against, so n_units drops are # expected. Anything beyond that is a silent side-effect of gaps or # NaN outcomes — surface the excess via warning (axis-E drop counter). n_units_observed = int(data_sorted[unit].nunique()) n_dropped = int(data_sorted["_outcome_change"].isna().sum()) n_unexpected_drops = max(0, n_dropped - n_units_observed) if n_unexpected_drops > 0: warnings.warn( f"{caller_label}: dropped {n_dropped} row(s) with NaN " f"first-differences; {n_units_observed} are the expected " f"first-period-per-unit drops, and {n_unexpected_drops} are " f"additional NaN first-differences (e.g. NaN outcomes or " f"unit-period gaps upstream). Parallel-trend statistics are " f"computed on the remaining rows.", UserWarning, stacklevel=3, ) changes_data = data_sorted.dropna(subset=["_outcome_change"]) treated_changes = changes_data[changes_data[treatment_group] == 1]["_outcome_change"].values control_changes = changes_data[changes_data[treatment_group] == 0]["_outcome_change"].values else: # Aggregate changes: compute mean change per period per group treated_data = data[data[treatment_group] == 1] control_data = data[data[treatment_group] == 0] # Compute period means treated_means = treated_data.groupby(time)[outcome].mean() control_means = control_data.groupby(time)[outcome].mean() # Compute changes between consecutive periods treated_changes = np.diff(treated_means.values) control_changes = np.diff(control_means.values) return treated_changes.astype(float), control_changes.astype(float) # compute_synthetic_weights and _compute_synthetic_weights_numpy removed in the # silent-failures audit post-cleanup (finding #22). The one caller # (`diff_diff.prep.rank_control_units`) inlines a single-pass, uncentered # Frank-Wolfe via the shared `_sc_weight_fw` dispatcher — a ranking heuristic, # NOT the canonical SDID/R `synthdid::sc.weight.fw` two-pass procedure # (intercept=True, 100-iter -> sparsify -> 10000-iter). Canonical SDID unit # weights go through `compute_sdid_unit_weights` (see `_sc_weight_fw_numpy` # below and REGISTRY.md SDID section). def _project_simplex(v: np.ndarray) -> np.ndarray: """ Project vector onto probability simplex (sum to 1, non-negative). Uses the algorithm from Duchi et al. (2008). Parameters ---------- v : np.ndarray Vector to project. Returns ------- np.ndarray Projected vector on the simplex. """ n = len(v) if n == 0: return v # Sort in descending order u = np.sort(v)[::-1] # Find the threshold cssv = np.cumsum(u) rho = np.where(u > (cssv - 1) / np.arange(1, n + 1))[0] if len(rho) == 0: # All elements are negative or zero rho_val = 0 else: rho_val = rho[-1] theta = (cssv[rho_val] - 1) / (rho_val + 1) return np.asarray(np.maximum(v - theta, 0)) # ============================================================================= # SDID Weight Optimization (Frank-Wolfe, matching R's synthdid) # ============================================================================= def _sum_normalize(v: np.ndarray) -> np.ndarray: """Normalize vector to sum to 1. Fallback to uniform if sum is zero. Matches R's synthdid ``sum_normalize()`` helper. """ s = np.sum(v) if s > 0: return v / s return np.ones(len(v)) / len(v) def _compute_noise_level(Y_pre_control: np.ndarray) -> float: """Compute noise level from first-differences of control outcomes. Matches R's ``sd(apply(Y[1:N0, 1:T0], 1, diff))`` which computes first-differences across time for each control unit, then takes the pooled standard deviation. Parameters ---------- Y_pre_control : np.ndarray Control unit pre-treatment outcomes, shape (n_pre, n_control). Returns ------- float Noise level (standard deviation of first-differences). """ if HAS_RUST_BACKEND: return float(_rust_compute_noise_level(np.ascontiguousarray(Y_pre_control))) return _compute_noise_level_numpy(Y_pre_control) def _compute_noise_level_numpy(Y_pre_control: np.ndarray) -> float: """Pure NumPy implementation of noise level computation.""" if Y_pre_control.shape[0] < 2: return 0.0 # R: apply(Y[1:N0, 1:T0], 1, diff) computes diff per row (unit). # Our matrix is (T, N) so diff along axis=0 gives (T-1, N). first_diffs = np.diff(Y_pre_control, axis=0) # (T_pre-1, N_co) if first_diffs.size <= 1: return 0.0 return float(np.std(first_diffs, ddof=1)) def _compute_regularization( Y_pre_control: np.ndarray, n_treated: int, n_post: int, ) -> tuple: """Compute auto-regularization parameters matching R's synthdid. Parameters ---------- Y_pre_control : np.ndarray Control unit pre-treatment outcomes, shape (n_pre, n_control). n_treated : int Number of treated units. n_post : int Number of post-treatment periods. Returns ------- tuple (zeta_omega, zeta_lambda) regularization parameters. """ sigma = _compute_noise_level(Y_pre_control) eta_omega = (n_treated * n_post) ** 0.25 eta_lambda = 1e-6 return eta_omega * sigma, eta_lambda * sigma def _fw_step( A: np.ndarray, x: np.ndarray, b: np.ndarray, eta: float, ) -> np.ndarray: """Single Frank-Wolfe step on the simplex. Matches R's ``fw.step()`` in synthdid's ``sc.weight.fw()``. Parameters ---------- A : np.ndarray Matrix of shape (N, T0). x : np.ndarray Current weight vector of shape (T0,). b : np.ndarray Target vector of shape (N,). eta : float Regularization strength (N * zeta^2). Returns ------- np.ndarray Updated weight vector on the simplex. """ Ax = A @ x half_grad = A.T @ (Ax - b) + eta * x i = int(np.argmin(half_grad)) d_x = -x.copy() d_x[i] += 1.0 if np.allclose(d_x, 0.0): return x.copy() d_err = A[:, i] - Ax denom = d_err @ d_err + eta * (d_x @ d_x) if denom <= 0: return x.copy() step = -(half_grad @ d_x) / denom step = float(np.clip(step, 0.0, 1.0)) return x + step * d_x def _sc_weight_fw( Y: np.ndarray, zeta: float, intercept: bool = True, init_weights: Optional[np.ndarray] = None, min_decrease: float = 1e-5, max_iter: int = 10000, return_convergence: bool = False, reg_weights: Optional[np.ndarray] = None, ): """Compute synthetic control weights via Frank-Wolfe optimization. Matches R's ``sc.weight.fw()`` from the synthdid package. Solves:: min_{lambda on simplex} zeta^2 * ||lambda||^2 + (1/N) * ||A_centered @ lambda - b_centered||^2 With ``reg_weights`` set, solves the weighted-regularization variant used by SDID survey-bootstrap (PR #352):: min_{lambda on simplex} zeta^2 * sum_j reg_weights[j] * lambda[j]^2 + (1/N) * ||A_centered @ lambda - b_centered||^2 Parameters ---------- Y : np.ndarray Matrix of shape (N, T0+1). Last column is the target (post-period mean or treated pre-period mean depending on context). zeta : float Regularization strength. intercept : bool, default True If True, column-center Y before optimization. init_weights : np.ndarray, optional Initial weights. If None, starts with uniform weights. min_decrease : float, default 1e-5 Convergence criterion: stop when objective decreases by less than ``min_decrease**2``. R uses ``1e-5 * noise_level``; the caller should pass the data-dependent value for best results. max_iter : int, default 10000 Maximum number of iterations. Matches R's default. return_convergence : bool, default False If True, returns a tuple ``(weights, converged)`` where ``converged`` is ``True`` iff the min-decrease criterion fired rather than ``max_iter`` being reached. Dispatches to the Rust ``sc_weight_fw_with_convergence`` entry point when available, and to ``_sc_weight_fw_numpy(return_convergence=True)`` otherwise. Used by SDID bootstrap to surface per-draw FW non-convergence explicitly instead of relying on ``warnings.catch_warnings`` (the default Rust FW entry point is silent on non-convergence). reg_weights : np.ndarray, optional Per-coordinate regularization weights of shape ``(T0,)``. When set, switches to the weighted-regularization Rust kernel (``sc_weight_fw_weighted`` / ``_with_convergence``) which solves the SDID survey-bootstrap objective with ``ζ²·Σ rw·ω²`` in place of the uniform ``ζ²·||ω||²``. The caller is responsible for any column-scaling of ``Y`` to match the loss form. Default ``None`` delegates to the unweighted kernel — preserves the legacy ABI for all existing callers. Returns ------- np.ndarray or Tuple[np.ndarray, bool] Weights of shape (T0,) on the simplex; with ``return_convergence=True``, additionally the convergence flag. """ Y_c = np.ascontiguousarray(Y, dtype=np.float64) init_c = ( np.ascontiguousarray(init_weights, dtype=np.float64) if init_weights is not None else None ) rw_c = np.ascontiguousarray(reg_weights, dtype=np.float64) if reg_weights is not None else None if rw_c is not None: # Validate reg_weights shape at the dispatcher so Rust and NumPy # backends share a single failure surface. The Rust # ``sc_weight_fw_weighted_internal`` silently falls back to the # unweighted kernel on a length mismatch, while the NumPy # implementation raises — dispatching without a shared upstream # check would let callers get the wrong objective on the Rust # path with no error (PR #355 R5 P2). expected_t0 = Y_c.shape[1] - 1 if rw_c.shape != (expected_t0,): raise ValueError( f"reg_weights shape {rw_c.shape} does not match expected " f"({expected_t0},) — must equal Y.shape[1] - 1" ) if HAS_RUST_BACKEND: if reg_weights is not None: if return_convergence: weights, converged = _rust_sc_weight_fw_weighted_with_convergence( Y_c, zeta, intercept, init_c, min_decrease, max_iter, rw_c, ) return np.asarray(weights), converged return np.asarray( _rust_sc_weight_fw_weighted( Y_c, zeta, intercept, init_c, min_decrease, max_iter, rw_c, ) ) if return_convergence: weights, converged = _rust_sc_weight_fw_with_convergence( Y_c, zeta, intercept, init_c, min_decrease, max_iter, ) return np.asarray(weights), converged return np.asarray( _rust_sc_weight_fw( Y_c, zeta, intercept, init_c, min_decrease, max_iter, ) ) return _sc_weight_fw_numpy( Y, zeta, intercept, init_weights, min_decrease, max_iter, return_convergence=return_convergence, reg_weights=reg_weights, ) def _sc_weight_fw_numpy( Y: np.ndarray, zeta: float, intercept: bool = True, init_weights: Optional[np.ndarray] = None, min_decrease: float = 1e-5, max_iter: int = 10000, return_convergence: bool = False, reg_weights: Optional[np.ndarray] = None, ): """Pure NumPy implementation of Frank-Wolfe SC weight solver. When ``return_convergence=True``, returns a tuple ``(weights, converged)`` and suppresses the default ``warn_if_not_converged`` side effect — the caller is responsible for deciding how to surface non-convergence. With ``reg_weights`` set, solves the weighted-regularization variant (matches the Rust ``sc_weight_fw_weighted`` kernel; PR #352). The loss term is unchanged; only the regularization becomes ``ζ²·Σ_j reg_weights[j]·lam[j]²`` and the FW step uses the diag(rw)- weighted simplex direction norm. """ T0 = Y.shape[1] - 1 N = Y.shape[0] if T0 <= 0: lam_trivial = np.ones(max(T0, 1)) if return_convergence: return lam_trivial, True return lam_trivial # Column-center if using intercept (matches R's intercept=TRUE default) if intercept: Y = Y - Y.mean(axis=0) A = Y[:, :T0] b = Y[:, T0] eta = N * zeta**2 if init_weights is not None: lam = init_weights.copy() else: lam = np.ones(T0) / T0 if reg_weights is not None: rw = np.asarray(reg_weights, dtype=np.float64) if rw.shape != (T0,): raise ValueError( f"reg_weights shape {rw.shape} does not match expected " f"({T0},) — must equal A.shape[1]" ) else: rw = None vals = np.full(max_iter, np.nan) converged = False for t in range(max_iter): if rw is None: lam = _fw_step(A, lam, b, eta) err = Y @ np.append(lam, -1.0) vals[t] = zeta**2 * np.sum(lam**2) + np.sum(err**2) / N else: # Weighted FW step with diag(rw) regularization. Mirrors the # Rust sc_weight_fw_*_weighted derivation in rust/src/weights.rs. ax_minus_b = A @ lam - b half_grad = A.T @ ax_minus_b + eta * rw * lam i = int(np.argmin(half_grad)) d = -lam.copy() d[i] += 1.0 d_x_w_norm_sq = float(np.sum(rw * d * d)) if d_x_w_norm_sq < 1e-24: err = ax_minus_b vals[t] = zeta**2 * float(np.sum(rw * lam * lam)) + float(np.sum(err**2)) / N if t >= 1 and vals[t - 1] - vals[t] < min_decrease**2: converged = True break continue d_err_sq = float(np.sum((A @ d) ** 2)) denom = d_err_sq + eta * d_x_w_norm_sq if denom <= 0.0: err = ax_minus_b vals[t] = zeta**2 * float(np.sum(rw * lam * lam)) + float(np.sum(err**2)) / N if t >= 1 and vals[t - 1] - vals[t] < min_decrease**2: converged = True break continue hg_dot_dx = float(half_grad @ d) step = float(np.clip(-hg_dot_dx / denom, 0.0, 1.0)) lam = lam + step * d err = A @ lam - b vals[t] = zeta**2 * float(np.sum(rw * lam * lam)) + float(np.sum(err**2)) / N if t >= 1 and vals[t - 1] - vals[t] < min_decrease**2: converged = True break if return_convergence: return lam, converged warn_if_not_converged(converged, "Frank-Wolfe SC weight solver", max_iter, min_decrease) return lam def _sparsify(v: np.ndarray) -> np.ndarray: """Sparsify weight vector by zeroing out small entries. Matches R's synthdid ``sparsify_function``: ``v[v <= max(v)/4] = 0; v = v / sum(v)`` Parameters ---------- v : np.ndarray Weight vector. Returns ------- np.ndarray Sparsified weight vector summing to 1. """ v = v.copy() max_v = np.max(v) if max_v <= 0: return np.ones(len(v)) / len(v) v[v <= max_v / 4] = 0.0 return _sum_normalize(v) def compute_time_weights( Y_pre_control: np.ndarray, Y_post_control: np.ndarray, zeta_lambda: float, intercept: bool = True, min_decrease: float = 1e-5, max_iter_pre_sparsify: int = 100, max_iter: int = 10000, init_weights: Optional[np.ndarray] = None, return_convergence: bool = False, ): """Compute SDID time weights via Frank-Wolfe optimization. Matches R's ``synthdid::sc.weight.fw(Yc[1:N0, ], zeta=zeta.lambda, intercept=TRUE)`` where ``Yc`` is the collapsed-form matrix. Uses two-pass optimization with sparsification (same as unit weights), matching R's default ``sparsify=sparsify_function``. Parameters ---------- Y_pre_control : np.ndarray Control outcomes in pre-treatment periods, shape (n_pre, n_control). Y_post_control : np.ndarray Control outcomes in post-treatment periods, shape (n_post, n_control). zeta_lambda : float Regularization parameter for time weights. intercept : bool, default True If True, column-center the optimization matrix. min_decrease : float, default 1e-5 Convergence criterion for Frank-Wolfe. R uses ``1e-5 * noise_level``. max_iter_pre_sparsify : int, default 100 Iterations for first pass (before sparsification). max_iter : int, default 10000 Maximum iterations for second pass (after sparsification). Matches R's default. init_weights : np.ndarray, optional Warm-start weights for the first Frank-Wolfe pass, shape ``(n_pre,)``. If None (default), the solver starts from uniform, matching the top-level ``synthdid_estimate(update.lambda=TRUE)`` path. When provided, the Rust fast-path is skipped in favor of the Python two-pass dispatcher so the first-pass init can be threaded through; this matches R's ``synthdid::bootstrap_sample`` shape (which passes ``weights$lambda`` as FW init per draw). Used by ``SyntheticDiD._bootstrap_se`` on the refit loop. return_convergence : bool, default False If True, returns a tuple ``(weights, converged)`` where ``converged`` is the AND of the first-pass and second-pass convergence flags from the underlying ``_sc_weight_fw`` calls (True iff the min-decrease criterion fired on BOTH passes; False if either hit ``max_iter``). Setting this flag also forces the Python two-pass dispatcher even when ``init_weights`` is None, because the Rust top-level fast-path is silent on non-convergence. Used by SDID bootstrap to surface per-draw FW non-convergence explicitly; standalone callers can leave this at the default to preserve the legacy ABI. Returns ------- np.ndarray or Tuple[np.ndarray, bool] Time weights of shape (n_pre,) on the simplex. With ``return_convergence=True``, additionally the two-pass convergence flag (as described above). """ if Y_post_control.shape[0] == 0: raise ValueError( "Y_post_control has no rows. At least one post-treatment period " "is required for time weight computation." ) # When the caller asks for convergence tracking, skip the Rust top-level # fast-path even if init_weights is None — that entry point bypasses the # Python two-pass dispatcher and is silent on FW non-convergence. if HAS_RUST_BACKEND and init_weights is None and not return_convergence: return np.asarray( _rust_compute_time_weights( np.ascontiguousarray(Y_pre_control, dtype=np.float64), np.ascontiguousarray(Y_post_control, dtype=np.float64), zeta_lambda, intercept, min_decrease, max_iter_pre_sparsify, max_iter, ) ) n_pre = Y_pre_control.shape[0] if n_pre <= 1: lam_trivial = np.ones(n_pre) if return_convergence: return lam_trivial, True return lam_trivial # Build collapsed form: (N_co, T_pre + 1), last col = per-control post mean post_means = np.mean(Y_post_control, axis=0) # (N_co,) Y_time = np.column_stack([Y_pre_control.T, post_means]) # (N_co, T_pre+1) # First pass: limited iterations (matching R's max.iter.pre.sparsify). # init_weights is either None (uniform start) or the caller-supplied # warm-start; the inner _sc_weight_fw still dispatches to Rust for the # 100-iter run, so we only pay a Python-level dispatch overhead. if return_convergence: lam, conv1 = _sc_weight_fw( Y_time, zeta=zeta_lambda, intercept=intercept, init_weights=init_weights, min_decrease=min_decrease, max_iter=max_iter_pre_sparsify, return_convergence=True, ) else: lam = _sc_weight_fw( Y_time, zeta=zeta_lambda, intercept=intercept, init_weights=init_weights, min_decrease=min_decrease, max_iter=max_iter_pre_sparsify, ) # Sparsify: zero out small weights, renormalize (R's sparsify_function) lam = _sparsify(lam) # Second pass: from sparsified initialization (matching R's max.iter) if return_convergence: lam, conv2 = _sc_weight_fw( Y_time, zeta=zeta_lambda, intercept=intercept, init_weights=lam, min_decrease=min_decrease, max_iter=max_iter, return_convergence=True, ) return lam, bool(conv1 and conv2) lam = _sc_weight_fw( Y_time, zeta=zeta_lambda, intercept=intercept, init_weights=lam, min_decrease=min_decrease, max_iter=max_iter, ) return lam def compute_sdid_unit_weights( Y_pre_control: np.ndarray, Y_pre_treated_mean: np.ndarray, zeta_omega: float, intercept: bool = True, min_decrease: float = 1e-5, max_iter_pre_sparsify: int = 100, max_iter: int = 10000, init_weights: Optional[np.ndarray] = None, return_convergence: bool = False, ): """Compute SDID unit weights via Frank-Wolfe with two-pass sparsification. Matches R's ``synthdid::sc.weight.fw(t(Yc[, 1:T0]), zeta=zeta.omega, intercept=TRUE)`` followed by the sparsify/re-optimize pass. Parameters ---------- Y_pre_control : np.ndarray Control outcomes in pre-treatment periods, shape (n_pre, n_control). Y_pre_treated_mean : np.ndarray Mean treated outcomes in pre-treatment periods, shape (n_pre,). zeta_omega : float Regularization parameter for unit weights. intercept : bool, default True If True, column-center the optimization matrix. min_decrease : float, default 1e-5 Convergence criterion for Frank-Wolfe. R uses ``1e-5 * noise_level``. max_iter_pre_sparsify : int, default 100 Iterations for first pass (before sparsification). max_iter : int, default 10000 Iterations for second pass (after sparsification). Matches R's default. init_weights : np.ndarray, optional Warm-start weights for the first Frank-Wolfe pass, shape ``(n_control,)``. If None (default), the solver starts from uniform — matching the top-level ``synthdid_estimate(update.omega=TRUE)`` path. When provided, the Rust fast-path is skipped in favor of the Python two-pass dispatcher so the first-pass init can be threaded through; this matches R's ``synthdid::bootstrap_sample`` shape (which passes ``sum_normalize(weights$omega[...])`` as FW init per draw). Used by ``SyntheticDiD._bootstrap_se`` on the refit loop. return_convergence : bool, default False If True, returns a tuple ``(weights, converged)`` where ``converged`` is the AND of the first-pass and second-pass convergence flags from the underlying ``_sc_weight_fw`` calls (True iff the min-decrease criterion fired on BOTH passes; False if either hit ``max_iter``). Setting this flag also forces the Python two-pass dispatcher even when ``init_weights`` is None, because the Rust top-level fast-path is silent on non-convergence. Used by SDID bootstrap to surface per-draw FW non-convergence explicitly; standalone callers can leave this at the default to preserve the legacy ABI. Returns ------- np.ndarray or Tuple[np.ndarray, bool] Unit weights of shape (n_control,) on the simplex. With ``return_convergence=True``, additionally the two-pass convergence flag (as described above). """ n_control = Y_pre_control.shape[1] if n_control == 0: empty = np.asarray([]) if return_convergence: return empty, True return empty if n_control == 1: singleton = np.asarray([1.0]) if return_convergence: return singleton, True return singleton # When the caller asks for convergence tracking, skip the Rust top-level # fast-path even if init_weights is None — that entry point bypasses the # Python two-pass dispatcher and is silent on FW non-convergence. if HAS_RUST_BACKEND and init_weights is None and not return_convergence: return np.asarray( _rust_sdid_unit_weights( np.ascontiguousarray(Y_pre_control, dtype=np.float64), np.ascontiguousarray(Y_pre_treated_mean, dtype=np.float64), zeta_omega, intercept, min_decrease, max_iter_pre_sparsify, max_iter, ) ) # Build collapsed form: (T_pre, N_co + 1), last col = treated pre means Y_unit = np.column_stack([Y_pre_control, Y_pre_treated_mean.reshape(-1, 1)]) # First pass: limited iterations. init_weights is either None (uniform # start) or the caller-supplied warm-start; the inner _sc_weight_fw # still dispatches to Rust for the 100-iter run, so we only pay a # Python-level dispatch overhead. if return_convergence: omega, conv1 = _sc_weight_fw( Y_unit, zeta=zeta_omega, intercept=intercept, init_weights=init_weights, max_iter=max_iter_pre_sparsify, min_decrease=min_decrease, return_convergence=True, ) else: omega = _sc_weight_fw( Y_unit, zeta=zeta_omega, intercept=intercept, init_weights=init_weights, max_iter=max_iter_pre_sparsify, min_decrease=min_decrease, ) # Sparsify: zero out weights <= max/4, renormalize omega = _sparsify(omega) # Second pass: from sparsified initialization if return_convergence: omega, conv2 = _sc_weight_fw( Y_unit, zeta=zeta_omega, intercept=intercept, init_weights=omega, max_iter=max_iter, min_decrease=min_decrease, return_convergence=True, ) return omega, bool(conv1 and conv2) omega = _sc_weight_fw( Y_unit, zeta=zeta_omega, intercept=intercept, init_weights=omega, max_iter=max_iter, min_decrease=min_decrease, ) return omega # ============================================================================= # Survey-weighted SDID FW helpers (PR #352 — internal, called from # SyntheticDiD._bootstrap_se on per-draw survey-weighted refits) # ============================================================================= def compute_sdid_unit_weights_survey( Y_pre_control: np.ndarray, Y_pre_treated_mean: np.ndarray, rw_control: np.ndarray, zeta_omega: float, intercept: bool = True, min_decrease: float = 1e-5, max_iter_pre_sparsify: int = 100, max_iter: int = 10000, init_weights: Optional[np.ndarray] = None, return_convergence: bool = False, ): """Survey-weighted SDID unit weights via two-pass weighted Frank-Wolfe. Solves the weighted-FW objective (PR #352 §2.2):: min_{ω on simplex} Σ_t (Σ_i rw_control[i]·ω[i]·Y_pre_control[t,i] - Y_pre_treated_mean[t])² + ζ²·Σ_i rw_control[i]·ω[i]² Implementation: pre-scales each control column of Y_unit by ``rw_control`` (so the loss term picks up the per-control linear combination) and passes ``rw_control`` as ``reg_weights`` to ``_sc_weight_fw`` (so the regularization picks up the per-ω scaling). Two-pass sparsify-refit structure mirrors ``compute_sdid_unit_weights``. The returned ω is on the standard simplex. The caller (typically ``SyntheticDiD._bootstrap_se``) is responsible for composing ``ω_eff = rw_control·ω / Σ(rw_control·ω)`` for the downstream SDID estimator, which expects a normalized weight vector. Parameters ---------- Y_pre_control : np.ndarray Control outcomes in pre-treatment periods, shape (n_pre, n_control). Y_pre_treated_mean : np.ndarray Mean treated outcomes in pre-treatment periods, shape (n_pre,). rw_control : np.ndarray Per-control survey weights, shape (n_control,). Must be non-negative. For pweight-only bootstrap this is the constant survey weight per control unit; for Rao-Wu bootstrap this is the per-draw rescaled weight (``generate_rao_wu_weights`` output sliced to control units). zeta_omega : float Regularization parameter (already normalized by Y_scale). intercept : bool, default True Column-center the optimization matrix. min_decrease : float, default 1e-5 Convergence criterion. max_iter_pre_sparsify : int, default 100 First-pass iteration cap before sparsification. max_iter : int, default 10000 Second-pass iteration cap. init_weights : np.ndarray, optional Warm-start weights for the first pass; shape (n_control,). return_convergence : bool, default False If True, returns ``(ω, converged)`` where converged is the AND of both passes' convergence flags. Returns ------- np.ndarray or Tuple[np.ndarray, bool] ω on the simplex (NOT ω_eff). """ n_control = Y_pre_control.shape[1] if rw_control.shape != (n_control,): raise ValueError( f"rw_control shape {rw_control.shape} does not match expected " f"({n_control},)" ) if n_control == 0: empty = np.asarray([]) return (empty, True) if return_convergence else empty if n_control == 1: singleton = np.asarray([1.0]) return (singleton, True) if return_convergence else singleton # Build the column-scaled Y matrix: each control column j is multiplied by # rw_control[j], so A·ω in the loss equals Σ_j rw_j·ω_j·Y_j,pre. rw = np.ascontiguousarray(rw_control, dtype=np.float64) Y_scaled = np.column_stack( [ Y_pre_control * rw[np.newaxis, :], Y_pre_treated_mean.reshape(-1, 1), ] ) if return_convergence: omega, conv1 = _sc_weight_fw( Y_scaled, zeta=zeta_omega, intercept=intercept, init_weights=init_weights, max_iter=max_iter_pre_sparsify, min_decrease=min_decrease, return_convergence=True, reg_weights=rw, ) else: omega = _sc_weight_fw( Y_scaled, zeta=zeta_omega, intercept=intercept, init_weights=init_weights, max_iter=max_iter_pre_sparsify, min_decrease=min_decrease, reg_weights=rw, ) omega = _sparsify(omega) if return_convergence: omega, conv2 = _sc_weight_fw( Y_scaled, zeta=zeta_omega, intercept=intercept, init_weights=omega, max_iter=max_iter, min_decrease=min_decrease, return_convergence=True, reg_weights=rw, ) return omega, bool(conv1 and conv2) return _sc_weight_fw( Y_scaled, zeta=zeta_omega, intercept=intercept, init_weights=omega, max_iter=max_iter, min_decrease=min_decrease, reg_weights=rw, ) def compute_time_weights_survey( Y_pre_control: np.ndarray, Y_post_control: np.ndarray, rw_control: np.ndarray, zeta_lambda: float, intercept: bool = True, min_decrease: float = 1e-5, max_iter_pre_sparsify: int = 100, max_iter: int = 10000, init_weights: Optional[np.ndarray] = None, return_convergence: bool = False, ): """Survey-weighted SDID time weights via two-pass row-weighted FW. Solves the WLS-style time-weight objective (PR #352 §2.2):: min_{λ on simplex} Σ_u rw_control[u]·(Σ_t λ[t]·Y_u,pre-centered[t] - Y_u,post_mean-centered)² + ζ²·||λ||² Regularization stays uniform on λ (rw is per-control, λ is per-period — no alignment for per-λ reg weighting). The loss term uses WLS-style row weights; when ``intercept=True``, the column-centering step is *also* survey-weighted (weighted mean across controls, weights ``rw_control``) so the centered loss minimizes ``Σ_u rw_u·(A_u·λ - b_u)²`` on the rw-centered matrix — equivalent to the stated weighted objective. The Rust kernel then sees the weighted-centered + sqrt(rw)-row-scaled matrix with ``intercept=False`` (no additional unweighted centering). The returned λ is on the standard simplex. Parameters ---------- Y_pre_control : np.ndarray Shape (n_pre, n_control). Y_post_control : np.ndarray Shape (n_post, n_control). rw_control : np.ndarray Shape (n_control,), non-negative. zeta_lambda : float Regularization parameter (already normalized by Y_scale). Other parameters mirror ``compute_time_weights``. Returns ------- np.ndarray or Tuple[np.ndarray, bool] λ on the simplex. """ n_pre = Y_pre_control.shape[0] n_control = Y_pre_control.shape[1] if rw_control.shape != (n_control,): raise ValueError( f"rw_control shape {rw_control.shape} does not match expected " f"({n_control},)" ) if Y_post_control.shape[0] == 0: raise ValueError( "Y_post_control has no rows. At least one post-treatment period " "is required for time weight computation." ) if n_pre <= 1: lam_trivial = np.ones(n_pre) return (lam_trivial, True) if return_convergence else lam_trivial # Build collapsed form like compute_time_weights: (N_co, T_pre+1) post_means = np.mean(Y_post_control, axis=0) Y_time = np.column_stack([Y_pre_control.T, post_means]) # (N_co, T_pre+1) # Column-center the (N_co, T_pre+1) matrix using the SURVEY-WEIGHTED # mean across control units when ``intercept=True``. Plain # ``intercept=True`` inside the FW kernel would use an unweighted # column mean which does not correspond to the stated weighted-loss # objective once ``rw_control`` varies. Perform the weighted # centering here and pass ``intercept=False`` below so the kernel # does not re-center on the row-scaled matrix. rw_sum = float(np.sum(rw_control)) if intercept and rw_sum > 0: col_weighted_means = (Y_time * rw_control[:, np.newaxis]).sum(axis=0) / rw_sum Y_time = Y_time - col_weighted_means[np.newaxis, :] # Row-scale by sqrt(rw): after weighted centering (if any), each # control unit's contribution to the loss is weighted by # ``rw_control[u]`` via the sqrt(rw) row scaling, which reproduces # ``||diag(sqrt(rw))·(A·λ - b)||²`` = ``Σ_u rw_u·(A_u·λ - b_u)²``. # Reg on λ stays uniform (no reg_weights). sqrt_rw = np.sqrt(np.maximum(rw_control, 0.0)) Y_weighted = Y_time * sqrt_rw[:, np.newaxis] if return_convergence: lam, conv1 = _sc_weight_fw( Y_weighted, zeta=zeta_lambda, intercept=False, # weighted centering already applied above init_weights=init_weights, min_decrease=min_decrease, max_iter=max_iter_pre_sparsify, return_convergence=True, ) else: lam = _sc_weight_fw( Y_weighted, zeta=zeta_lambda, intercept=False, # weighted centering already applied above init_weights=init_weights, min_decrease=min_decrease, max_iter=max_iter_pre_sparsify, ) lam = _sparsify(lam) if return_convergence: lam, conv2 = _sc_weight_fw( Y_weighted, zeta=zeta_lambda, intercept=False, # weighted centering already applied above init_weights=lam, min_decrease=min_decrease, max_iter=max_iter, return_convergence=True, ) return lam, bool(conv1 and conv2) return _sc_weight_fw( Y_weighted, zeta=zeta_lambda, intercept=False, # weighted centering already applied above init_weights=lam, min_decrease=min_decrease, max_iter=max_iter, ) def compute_sdid_estimator( Y_pre_control: np.ndarray, Y_post_control: np.ndarray, Y_pre_treated: np.ndarray, Y_post_treated: np.ndarray, unit_weights: np.ndarray, time_weights: np.ndarray, ) -> float: """ Compute the Synthetic DiD estimator. Parameters ---------- Y_pre_control : np.ndarray Control outcomes in pre-treatment periods, shape (n_pre, n_control). Y_post_control : np.ndarray Control outcomes in post-treatment periods, shape (n_post, n_control). Y_pre_treated : np.ndarray Treated unit outcomes in pre-treatment periods, shape (n_pre,). Y_post_treated : np.ndarray Treated unit outcomes in post-treatment periods, shape (n_post,). unit_weights : np.ndarray Weights for control units, shape (n_control,). time_weights : np.ndarray Weights for pre-treatment periods, shape (n_pre,). Returns ------- float The synthetic DiD treatment effect estimate. Notes ----- The SDID estimator is: τ̂ = (Ȳ_treated,post - Σ_t λ_t * Y_treated,t) - Σ_j ω_j * (Ȳ_j,post - Σ_t λ_t * Y_j,t) Where: - ω_j are unit weights - λ_t are time weights - Ȳ denotes average over post periods """ # Weighted pre-treatment averages weighted_pre_control = time_weights @ Y_pre_control # shape: (n_control,) weighted_pre_treated = time_weights @ Y_pre_treated # scalar # Post-treatment averages mean_post_control = np.mean(Y_post_control, axis=0) # shape: (n_control,) mean_post_treated = np.mean(Y_post_treated) # scalar # DiD for treated: post - weighted pre did_treated = mean_post_treated - weighted_pre_treated # Weighted DiD for controls: sum over j of omega_j * (post_j - weighted_pre_j) did_control = unit_weights @ (mean_post_control - weighted_pre_control) # SDID estimator tau = did_treated - did_control return float(tau) def demean_by_group( data: pd.DataFrame, variables: List[str], group_var: str, inplace: bool = False, suffix: str = "", weights: Optional[np.ndarray] = None, ) -> Tuple[pd.DataFrame, int]: """ Demean variables by a grouping variable (one-way within transformation). For each variable, computes: x_ig - mean(x_g) where g is the group. When weights are provided, uses weighted group means: mean_g = sum(w_i * x_i) / sum(w_i) for i in group g. Parameters ---------- data : pd.DataFrame DataFrame containing the variables to demean. variables : list of str Column names to demean. group_var : str Column name for the grouping variable. inplace : bool, default False If True, modifies the original columns. If False, leaves original columns unchanged (demeaning is still applied to return value). suffix : str, default "" Suffix to add to demeaned column names (only used when inplace=False and you want to keep both original and demeaned columns). weights : np.ndarray, optional Observation weights for weighted group means. Returns ------- data : pd.DataFrame DataFrame with demeaned variables. n_effects : int Number of absorbed fixed effects (nunique - 1). Examples -------- >>> df, n_fe = demean_by_group(df, ['y', 'x1', 'x2'], 'unit') >>> # df['y'], df['x1'], df['x2'] are now demeaned by unit """ if not inplace: data = data.copy() # Count fixed effects (categories - 1 for identification) n_effects = data[group_var].nunique() - 1 if weights is not None: # Weighted demeaning: weighted_mean_g = sum(w*x) / sum(w) per group groups = data[group_var].values w = np.asarray(weights, dtype=np.float64) # Cache weight sums per group (invariant across variables) w_sum = pd.Series(w).groupby(groups).transform("sum") for var in variables: col_name = var if not suffix else f"{var}{suffix}" x = data[var].values.astype(np.float64) wx = pd.Series(w * x).groupby(groups).transform("sum") weighted_means = wx / w_sum data[col_name] = x - weighted_means.values else: # Cache the groupby object for efficiency grouper = data.groupby(group_var, sort=False) for var in variables: col_name = var if not suffix else f"{var}{suffix}" group_means = grouper[var].transform("mean") data[col_name] = data[var] - group_means return data, n_effects def within_transform( data: pd.DataFrame, variables: List[str], unit: str, time: str, inplace: bool = False, suffix: str = "_demeaned", weights: Optional[np.ndarray] = None, max_iter: int = 100, tol: float = 1e-8, ) -> pd.DataFrame: """ Apply two-way within transformation to remove unit and time fixed effects. Computes: y_it - y_i. - y_.t + y_.. for each variable. When weights are provided, uses weighted group means at each step. This is the standard fixed effects transformation for panel data that removes both unit-specific and time-specific effects. Parameters ---------- data : pd.DataFrame Panel data containing the variables to transform. variables : list of str Column names to transform. unit : str Column name for unit identifier. time : str Column name for time period identifier. inplace : bool, default False If True, modifies the original columns. If False, creates new columns with the specified suffix. suffix : str, default "_demeaned" Suffix for new column names when inplace=False. weights : np.ndarray, optional Observation weights for weighted group means. max_iter : int, default 100 Maximum number of alternating-projection iterations. Used only when ``weights`` is not ``None``; the unweighted path is a single pass and ignores this argument. Emits a ``UserWarning`` per call when any variable fails to converge within this budget. tol : float, default 1e-8 Convergence tolerance on the max absolute change across the iterate. Used only when ``weights`` is not ``None``. Returns ------- pd.DataFrame DataFrame with within-transformed variables. Notes ----- The within transformation removes variation that is constant within units (unit fixed effects) and constant within time periods (time fixed effects). The resulting estimates are equivalent to including unit and time dummies but is computationally more efficient for large panels. Examples -------- >>> df = within_transform(df, ['y', 'x'], 'unit_id', 'year') >>> # df now has 'y_demeaned' and 'x_demeaned' columns """ if not inplace: data = data.copy() if weights is not None: # Weighted within-transformation via iterative alternating projections w = np.asarray(weights, dtype=np.float64) unit_groups = data[unit].values time_groups = data[time].values # Cache weight sums per group (invariant across variables) unit_w_sum = pd.Series(w).groupby(unit_groups).transform("sum").values time_w_sum = pd.Series(w).groupby(time_groups).transform("sum").values def _weighted_group_demean(x, groups, w, w_sum): wx_sum = pd.Series(w * x).groupby(groups).transform("sum").values return x - wx_sum / w_sum non_converged_vars: List[str] = [] if inplace: for var in variables: x = data[var].values.astype(np.float64) converged = False for _iter in range(max_iter): x_old = x.copy() x = _weighted_group_demean(x, unit_groups, w, unit_w_sum) x = _weighted_group_demean(x, time_groups, w, time_w_sum) if np.max(np.abs(x - x_old)) < tol: converged = True break if not converged: non_converged_vars.append(var) data[var] = x else: demeaned_data = {} for var in variables: x = data[var].values.astype(np.float64) converged = False for _iter in range(max_iter): x_old = x.copy() x = _weighted_group_demean(x, unit_groups, w, unit_w_sum) x = _weighted_group_demean(x, time_groups, w, time_w_sum) if np.max(np.abs(x - x_old)) < tol: converged = True break if not converged: non_converged_vars.append(var) demeaned_data[f"{var}{suffix}"] = x demeaned_df = pd.DataFrame(demeaned_data, index=data.index) data = pd.concat([data, demeaned_df], axis=1) if non_converged_vars: warn_if_not_converged( False, f"within_transform weighted demean (variables: {non_converged_vars})", max_iter, tol, ) else: # Cache groupby objects for efficiency unit_grouper = data.groupby(unit, sort=False) time_grouper = data.groupby(time, sort=False) if inplace: for var in variables: unit_means = unit_grouper[var].transform("mean") time_means = time_grouper[var].transform("mean") grand_mean = data[var].mean() data[var] = data[var] - unit_means - time_means + grand_mean else: demeaned_data = {} for var in variables: unit_means = unit_grouper[var].transform("mean") time_means = time_grouper[var].transform("mean") grand_mean = data[var].mean() demeaned_data[f"{var}{suffix}"] = ( data[var] - unit_means - time_means + grand_mean ).values demeaned_df = pd.DataFrame(demeaned_data, index=data.index) data = pd.concat([data, demeaned_df], axis=1) return data