Source code for diff_diff.had_pretests

"""Pre-test diagnostics for the HeterogeneousAdoptionDiD estimator.

Paper Section 4 (de Chaisemartin, Ciccia, D'Haultfoeuille, Knau 2026,
arXiv:2405.04465v6) prescribes a four-step pre-testing workflow for TWFE
validity in HADs. This module ships the tests and the composite workflow:

Single-horizon tests:

1. :func:`qug_test` - order-statistic ratio test of the support infimum
   ``H_0: d_lower = 0`` (paper Theorem 4). Closed-form, tuning-free.
2. :func:`stute_test` - Cramer-von Mises cusum test of linearity of
   ``E[ΔY | D_2]`` with Mammen (1993) wild bootstrap p-value (paper
   Appendix D).
3. :func:`yatchew_hr_test` - heteroskedasticity-robust variance-ratio
   specification test (paper Theorem 7 / Equation 29). Feasible at
   ``G >= 100k``. Two nulls via the keyword-only ``null=`` argument:
   ``"linearity"`` (default; paper Theorem 7, fits ``Y ~ 1 + D``) and
   ``"mean_independence"`` (R-parity extension mirroring R
   ``YatchewTest::yatchew_test(order=0)``; fits ``Y ~ 1``). The
   downstream variance-ratio machinery is shared between the two
   modes — only the residual definition differs.

Joint / multi-period tests (Phase 3 follow-up):

4. :func:`stute_joint_pretest` - residuals-in core that generalizes the
   single-horizon Stute CvM to K horizons with shared-η wild bootstrap
   and sum-of-CvMs aggregation (Delgado 1993; Escanciano 2006).
5. :func:`joint_pretrends_test` - data-in wrapper for the mean-
   independence null (paper step 2 pre-trends across pre-period
   placebos, Section 4.2 footnote 6 + Section 4.3 paragraph 1).
6. :func:`joint_homogeneity_test` - data-in wrapper for the linearity
   null across post-periods (paper Section 4.3 joint extension,
   page 32).

Composite workflow:

:func:`did_had_pretest_workflow` has two dispatch modes:

- ``aggregate="overall"`` (default, two-period panel): runs steps 1 + 3
  via :func:`qug_test` + :func:`stute_test` + :func:`yatchew_hr_test`.
  Paper step 2 is NOT run on this path (a two-period panel has no pre-
  period placebo); the verdict explicitly flags the Assumption 7 gap
  via the ``"paper step 2 deferred"`` caveat so callers do not get an
  unconditional "TWFE safe" signal.
- ``aggregate="event_study"`` (multi-period panel, >= 3 periods): runs
  QUG at ``F`` + joint pre-trends Stute across earlier pre-periods +
  joint homogeneity-linearity Stute across post-periods. Closes the
  paper step-2 gap and does NOT emit the step-2-deferred caveat in the
  verdict when at least one earlier pre-period is available. The
  step-3 alternative (Yatchew-HR linearity) is subsumed by joint Stute
  on this path; the paper does not derive a joint Yatchew variant, so
  users who need Yatchew robustness under multi-period data can call
  :func:`yatchew_hr_test` on each ``(base, post)`` pair manually.
  (Step 4 in the paper's workflow is the decision itself - "use TWFE
  if none of the tests rejects" - not a separate test.)

Eq. 17 / Eq. 18 linear-trend detrending (paper Section 5.2 Pierce-Schott
application) shipped in PR #389 (Phase 4 R-parity) as the
``trends_lin: bool = False`` keyword-only kwarg on
:func:`joint_pretrends_test`, :func:`joint_homogeneity_test`, AND
:meth:`HeterogeneousAdoptionDiD.fit` (event-study path). Mirrors R
``DIDHAD::did_had(..., trends_lin=TRUE)``. Survey-weighted variant is
not yet derived from the paper and raises ``NotImplementedError``;
tracked in ``TODO.md`` if user demand emerges. See
``docs/methodology/REGISTRY.md`` for the full algorithm narrative,
invariants, and deviation notes.
"""

from __future__ import annotations

import warnings
from dataclasses import dataclass
from typing import Any, Dict, Literal, Optional

import numpy as np
import pandas as pd
from scipy import stats

from diff_diff.bootstrap_utils import (
    apply_stratum_centering,
    generate_survey_multiplier_weights_batch,
)
from diff_diff.had import (
    _aggregate_first_difference,
    _aggregate_unit_resolved_survey,
    _aggregate_unit_weights,
    _json_safe_scalar,
    _validate_had_panel,
    _validate_had_panel_event_study,
)
from diff_diff.survey import (
    HAD_DEPRECATION_MSG_SURVEY_KWARG,
    HAD_DEPRECATION_MSG_WEIGHTS_KWARG_ARRAY_IN,
    HAD_DEPRECATION_MSG_WEIGHTS_KWARG_DATA_IN,
    HAD_DUAL_KNOB_MUTEX_MSG_ARRAY_IN,
    HAD_DUAL_KNOB_MUTEX_MSG_DATA_IN,
    SurveyDesign,
    make_pweight_design,
)
from diff_diff.utils import _generate_mammen_weights

__all__ = [
    "QUGTestResults",
    "StuteTestResults",
    "YatchewTestResults",
    "StuteJointResult",
    "HADPretestReport",
    "qug_test",
    "stute_test",
    "yatchew_hr_test",
    "stute_joint_pretest",
    "joint_pretrends_test",
    "joint_homogeneity_test",
    "did_had_pretest_workflow",
]


_MIN_G_QUG = 2
_MIN_G_STUTE = 10
_MIN_G_YATCHEW = 3
_MIN_N_BOOTSTRAP = 99
_STUTE_LARGE_G_THRESHOLD = 100_000


# Scale-invariant tolerance for detecting a numerically exact linear OLS fit.
# The ratio SSR / TSS = sum(eps^2) / sum((dy - dybar)^2) equals 1 - R^2
# and is BOTH TRANSLATION-INVARIANT (centering absorbs additive shifts)
# and SCALE-INVARIANT (the ratio is dimensionless under multiplicative
# rescaling of dy). Under exact Assumption 8, residuals are mathematically
# zero; in practice FP round-off leaves eps on the order of machine-epsilon
# (~1e-16). Squared that is ~1e-32. The threshold ~1e-24 leaves ~10^8
# accumulated FP operations of margin so genuinely-noisy data is never
# mis-classified.
#
# IMPORTANT: the comparison is purely ``eps^2 <= tol * dy_centered^2`` with
# NO additive floor (e.g. ``max(dy_centered^2, 1.0)`` would break scale
# invariance - scaling dy by 1e-12 would make dy_centered^2 ~ 1e-24 but
# the floor would hold the threshold at 1.0, firing the shortcut on
# noisy data that should not trigger it). The ``dy_centered^2 == 0``
# edge case (constant dy) is handled by a separate branch above the
# relative comparison, so the relative form is only applied when the
# denominator is genuinely positive.
_EXACT_LINEAR_RELATIVE_TOL = 1e-24


# =============================================================================
# Result dataclasses
# =============================================================================


[docs] @dataclass class QUGTestResults: """Result of :func:`qug_test` (paper Theorem 4). The QUG test rejects ``H_0: d_lower = 0`` when the order-statistic ratio ``T = D_{(1)} / (D_{(2)} - D_{(1)})`` exceeds ``1/alpha - 1``. Under the null, the asymptotic limit law of ``T`` is the ratio of two independent Exp(1) random variables, with CDF ``F(t) = t / (1 + t)``, so ``p_value = 1 / (1 + T)``. Attributes ---------- t_stat : float ``D_{(1)} / (D_{(2)} - D_{(1)})``. NaN when fewer than 2 non-zero observations remain or when the two smallest doses tie. p_value : float ``1 / (1 + t_stat)`` under the null. NaN when ``t_stat`` is NaN. reject : bool ``True`` iff ``t_stat > critical_value``. ``False`` on NaN statistic. alpha : float Significance level used. critical_value : float ``1 / alpha - 1``. Populated even when the statistic is NaN so downstream readers can inspect the decision threshold. n_obs : int Number of observations after filtering to ``d > 0``. n_excluded_zero : int Number of zero-dose observations excluded from the sample. d_order_1 : float Smallest positive dose ``D_{(1)}``. NaN when ``n_obs < 2``. d_order_2 : float Second-smallest positive dose ``D_{(2)}``. NaN when ``n_obs < 2``. """ t_stat: float p_value: float reject: bool alpha: float critical_value: float n_obs: int n_excluded_zero: int d_order_1: float d_order_2: float def __repr__(self) -> str: return ( f"QUGTestResults(t_stat={self.t_stat:.4f}, p_value={self.p_value:.4f}, " f"reject={self.reject}, alpha={self.alpha}, n_obs={self.n_obs})" )
[docs] def summary(self) -> str: """Formatted summary table.""" width = 64 lines = [ "=" * width, "QUG null test (H_0: d_lower = 0)".center(width), "=" * width, f"{'Statistic T:':<30} {self.t_stat:>20.4f}", f"{'p-value:':<30} {self.p_value:>20.4f}", f"{'Critical value (1/alpha-1):':<30} {self.critical_value:>20.4f}", f"{'Reject H_0:':<30} {str(self.reject):>20}", f"{'alpha:':<30} {self.alpha:>20.4f}", f"{'Observations:':<30} {self.n_obs:>20}", f"{'Excluded (d == 0):':<30} {self.n_excluded_zero:>20}", f"{'D_(1):':<30} {self.d_order_1:>20.4f}", f"{'D_(2):':<30} {self.d_order_2:>20.4f}", "=" * width, ] return "\n".join(lines)
[docs] def print_summary(self) -> None: """Print the summary to stdout.""" print(self.summary())
[docs] def to_dict(self) -> Dict[str, Any]: """Return results as a JSON-safe dict.""" return { "test": "qug", "t_stat": _json_safe_scalar(self.t_stat), "p_value": _json_safe_scalar(self.p_value), "reject": bool(self.reject), "alpha": float(self.alpha), "critical_value": _json_safe_scalar(self.critical_value), "n_obs": int(self.n_obs), "n_excluded_zero": int(self.n_excluded_zero), "d_order_1": _json_safe_scalar(self.d_order_1), "d_order_2": _json_safe_scalar(self.d_order_2), }
[docs] def to_dataframe(self) -> pd.DataFrame: """Return a one-row DataFrame of the result dict.""" return pd.DataFrame([self.to_dict()])
[docs] @dataclass class StuteTestResults: """Result of :func:`stute_test` (paper Appendix D). The Stute test rejects the null that ``E[ΔY | D_2]`` is linear in ``D_2`` (paper Assumption 8) when the sorted-residual CvM statistic ``S = (1/G^2) Σ (Σ_{h=1}^g eps_{(h)})^2`` exceeds the Mammen wild bootstrap ``1 - alpha`` quantile. Attributes ---------- cvm_stat : float CvM statistic. NaN when ``G < 10`` (below the threshold the statistic is not well-calibrated). p_value : float Bootstrap p-value ``(1 + sum(S_b >= S)) / (B + 1)``. NaN when the statistic is NaN. reject : bool ``True`` iff ``p_value <= alpha``. ``False`` on NaN. alpha : float Significance level used. n_bootstrap : int Number of Mammen wild bootstrap replications. n_obs : int Number of observations. seed : int or None Seed passed to ``np.random.default_rng``. ``None`` when unseeded. """ cvm_stat: float p_value: float reject: bool alpha: float n_bootstrap: int n_obs: int seed: Optional[int] def __repr__(self) -> str: return ( f"StuteTestResults(cvm_stat={self.cvm_stat:.4f}, " f"p_value={self.p_value:.4f}, reject={self.reject}, " f"alpha={self.alpha}, n_bootstrap={self.n_bootstrap}, " f"n_obs={self.n_obs})" )
[docs] def summary(self) -> str: """Formatted summary table.""" width = 64 lines = [ "=" * width, "Stute CvM linearity test (H_0: linear E[dY|D])".center(width), "=" * width, f"{'CvM statistic:':<30} {self.cvm_stat:>20.4f}", f"{'Bootstrap p-value:':<30} {self.p_value:>20.4f}", f"{'Reject H_0:':<30} {str(self.reject):>20}", f"{'alpha:':<30} {self.alpha:>20.4f}", f"{'Bootstrap replications:':<30} {self.n_bootstrap:>20}", f"{'Observations:':<30} {self.n_obs:>20}", f"{'Seed:':<30} {str(self.seed):>20}", "=" * width, ] return "\n".join(lines)
[docs] def print_summary(self) -> None: """Print the summary to stdout.""" print(self.summary())
[docs] def to_dict(self) -> Dict[str, Any]: """Return results as a JSON-safe dict.""" return { "test": "stute", "cvm_stat": _json_safe_scalar(self.cvm_stat), "p_value": _json_safe_scalar(self.p_value), "reject": bool(self.reject), "alpha": float(self.alpha), "n_bootstrap": int(self.n_bootstrap), "n_obs": int(self.n_obs), "seed": None if self.seed is None else int(self.seed), }
[docs] def to_dataframe(self) -> pd.DataFrame: """Return a one-row DataFrame of the result dict.""" return pd.DataFrame([self.to_dict()])
[docs] @dataclass class YatchewTestResults: """Result of :func:`yatchew_hr_test` (paper Theorem 7 / Equation 29). Heteroskedasticity-robust specification test using Yatchew's difference-based variance estimator. Two nulls are supported via the ``null=`` argument on :func:`yatchew_hr_test` and reflected on the ``null_form`` attribute below: ``"linearity"`` (default; paper Theorem 7, the same null as :func:`stute_test`, residuals from OLS ``dy ~ 1 + d``) and ``"mean_independence"`` (R-parity extension mirroring R ``YatchewTest::yatchew_test(order=0)``, residuals from intercept-only OLS ``dy ~ 1``). The test statistic ``T_hr = sqrt(G) * (sigma2_lin - sigma2_diff) / sigma2_W`` is asymptotically N(0, 1) under H_0 in both modes; rejection uses the one-sided standard-normal critical value. Only the residual definition (and therefore ``sigma2_lin``) differs between modes — the ``sigma2_diff`` / ``sigma2_W`` / sort-by-``d`` machinery is shared. Attributes ---------- t_stat_hr : float Test statistic ``T_hr`` from paper Equation 29. NaN when ``G < 3``. p_value : float ``1 - Phi(T_hr)``. NaN when the statistic is NaN. reject : bool ``True`` iff ``T_hr >= critical_value``. ``False`` on NaN. alpha : float Significance level used. critical_value : float One-sided standard-normal critical value ``z_{1 - alpha}``. sigma2_lin : float Residual variance under the chosen null. Under ``null_form="linearity"``: residual variance from OLS of ``dy`` on ``d``. Under ``null_form="mean_independence"``: ``(1/G) * sum((dy - mean(dy))^2)``, the population variance of ``dy``. sigma2_diff : float Yatchew differencing variance ``(1 / (2G)) * sum((dy_{(g)} - dy_{(g-1)})^2)`` - divisor is ``2G`` (paper-literal), NOT ``2(G-1)``. sigma2_W : float Heteroskedasticity-robust scale ``sqrt((1 / (G-1)) * sum(eps_{(g)}^2 * eps_{(g-1)}^2))``. n_obs : int Number of observations. null_form : str ``"linearity"`` (default; H_0: ``E[dY|D]`` is linear in ``D``, residuals from OLS ``dy ~ 1 + d``) or ``"mean_independence"`` (H_0: ``E[dY|D] = E[dY]``, residuals from intercept-only OLS ``dy ~ 1``). Mirrors R ``YatchewTest::yatchew_test``'s ``order`` argument (``order=1`` ↔ ``"linearity"``; ``order=0`` ↔ ``"mean_independence"``). """ t_stat_hr: float p_value: float reject: bool alpha: float critical_value: float sigma2_lin: float sigma2_diff: float sigma2_W: float n_obs: int null_form: str = "linearity" def __repr__(self) -> str: return ( f"YatchewTestResults(t_stat_hr={self.t_stat_hr:.4f}, " f"p_value={self.p_value:.4f}, reject={self.reject}, " f"alpha={self.alpha}, null_form={self.null_form!r}, " f"n_obs={self.n_obs})" )
[docs] def summary(self) -> str: """Formatted summary table.""" width = 64 title = { "linearity": "Yatchew-HR linearity test (H_0: linear E[dY|D])", "mean_independence": ("Yatchew-HR mean-independence test (H_0: E[dY|D] = E[dY])"), }.get(self.null_form, f"Yatchew-HR test (null_form={self.null_form!r})") lines = [ "=" * width, title.center(width), "=" * width, f"{'T_hr statistic:':<30} {self.t_stat_hr:>20.4f}", f"{'p-value:':<30} {self.p_value:>20.4f}", f"{'Critical value (1-sided z):':<30} {self.critical_value:>20.4f}", f"{'Reject H_0:':<30} {str(self.reject):>20}", f"{'alpha:':<30} {self.alpha:>20.4f}", f"{'sigma^2_lin (OLS):':<30} {self.sigma2_lin:>20.4f}", f"{'sigma^2_diff (Yatchew):':<30} {self.sigma2_diff:>20.4f}", f"{'sigma^2_W (HR scale):':<30} {self.sigma2_W:>20.4f}", f"{'Observations:':<30} {self.n_obs:>20}", "=" * width, ] return "\n".join(lines)
[docs] def print_summary(self) -> None: """Print the summary to stdout.""" print(self.summary())
[docs] def to_dict(self) -> Dict[str, Any]: """Return results as a JSON-safe dict.""" return { "test": "yatchew_hr", "t_stat_hr": _json_safe_scalar(self.t_stat_hr), "p_value": _json_safe_scalar(self.p_value), "reject": bool(self.reject), "alpha": float(self.alpha), "critical_value": _json_safe_scalar(self.critical_value), "sigma2_lin": _json_safe_scalar(self.sigma2_lin), "sigma2_diff": _json_safe_scalar(self.sigma2_diff), "sigma2_W": _json_safe_scalar(self.sigma2_W), "n_obs": int(self.n_obs), "null_form": str(self.null_form), }
[docs] def to_dataframe(self) -> pd.DataFrame: """Return a one-row DataFrame of the result dict.""" return pd.DataFrame([self.to_dict()])
[docs] @dataclass class StuteJointResult: """Result of :func:`stute_joint_pretest` (joint Cramer-von Mises across horizons). Aggregates the per-horizon Stute (1997) CvM statistic into a joint specification test: ``S_joint = sum_k S_k``, where ``S_k`` is the single-horizon CvM on residuals ``eps_{g,k}``. Inference is via Mammen (1993) wild bootstrap with a **shared** multiplier ``eta_g`` across horizons per unit (Delgado-Manteiga 2001; Hlavka-Huskova 2020) to preserve the unit-level dependence structure of the vector-valued empirical process. Two nulls are supported via the thin wrappers :func:`joint_pretrends_test` (mean-independence: ``E[Y_t - Y_base | D] = mu_t``, design matrix ``[1]``) and :func:`joint_homogeneity_test` (linearity: ``E[Y_t - Y_base | D_t] = beta_{0,t} + beta_{fe,t} * D``, design matrix ``[1, D]``). Both wrappers accept a ``trends_lin: bool = False`` keyword-only flag (PR #392): when ``True``, applies paper Eq 17 / Eq 18 linear-trend detrending before the joint CvM using per-group slope ``Y[g, F-1] - Y[g, F-2]``. Attributes ---------- cvm_stat_joint : float Joint statistic ``S_joint = sum_k S_k``. NaN on NaN-propagation. p_value : float Bootstrap p-value ``(1 + sum(S*_b >= S_joint)) / (B + 1)``. NaN when the statistic is NaN. ``1.0`` when the per-horizon exact- linear short-circuit fires (all horizons machine-exact linear). reject : bool ``True`` iff ``p_value <= alpha``. Always ``False`` on NaN. alpha : float Significance level. horizon_labels : list of str Horizon identifiers as ``str(t)`` for each period. **String identity only** - NOT a chronological ordering key. Callers who need chronological order should preserve the original period values alongside (a downstream plotter sorting labels lexicographically will misorder e.g. ``["2003-Q10", "2003-Q2", ...]``). per_horizon_stats : dict[str, float] ``{label: S_k}`` diagnostic. Per-horizon p-values are NOT exposed (decomposing the joint bootstrap into K independent loops is a K-fold memory/time cost; deferred). Callers who need per-horizon p-values can call :func:`stute_test` separately on each (period, residual) pair. On NaN-propagation (any horizon has NaN input), this dict is preserved with ``{label: np.nan for label in horizon_labels}``, NOT an empty dict, NOT a partial dict: the keys carry diagnostic value (which horizons were attempted), the NaN values signal non-propagation. n_bootstrap : int n_obs : int Number of units ``G``. n_horizons : int seed : int or None null_form : str ``"mean_independence"`` (from :func:`joint_pretrends_test`) or ``"linearity"`` (from :func:`joint_homogeneity_test`). ``"custom"`` when called directly via :func:`stute_joint_pretest` without a wrapper. exact_linear_short_circuited : bool ``True`` when every horizon's residual SSR to centered TSS ratio is below :data:`_EXACT_LINEAR_RELATIVE_TOL`; bootstrap is skipped and ``p_value = 1.0``. The per-horizon check ensures a single degenerate horizon does not collapse the joint test when other horizons have nontrivial residuals. """ cvm_stat_joint: float p_value: float reject: bool alpha: float horizon_labels: list per_horizon_stats: Dict[str, float] n_bootstrap: int n_obs: int n_horizons: int seed: Optional[int] null_form: str exact_linear_short_circuited: bool def __repr__(self) -> str: return ( f"StuteJointResult(cvm_stat_joint={self.cvm_stat_joint:.4f}, " f"p_value={self.p_value:.4f}, reject={self.reject}, " f"n_horizons={self.n_horizons}, null_form={self.null_form!r}, " f"n_obs={self.n_obs})" )
[docs] def summary(self) -> str: """Formatted summary table.""" width = 64 per_horizon_lines = [ f" {label:<20} {stat:>20.4f}" for label, stat in self.per_horizon_stats.items() ] null_label = { "mean_independence": "mean-independence (pre-trends)", "linearity": "linearity (post-homogeneity)", }.get(self.null_form, self.null_form) lines = [ "=" * width, f"Joint Stute CvM test ({null_label})".center(width), "=" * width, f"{'Joint CvM statistic:':<30} {self.cvm_stat_joint:>20.4f}", f"{'Bootstrap p-value:':<30} {self.p_value:>20.4f}", f"{'Reject H_0:':<30} {str(self.reject):>20}", f"{'alpha:':<30} {self.alpha:>20.4f}", f"{'Bootstrap replications:':<30} {self.n_bootstrap:>20}", f"{'Horizons:':<30} {self.n_horizons:>20}", f"{'Observations:':<30} {self.n_obs:>20}", f"{'Seed:':<30} {str(self.seed):>20}", f"{'Exact-linear short-circuit:':<30} " f"{str(self.exact_linear_short_circuited):>20}", "-" * width, "Per-horizon statistics:", *per_horizon_lines, "=" * width, ] return "\n".join(lines)
[docs] def print_summary(self) -> None: """Print the summary to stdout.""" print(self.summary())
[docs] def to_dict(self) -> Dict[str, Any]: """Return results as a JSON-safe dict.""" return { "test": "stute_joint", "cvm_stat_joint": _json_safe_scalar(self.cvm_stat_joint), "p_value": _json_safe_scalar(self.p_value), "reject": bool(self.reject), "alpha": float(self.alpha), "horizon_labels": [str(label) for label in self.horizon_labels], "per_horizon_stats": { str(k): _json_safe_scalar(v) for k, v in self.per_horizon_stats.items() }, "n_bootstrap": int(self.n_bootstrap), "n_obs": int(self.n_obs), "n_horizons": int(self.n_horizons), "seed": None if self.seed is None else int(self.seed), "null_form": str(self.null_form), "exact_linear_short_circuited": bool(self.exact_linear_short_circuited), }
[docs] def to_dataframe(self) -> pd.DataFrame: """Return a one-row DataFrame of the top-level result fields.""" return pd.DataFrame( [ { "test": "stute_joint", "cvm_stat_joint": _json_safe_scalar(self.cvm_stat_joint), "p_value": _json_safe_scalar(self.p_value), "reject": bool(self.reject), "alpha": float(self.alpha), "n_bootstrap": int(self.n_bootstrap), "n_obs": int(self.n_obs), "n_horizons": int(self.n_horizons), "null_form": str(self.null_form), } ] )
[docs] @dataclass class HADPretestReport: """Composite output of :func:`did_had_pretest_workflow`. Two dispatch shapes, distinguished by :attr:`aggregate`: ``aggregate="overall"`` (default, two-period panel): bundles paper steps 1 (QUG) and 3 (linearity via Stute + Yatchew-HR) on a two-period first-differenced sample. Step 2 (Assumption 7 pre-trends) is NOT implemented on this path and is explicitly flagged in the verdict; callers must run pre-trends separately. ``aggregate="event_study"`` (multi-period panel, >= 3 periods): bundles QUG + joint pre-trends Stute + joint homogeneity-linearity Stute. The joint Stute variants close the paper step-2 gap; the event-study verdict does NOT emit the "paper step 2 deferred" caveat. Step 3 adjudication uses joint Stute only - no joint Yatchew variant exists because the paper does not derive one; users who need Yatchew robustness under multi-period data can run :func:`yatchew_hr_test` on each (base, post) pair manually. Attributes ---------- qug : QUGTestResults or None Populated by default; ``None`` only when the workflow runs under ``survey=`` / ``weights=`` (Phase 4.5 C path), where the QUG step is permanently skipped per Phase 4.5 C0 (extreme-value theory under complex sampling not a settled toolkit; see :func:`qug_test`). stute : StuteTestResults or None Populated when ``aggregate == "overall"``; ``None`` when ``aggregate == "event_study"``. yatchew : YatchewTestResults or None Populated when ``aggregate == "overall"``; ``None`` when ``aggregate == "event_study"``. pretrends_joint : StuteJointResult or None Populated when ``aggregate == "event_study"`` and at least one earlier pre-period exists; ``None`` on the overall path or when only the immediate base pre-period is available. homogeneity_joint : StuteJointResult or None Populated when ``aggregate == "event_study"``; ``None`` on the overall path. all_pass : bool On the **unweighted overall path**: same Phase 3 semantics - True iff QUG is conclusive AND at least one of Stute/Yatchew is conclusive AND no conclusive test rejects. On the **unweighted event-study path**: True iff ``np.isfinite(qug.p_value)``, ``pretrends_joint is not None and np.isfinite(pretrends_joint.p_value)``, ``np.isfinite(homogeneity_joint.p_value)``, AND none of the three rejects. On the **survey/weights path** (Phase 4.5 C) the QUG-conclusiveness gate is dropped (``qug=None`` per C0 deferral); the linearity-conditional rule splits by aggregate: - ``aggregate="overall"`` survey: True iff at least one of Stute/Yatchew is conclusive AND no conclusive test rejects. - ``aggregate="event_study"`` survey: True iff ``pretrends_joint`` is non-None and conclusive, ``homogeneity_joint`` is conclusive, AND neither rejects. (Both joint variants must be conclusive on the event-study path - same step-2 + step-3 closure as the unweighted aggregate, just without the QUG step.) Mirrors Phase 3's ``bool(np.isfinite(p_value))`` convention - no ``.conclusive()`` helper on any result dataclass. verdict : str Human-readable classification. Paper rule applies symmetrically: TWFE is admissible only if NONE of the implemented tests rejects. Conclusive rejections are the primary verdict; unresolved steps append as ``"; additional steps unresolved: ..."`` rather than replacing the rejection. alpha : float n_obs : int Unit count. For overall: units after two-period first-difference aggregation. For event_study: units after balanced-panel validation and (if applicable) last-cohort auto-filter. aggregate : str ``"overall"`` or ``"event_study"``. Determines which component fields are populated and which branch of serialization methods to render. """ qug: Optional[QUGTestResults] stute: Optional[StuteTestResults] yatchew: Optional[YatchewTestResults] all_pass: bool verdict: str alpha: float n_obs: int pretrends_joint: Optional[StuteJointResult] = None homogeneity_joint: Optional[StuteJointResult] = None aggregate: str = "overall" def __repr__(self) -> str: # Preserve Phase 3 repr bit-exactly on the overall path. The # aggregate kwarg is only surfaced on the event-study path so # downstream consumers comparing repr strings on two-period # reports see identical output. if self.aggregate == "event_study": return ( f"HADPretestReport(aggregate={self.aggregate!r}, " f"all_pass={self.all_pass}, " f"verdict={self.verdict!r}, n_obs={self.n_obs})" ) return ( f"HADPretestReport(all_pass={self.all_pass}, " f"verdict={self.verdict!r}, n_obs={self.n_obs})" )
[docs] def summary(self) -> str: """Formatted summary of all tests and the verdict.""" width = 72 # Preserve Phase 3 summary bit-exactly on the overall path. The # `aggregate: ...` header line is only rendered on the event- # study path; two-period reports produce the Phase 3 layout. # QUG block: rendered when self.qug is populated, else a skip note # (Phase 4.5 C survey/weights path leaves qug=None; see C0 deferral). qug_block = ( self.qug.summary() if self.qug is not None else "(QUG step skipped - permanently deferred under survey/weights per Phase 4.5 C0)" ) if self.aggregate == "event_study": header = [ "=" * width, "HAD pre-test workflow".center(width), f"aggregate: {self.aggregate}".center(width), "=" * width, qug_block, "", ] if self.pretrends_joint is not None: body = [self.pretrends_joint.summary(), ""] else: body = [ "(joint pre-trends skipped - no earlier pre-period)", "", ] if self.homogeneity_joint is not None: body += [self.homogeneity_joint.summary(), ""] else: # aggregate == "overall" - Phase 3 layout preserved when qug is # not None (unweighted path); QUG-skip block on the survey path. header = [ "=" * width, "HAD pre-test workflow".center(width), "=" * width, qug_block, "", ] body = [] if self.stute is not None: body += [self.stute.summary(), ""] if self.yatchew is not None: body += [self.yatchew.summary(), ""] footer = [ "=" * width, f"{'All pass:':<30} {str(self.all_pass):>40}", f"Verdict: {self.verdict}", "=" * width, ] return "\n".join(header + body + footer)
[docs] def print_summary(self) -> None: """Print the summary to stdout.""" print(self.summary())
[docs] def to_dict(self) -> Dict[str, Any]: """Return a JSON-safe nested dict of the full report. On ``aggregate="overall"``, the output schema is bit-exact with Phase 3 (``{qug, stute, yatchew, all_pass, verdict, alpha, n_obs}``) - no new keys, no aggregate field. On ``aggregate="event_study"``, the output carries ``aggregate``, ``pretrends_joint``, ``homogeneity_joint`` and omits the ``None``-valued ``stute`` / ``yatchew`` keys entirely. """ # qug serializes as None on the survey/weights path (Phase 4.5 C # QUG-skip per C0 deferral); rendered as the existing dict on the # default unweighted path. qug_dict = None if self.qug is None else self.qug.to_dict() if self.aggregate == "event_study": return { "aggregate": str(self.aggregate), "qug": qug_dict, "pretrends_joint": ( None if self.pretrends_joint is None else self.pretrends_joint.to_dict() ), "homogeneity_joint": ( None if self.homogeneity_joint is None else self.homogeneity_joint.to_dict() ), "all_pass": bool(self.all_pass), "verdict": str(self.verdict), "alpha": float(self.alpha), "n_obs": int(self.n_obs), } # aggregate == "overall" - Phase 3 schema preserved bit-exactly on # the unweighted path (qug populated); the qug=None survey path # surfaces qug: null. return { "qug": qug_dict, "stute": None if self.stute is None else self.stute.to_dict(), "yatchew": None if self.yatchew is None else self.yatchew.to_dict(), "all_pass": bool(self.all_pass), "verdict": str(self.verdict), "alpha": float(self.alpha), "n_obs": int(self.n_obs), }
[docs] def to_dataframe(self) -> pd.DataFrame: """Return a tidy 3-row DataFrame (one row per implemented test). Columns (stable across aggregates): ``[test, statistic_name, statistic_value, p_value, reject, alpha, n_obs]``. Row identifiers vary by aggregate: - ``aggregate="overall"``: rows are ``qug``, ``stute``, ``yatchew_hr`` (Phase 3 schema, unchanged). - ``aggregate="event_study"``: rows are ``qug``, ``pretrends_joint``, ``homogeneity_joint``. Rows for ``None``-valued components (e.g. ``pretrends_joint`` when no earlier pre-period exists) are emitted with NaN statistic values and ``reject=False`` to preserve the 3-row shape. """ # qug row: NaN-skip when self.qug is None (Phase 4.5 C survey/weights # path leaves qug=None per C0 deferral). Mirrors the joint NaN-row # shape from `_joint_row_or_nan` so the 3-row contract is preserved. if self.qug is None: qug_row = { "test": "qug", "statistic_name": "t_stat", "statistic_value": float("nan"), "p_value": float("nan"), "reject": False, "alpha": float(self.alpha), "n_obs": int(self.n_obs), } else: qug_row = { "test": "qug", "statistic_name": "t_stat", "statistic_value": _json_safe_scalar(self.qug.t_stat), "p_value": _json_safe_scalar(self.qug.p_value), "reject": bool(self.qug.reject), "alpha": float(self.qug.alpha), "n_obs": int(self.qug.n_obs), } if self.aggregate == "event_study": pre_row = self._joint_row_or_nan("pretrends_joint", self.pretrends_joint) hom_row = self._joint_row_or_nan("homogeneity_joint", self.homogeneity_joint) rows = [qug_row, pre_row, hom_row] else: stute_row = ( { "test": "stute", "statistic_name": "cvm_stat", "statistic_value": _json_safe_scalar(self.stute.cvm_stat), "p_value": _json_safe_scalar(self.stute.p_value), "reject": bool(self.stute.reject), "alpha": float(self.stute.alpha), "n_obs": int(self.stute.n_obs), } if self.stute is not None else { "test": "stute", "statistic_name": "cvm_stat", "statistic_value": float("nan"), "p_value": float("nan"), "reject": False, "alpha": float(self.alpha), "n_obs": int(self.n_obs), } ) yatchew_row = ( { "test": "yatchew_hr", "statistic_name": "t_stat_hr", "statistic_value": _json_safe_scalar(self.yatchew.t_stat_hr), "p_value": _json_safe_scalar(self.yatchew.p_value), "reject": bool(self.yatchew.reject), "alpha": float(self.yatchew.alpha), "n_obs": int(self.yatchew.n_obs), } if self.yatchew is not None else { "test": "yatchew_hr", "statistic_name": "t_stat_hr", "statistic_value": float("nan"), "p_value": float("nan"), "reject": False, "alpha": float(self.alpha), "n_obs": int(self.n_obs), } ) rows = [qug_row, stute_row, yatchew_row] cols = [ "test", "statistic_name", "statistic_value", "p_value", "reject", "alpha", "n_obs", ] return pd.DataFrame(rows).reindex(columns=cols)
def _joint_row_or_nan( self, test_label: str, joint: Optional[StuteJointResult] ) -> Dict[str, Any]: """Build a to_dataframe row for a joint-Stute component. When ``joint`` is ``None`` (e.g. pretrends_joint skipped because no earlier pre-period), emit a NaN row preserving the 3-row shape for downstream plotting. """ if joint is None: return { "test": test_label, "statistic_name": "cvm_stat_joint", "statistic_value": float("nan"), "p_value": float("nan"), "reject": False, "alpha": float(self.alpha), "n_obs": int(self.n_obs), } return { "test": test_label, "statistic_name": "cvm_stat_joint", "statistic_value": _json_safe_scalar(joint.cvm_stat_joint), "p_value": _json_safe_scalar(joint.p_value), "reject": bool(joint.reject), "alpha": float(joint.alpha), "n_obs": int(joint.n_obs), }
# ============================================================================= # Private helpers # ============================================================================= def _validate_1d_numeric(arr: np.ndarray, name: str) -> np.ndarray: """Return ``arr`` as a 1D float ndarray or raise ``ValueError``.""" a = np.asarray(arr) if a.ndim != 1: raise ValueError(f"{name} must be 1-dimensional, got shape {a.shape}.") a = a.astype(np.float64, copy=False) if np.isnan(a).any(): raise ValueError(f"{name} contains NaN values.") if not np.isfinite(a).all(): raise ValueError(f"{name} contains non-finite values (inf).") return a def _fit_ols_intercept_slope(d: np.ndarray, dy: np.ndarray) -> "tuple[float, float, np.ndarray]": """Fit ``dy = a + b*d + eps`` via closed-form OLS. Returns ``(a_hat, b_hat, residuals)`` where ``residuals`` has the same length as ``d`` in the ORIGINAL input order (not sorted). """ d_mean = d.mean() dy_mean = dy.mean() d_dev = d - d_mean var_d = np.dot(d_dev, d_dev) if var_d <= 0.0: # Degenerate case: all dose values equal. Slope undefined. # Caller is responsible for gating before we reach here; if we # do reach here, return (mean(dy), 0, dy - mean(dy)). return float(dy_mean), 0.0, dy - dy_mean b_hat = float(np.dot(d_dev, dy - dy_mean) / var_d) a_hat = float(dy_mean - b_hat * d_mean) residuals = dy - a_hat - b_hat * d return a_hat, b_hat, residuals def _fit_weighted_ols_intercept_slope( d: np.ndarray, dy: np.ndarray, w: np.ndarray ) -> "tuple[float, float, np.ndarray]": """Weighted OLS analog of :func:`_fit_ols_intercept_slope`. Solves the weighted normal equations for ``dy = a + b*d + eps`` where each observation has weight ``w_g``. Returns ``(a_hat, b_hat, residuals)`` with ``residuals`` in the ORIGINAL input order (not sorted) and on the un-weighted scale (``residuals = dy - a_hat - b_hat * d``, NOT ``sqrt(w) * (dy - ...)``). At ``w = ones(G)`` reduces bit-exactly to ``_fit_ols_intercept_slope`` (Phase 4.5 C stability invariant #1; locked at ``atol=1e-14`` by the survey-path tests). Closed form: b_hat = sum(w * (d - d_w_mean) * (dy - dy_w_mean)) / sum(w * (d - d_w_mean)^2) a_hat = dy_w_mean - b_hat * d_w_mean where ``d_w_mean = sum(w * d) / sum(w)`` (and similarly for ``dy``). """ sw = float(np.sum(w)) if sw <= 0.0: raise ValueError( f"_fit_weighted_ols_intercept_slope: sum(w) = {sw} <= 0; " "weighted OLS requires positive total mass." ) d_wmean = float(np.sum(w * d) / sw) dy_wmean = float(np.sum(w * dy) / sw) d_dev = d - d_wmean var_d_w = float(np.sum(w * d_dev * d_dev)) if var_d_w <= 0.0: # Degenerate case: all dose values equal (under weights). return float(dy_wmean), 0.0, dy - dy_wmean b_hat = float(np.sum(w * d_dev * (dy - dy_wmean)) / var_d_w) a_hat = float(dy_wmean - b_hat * d_wmean) residuals = dy - a_hat - b_hat * d return a_hat, b_hat, residuals def _fit_ols_intercept_only(dy: np.ndarray) -> "tuple[float, float, np.ndarray]": """Fit ``dy = a + eps`` (intercept-only OLS, mean-independence null). Returns ``(a_hat, 0.0, residuals)`` where ``a_hat = mean(dy)`` and ``residuals = dy - a_hat`` in the ORIGINAL input order. Slope ``b_hat = 0.0`` is returned for tuple-symmetry with :func:`_fit_ols_intercept_slope`; the downstream Yatchew variance code consumes only ``residuals``. Mirrors R ``YatchewTest::yatchew_test(order=0)``. """ a_hat = float(np.mean(dy)) residuals = dy - a_hat return a_hat, 0.0, residuals def _fit_weighted_ols_intercept_only( dy: np.ndarray, w: np.ndarray ) -> "tuple[float, float, np.ndarray]": """Weighted intercept-only OLS analog of :func:`_fit_ols_intercept_only`. Returns ``(a_hat, 0.0, residuals)`` where ``a_hat = sum(w * dy) / sum(w)`` (the weighted mean) and ``residuals = dy - a_hat`` in the ORIGINAL input order on the un-weighted scale. At ``w = ones(G)`` reduces bit-exactly to :func:`_fit_ols_intercept_only`. """ sw = float(np.sum(w)) if sw <= 0.0: raise ValueError( f"_fit_weighted_ols_intercept_only: sum(w) = {sw} <= 0; " "weighted OLS requires positive total mass." ) a_hat = float(np.sum(w * dy) / sw) residuals = dy - a_hat return a_hat, 0.0, residuals def _cvm_statistic(eps_sorted: np.ndarray, d_sorted: np.ndarray) -> float: """Compute the tie-safe Cramer-von Mises cusum statistic. Paper definition (Appendix D): c_G(d) := G^{-1/2} * sum_g 1{D_g <= d} * eps_g S := (1/G) * sum_g c_G^2(D_g) = (1/G^2) * sum_g (C_g)^2 where ``C_g = sum_{h : D_h <= D_g} eps_h`` is the cumulative residual sum up to and including ALL observations with dose <= D_g. This definition is tie-safe: at a tied dose value ``D_g == D_{g+1}``, both c_G(D_g) and c_G(D_{g+1}) include all tie-block members, so the cumulative sum used at each tied observation is the cumulative sum through the END of the tie block. A naive per-observation ``cumsum`` on sorted residuals violates this at tie blocks (each tied observation sees a partial within-block cumulative sum). This implementation collapses each tie block to the post-tie cumulative sum before squaring, matching the paper definition. Parameters ---------- eps_sorted : np.ndarray, shape (G,) Residuals sorted by ``d_sorted``. d_sorted : np.ndarray, shape (G,) Regressor values sorted ascending. Must be sorted consistently with ``eps_sorted``. Returns ------- float ``S = (1 / G^2) * sum_g C_g^2``. """ G = eps_sorted.shape[0] cumsum = np.cumsum(eps_sorted) # Tie-safe correction: replace within-tie-block values with the # cumulative sum at the END of each tie block. np.unique on the # already-sorted regressor gives per-unique-value counts; the last # index of each tie block is `cumsum(counts) - 1`, and np.repeat # expands that back to per-observation. _, counts = np.unique(d_sorted, return_counts=True) tie_end_idx = np.cumsum(counts) - 1 cumsum_tie_safe = np.repeat(cumsum[tie_end_idx], counts) return float(np.sum(cumsum_tie_safe * cumsum_tie_safe) / (G * G)) def _has_lonely_psu_adjust_singletons(resolved: Any) -> bool: """Detect singleton strata under ``lonely_psu='adjust'``. Returns ``True`` iff (a) the resolved design uses ``lonely_psu='adjust'`` AND (b) at least one stratum has fewer than 2 PSUs. Under those conditions, the bootstrap multiplier helper pools singletons into a pseudo-stratum with NONZERO multipliers while the analytical variance target requires a pseudo-stratum centering transform that is not derived for the Stute CvM (Phase 4.5 C R5 P1; mirrors the explicit lonely-PSU reject on HeterogeneousAdoptionDiD's sup-t bootstrap at ``had.py:2081-2118``). """ if getattr(resolved, "lonely_psu", "remove") != "adjust": return False strata_arr = resolved.strata if strata_arr is None: return False psu_arr = resolved.psu for h in np.unique(strata_arr): mask_h = np.asarray(strata_arr) == h if psu_arr is not None: n_psu_h = int(np.unique(np.asarray(psu_arr)[mask_h]).shape[0]) else: n_psu_h = int(mask_h.sum()) if n_psu_h < 2: return True return False def _cvm_statistic_weighted( eps_sorted: np.ndarray, d_sorted: np.ndarray, w_sorted: np.ndarray ) -> float: """Weighted analog of :func:`_cvm_statistic` (survey-weighted plug-in). The unweighted Stute CvM `S = (1/G) * sum_g c_G(D_g)^2` integrates the squared cusum process against the empirical CDF ``F_hat = (1/G) sum_i delta_{D_i}``. The weighted plug-in replaces ``F_hat`` by the survey-weighted EDF ``F_hat_w = (1/W) sum_i w_i delta_{D_i}``, which weights BOTH the inner cusum AND the outer integration measure: C_g = sum_{h : D_h <= D_g} w_h * eps_h (inner cusum, weighted) S_w = (1 / W^2) * sum_g w_g * (C_g)^2 (outer measure, weighted) W = sum(w) The outer ``w_g`` factor on each squared cusum (R7 P0 fix) is what distinguishes this from a count-weighted-cusum form ``(1/W^2) * sum_g C_g^2`` (no outer ``w_g``), which silently misreports survey-weighted Stute statistics for non-uniform weights. At ``w = ones(G)`` both forms reduce to ``(1/G^2) sum_g C_g^2`` (unweighted) -- only non-uniform weights distinguish them. Tie-block collapse uses the same ``np.unique(d_sorted)`` count machinery as the unweighted form -- positions are determined by ``d_sorted`` ties (independent of weights), so the collapse pattern is weight-invariant. The outer ``w_sorted`` factor applies to the tie-collapsed cusum at each observation. Parameters ---------- eps_sorted, d_sorted, w_sorted : np.ndarray, shape (G,) Residuals, regressor values, and weights sorted CONSISTENTLY by ``d``. Caller is responsible for the sort alignment. Returns ------- float Weighted CvM statistic. """ weighted_eps = w_sorted * eps_sorted cumsum = np.cumsum(weighted_eps) _, counts = np.unique(d_sorted, return_counts=True) tie_end_idx = np.cumsum(counts) - 1 cumsum_tie_safe = np.repeat(cumsum[tie_end_idx], counts) W = float(np.sum(w_sorted)) # R7 P0: integrate outer measure against F_hat_w via the w_sorted # factor on each squared cusum (NOT against uniform 1/G measure). return float(np.sum(w_sorted * cumsum_tie_safe * cumsum_tie_safe) / (W * W)) def _compose_verdict( qug: QUGTestResults, stute: StuteTestResults, yatchew: YatchewTestResults ) -> str: """Build the :class:`HADPretestReport` verdict string. Paper Section 4.2-4.3 specifies a four-step workflow; Phase 3 ships step 1 (QUG) and step 3 (linearity, via ``stute_test`` OR ``yatchew_hr_test``). The linearity step accepts either test, so a conclusive Stute result alone suffices even when Yatchew is NaN (e.g. tied doses, which Yatchew rejects by contract). Paper logic: TWFE is admissible only if NONE of the implemented tests rejects. A conclusive rejection must therefore never be hidden by a purely-inconclusive verdict just because another step is NaN - it is reported as the primary outcome and any unresolved steps are appended as a suffix. Priority: 1. Collect all rejection reasons from CONCLUSIVE tests. If any conclusive test rejected, that is the primary verdict. Unresolved steps (QUG NaN, or BOTH linearity tests NaN) are appended as ``"; additional steps unresolved: ..."`` rather than replacing the rejection. 2. If no conclusive test rejected but a required step is unresolved, return a pure ``"inconclusive - ..."`` verdict naming the unresolved step(s). 3. Otherwise (all required steps conclusive and none reject), return the partial-workflow fail-to-reject verdict flagging the Assumption 7 gap, with a ``" (Yatchew NaN - skipped)"`` suffix when ONE linearity test was NaN and the other was conclusive. """ qug_ok = bool(np.isfinite(qug.p_value)) stute_ok = bool(np.isfinite(stute.p_value)) yatchew_ok = bool(np.isfinite(yatchew.p_value)) # Rejections from conclusive tests only. NaN-p tests have reject=False # by convention, so the ``ok and reject`` guard is defensive. qug_rej = qug_ok and qug.reject stute_rej = stute_ok and stute.reject yatchew_rej = yatchew_ok and yatchew.reject reasons = [] if qug_rej: reasons.append("support infimum rejected - continuous_at_zero design invalid (QUG)") if stute_rej or yatchew_rej: which = ",".join( name for name, rejected in (("Stute", stute_rej), ("Yatchew", yatchew_rej)) if rejected ) reasons.append(f"linearity rejected - heterogeneity bias ({which})") # Unresolved steps: QUG is required; step 3 requires at least one # conclusive linearity test. unresolved = [] if not qug_ok: unresolved.append("QUG NaN") if not stute_ok and not yatchew_ok: unresolved.append("both Stute and Yatchew linearity tests NaN") if reasons: # A conclusive rejection is the primary outcome. Append any # unresolved-step note rather than replacing the rejection. verdict = "; ".join(reasons) if unresolved: verdict += "; additional steps unresolved: " + "; ".join(unresolved) return verdict if unresolved: return "inconclusive - " + "; ".join(unresolved) # All required steps conclusive, none reject. Note any single skipped # linearity test (the OTHER linearity test was conclusive and # fail-to-reject, so step 3 IS resolved). skipped = [] if not stute_ok: skipped.append("Stute NaN") if not yatchew_ok: skipped.append("Yatchew NaN") skip_note = f" ({'; '.join(skipped)} - skipped)" if skipped else "" return ( "QUG and linearity diagnostics fail-to-reject" f"{skip_note}; Assumption 7 pre-trends test NOT run " "(paper step 2 deferred to Phase 3 follow-up)" ) # ============================================================================= # Public test functions # =============================================================================
[docs] def qug_test( d: np.ndarray, alpha: float = 0.05, *, survey_design: Any = None, survey: Any = None, weights: Optional[np.ndarray] = None, ) -> QUGTestResults: """Run the QUG null test for the support infimum (paper Theorem 4). Tests ``H_0: d_lower = 0`` using the order-statistic ratio ``T = D_{(1)} / (D_{(2)} - D_{(1)})``, rejecting when ``T > 1/alpha - 1``. Under the null, the asymptotic limit law of ``T`` is the ratio of two independent Exp(1) variables with CDF ``F(t) = t / (1 + t)``, so the one-sided p-value is ``1 / (1 + T)``. Zero-dose observations are filtered out (the test targets the infimum of the treated support). A ``UserWarning`` is emitted naming the exclusion count. When fewer than two positive doses remain, the test returns all-NaN inference with ``reject=False``. Parameters ---------- d : np.ndarray, shape (G,) Post-period dose vector. Must be 1D numeric and contain no NaN. alpha : float, default 0.05 One-sided significance level. Must satisfy ``0 < alpha < 1``. survey_design : ResolvedSurveyDesign or None, keyword-only, default None Permanently rejected with ``NotImplementedError`` (Phase 4.5 C0 decision gate). Surface-symmetric kwarg with the rest of the HAD family — accepted in the signature so all 8 HAD entry points share the canonical kwarg name, but ``qug_test`` has no survey-aware migration target. See *Notes -- Survey/weighted data*. survey : SurveyDesign or None, keyword-only, default None DEPRECATED alias of ``survey_design=``. Surface-symmetric only; any non-``None`` value still raises ``NotImplementedError`` — the deprecation is about kwarg-name consolidation, NOT a migration path (there is no survey-aware QUG). Will be removed in the next minor release. weights : np.ndarray or None, keyword-only, default None DEPRECATED alias of ``survey_design=`` for the per-row pweight shortcut on the rest of the HAD array-in family. On ``qug_test``, surface-symmetric only; any non-``None`` value still raises ``NotImplementedError`` — there is no migration path (``make_pweight_design(arr)`` is NOT a valid QUG migration target). Will be removed in the next minor release. Returns ------- QUGTestResults Result dataclass with ``t_stat``, ``p_value``, ``reject``, and sample metadata. Raises ------ ValueError If ``d`` is not 1D numeric or contains NaN, or if ``alpha`` is not in ``(0, 1)``, or if more than one of ``survey_design``/``survey``/``weights`` is non-None (mutex). NotImplementedError If any of ``survey_design``, ``survey``, ``weights`` is non-None. See *Notes -- Survey/weighted data*. Notes ----- Tie-break: when ``D_{(1)} == D_{(2)}`` the statistic is undefined. The test returns ``t_stat=NaN, p_value=NaN, reject=False`` with a ``UserWarning`` rather than raising. Survey/weighted data: QUG is permanently deferred under survey-weighted or pweight inputs (Phase 4.5 C0 decision gate, 2026-04). The test statistic uses extreme order statistics ``(D_{(1)}, D_{(2)})``, which are NOT smooth functionals of the empirical CDF -- standard survey machinery (Binder TSL linearization, multiplier bootstrap, Rao-Wu rescaled bootstrap) does not yield a calibrated test, and under cluster sampling the ``Exp(1)/Exp(1)`` limit law's independence assumption breaks. The extreme-value-theory-under-unequal-probability- sampling literature (Quintos et al. 2001, Beirlant et al.) addresses tail-index estimation, not boundary tests; no off-the-shelf survey-aware QUG exists. Phase 4.5 C ships survey-aware Stute via :func:`did_had_pretest_workflow` (which skips the QUG step under survey/weights and runs the linearity family with a PSU-level Mammen multiplier bootstrap for Stute and weighted OLS + pweight-sandwich variance components for Yatchew). See ``docs/methodology/REGISTRY.md`` § "QUG Null Test" for the full methodology note. References ---------- de Chaisemartin, Ciccia, D'Haultfoeuille, Knau (2026, arXiv:2405.04465v6), Theorem 4 and Section 4.2. """ if not (0.0 < alpha < 1.0): raise ValueError(f"alpha must satisfy 0 < alpha < 1, got {alpha}.") # Three-way mutex on survey_design / survey / weights. qug_test rejects # ALL non-None survey-aware inputs (Phase 4.5 C0 permanent deferral, see # NotImplementedError below), so the mutex message here is qug-specific # and does NOT point users to `make_pweight_design(arr)` (which the # array-in mutex on `stute_test`/`yatchew_hr_test`/`stute_joint_pretest` # does suggest as the migration target). PR #376 R2 P3 fix. n_set = sum(x is not None for x in (survey_design, survey, weights)) if n_set > 1: raise ValueError( "qug_test: pass at most one of `survey_design=`, `survey=`, or " "`weights=`. All three are permanently rejected on qug_test " "(Phase 4.5 C0 deferral) — there is no migration path; see the " "NotImplementedError raised below for the methodology rationale." ) # Soft deprecation: route legacy survey=/weights= aliases through # survey_design= for the gated NotImplementedError below. PR #376 R10 # P3: qug_test-specific deprecation messages — the shared # HAD_DEPRECATION_MSG_*_KWARG_ARRAY_IN strings tell users to migrate to # `survey_design=` / `make_pweight_design(...)`, but qug_test # permanently rejects ALL survey-aware kwargs (Phase 4.5 C0 deferral). # Use qug-specific warning text that says the aliases are deprecated # but survey-aware QUG remains unsupported, and points users to # unweighted `qug_test()` or `did_had_pretest_workflow(..., # survey_design=...)` for the survey-aware linearity family. if survey is not None: warnings.warn( "`survey=` is deprecated on qug_test (will be removed in the " "next minor release). Note that qug_test does NOT support " "survey-aware inputs at all (Phase 4.5 C0 permanent deferral; " "see the NotImplementedError below). For survey-aware HAD " "pretesting, use `did_had_pretest_workflow(..., " "survey_design=...)` (the workflow skips the QUG step under " "survey/weights and runs the linearity family).", DeprecationWarning, stacklevel=2, ) survey_design = survey elif weights is not None: warnings.warn( "`weights=` is deprecated on qug_test (will be removed in the " "next minor release). Note that qug_test does NOT support " "weighted/survey inputs at all (Phase 4.5 C0 permanent deferral; " "see the NotImplementedError below). For survey-aware HAD " "pretesting, use `did_had_pretest_workflow(..., " "survey_design=...)` (the workflow skips the QUG step under " "survey/weights and runs the linearity family).", DeprecationWarning, stacklevel=2, ) survey_design = make_pweight_design(np.asarray(weights, dtype=np.float64)) # Phase 4.5 C0 decision gate: QUG-under-survey is permanently deferred. # Extreme-order-statistic functionals are not smooth in the empirical # CDF, so standard survey machinery (Binder TSL linearization, Rao-Wu # rescaled bootstrap) does not provide a calibrated test. See # REGISTRY.md § "QUG Null Test" for the full methodology note. if survey_design is not None: raise NotImplementedError( "qug_test does not support survey_design= / survey= / " "weights= kwargs.\n" "\n" "QUG (de Chaisemartin et al. 2026, Theorem 4) tests " "H_0: d_lower = 0 via the ratio of the two smallest order " "statistics, T = D_(1) / (D_(2) - D_(1)). " "Extreme-order-statistic functionals are not smooth in the " "empirical CDF, so standard survey machinery (Binder " "linearization, multiplier bootstrap, Rao-Wu rescaled " "bootstrap) does not provide a calibrated test. Under " "cluster sampling the Exp(1)/Exp(1) limit law's independence " "assumption breaks. The literature on extreme-value theory " "under unequal-probability sampling (Quintos et al. 2001, " "Beirlant et al.) addresses tail-index estimation, not " "boundary tests; no off-the-shelf survey-aware QUG exists.\n" "\n" "For survey-aware HAD pretesting, use the joint Stute family " "via did_had_pretest_workflow(..., survey_design=..., " "aggregate=...) -- shipped in Phase 4.5 C. The workflow " "skips the QUG step under survey/weights with a UserWarning " "and runs the linearity family with a PSU-level Mammen " "multiplier bootstrap (Stute) + weighted OLS + pweight-" "sandwich variance components (Yatchew). See " "docs/methodology/REGISTRY.md § 'QUG Null Test' for the " "full methodology note." ) d_arr = _validate_1d_numeric(d, "d") critical_value = 1.0 / alpha - 1.0 # HAD support restriction: doses must be non-negative (paper Section 2). # Reject negative doses at the front door rather than silently filtering # them into the zero-exclusion counter. if (d_arr < 0).any(): n_neg = int((d_arr < 0).sum()) raise ValueError( f"qug_test: d contains {n_neg} negative value(s); HAD doses " f"must be non-negative (paper Section 2). Check your dose " f"column or pre-process before calling qug_test." ) mask = d_arr > 0 d_nz = d_arr[mask] n_excluded = int(d_arr.shape[0] - d_nz.shape[0]) if n_excluded > 0: warnings.warn( f"qug_test: excluded {n_excluded} observation(s) with d == 0 " f"(the QUG null test targets the infimum of the treated-dose " f"support; zero-dose observations are not in scope).", UserWarning, stacklevel=2, ) n_obs = int(d_nz.shape[0]) if n_obs < _MIN_G_QUG: warnings.warn( f"qug_test: only {n_obs} positive-dose observation(s); need " f"at least {_MIN_G_QUG}. Returning NaN result.", UserWarning, stacklevel=2, ) return QUGTestResults( t_stat=float("nan"), p_value=float("nan"), reject=False, alpha=alpha, critical_value=critical_value, n_obs=n_obs, n_excluded_zero=n_excluded, d_order_1=float("nan"), d_order_2=float("nan"), ) # Use np.partition for O(G) extraction of the two smallest positive # doses (faster than full O(G log G) sort). For k=1, np.partition # guarantees partitioned[0] <= partitioned[1] = D_{(2)} (the 2nd-smallest), # which implies partitioned[0] = D_{(1)} (the minimum). partitioned = np.partition(d_nz, 1) D1 = float(partitioned[0]) D2 = float(partitioned[1]) if D2 == D1: warnings.warn( "qug_test: D_(1) == D_(2); the test statistic is undefined " "(ties at the minimum positive dose). Returning NaN result.", UserWarning, stacklevel=2, ) return QUGTestResults( t_stat=float("nan"), p_value=float("nan"), reject=False, alpha=alpha, critical_value=critical_value, n_obs=n_obs, n_excluded_zero=n_excluded, d_order_1=D1, d_order_2=D2, ) t_stat = D1 / (D2 - D1) p_value = 1.0 / (1.0 + t_stat) reject = t_stat > critical_value return QUGTestResults( t_stat=float(t_stat), p_value=float(p_value), reject=bool(reject), alpha=alpha, critical_value=critical_value, n_obs=n_obs, n_excluded_zero=n_excluded, d_order_1=D1, d_order_2=D2, )
[docs] def stute_test( d: np.ndarray, dy: np.ndarray, alpha: float = 0.05, n_bootstrap: int = 999, seed: Optional[int] = None, *, survey_design: Any = None, survey: Any = None, weights: Optional[np.ndarray] = None, ) -> StuteTestResults: """Run the Stute Cramer-von Mises linearity test (paper Appendix D). Tests ``H_0: E[ΔY | D_2]`` is linear in ``D_2`` (paper Assumption 8). The test statistic is the sorted-residual cusum CvM S = (1 / G^2) * sum_{g=1}^G (sum_{h=1}^g eps_(h))^2 where ``eps_(h)`` is the ``h``-th OLS residual after sorting by ``d``. The p-value is the bootstrap tail probability ``(1 + sum(S_b >= S)) / (B + 1)`` under the Mammen (1993) two-point wild bootstrap; each bootstrap iteration refits OLS on ``dy_b = a_hat + b_hat * d + eps * eta`` with multiplier weights ``eta``. Parameters ---------- d, dy : np.ndarray, shape (G,) Dose and first-difference outcome vectors. alpha : float, default 0.05 Significance level. Must satisfy ``0 < alpha < 1``. n_bootstrap : int, default 999 Number of Mammen wild bootstrap replications. Must be ``>= 99`` (below which the discretised p-value grid is too coarse). seed : int or None, default None Seed for ``np.random.default_rng``. Pass an integer for reproducible results. survey_design : ResolvedSurveyDesign or None, keyword-only, default None Already-resolved survey design (per-unit). Array-in helpers accept ``ResolvedSurveyDesign`` ONLY; passing a ``SurveyDesign`` raises ``TypeError`` with migration guidance. For the pweight-only shortcut, use ``survey_design=make_pweight_design(arr)``. Triggers the survey-aware Stute calibration: PSU-level Mammen multipliers via :func:`diff_diff.bootstrap_utils.generate_survey_multiplier_weights_batch`, broadcast to per-unit residual perturbation, with weighted CvM recompute. Replicate-weight designs raise ``NotImplementedError``. survey : ResolvedSurveyDesign or None, keyword-only, default None DEPRECATED alias of ``survey_design=``. Will be removed in the next minor release. weights : np.ndarray or None, keyword-only, default None DEPRECATED alias of ``survey_design=make_pweight_design(arr)``. Will be removed in the next minor release. Returns ------- StuteTestResults Raises ------ ValueError If ``d`` / ``dy`` are not 1D numeric, contain NaN, have unequal lengths, if any ``d`` value is negative (paper Section 2 HAD support restriction), if ``alpha`` is outside ``(0, 1)``, or if ``n_bootstrap < 99``. Also raised if more than one of ``survey_design``, ``survey``, ``weights`` is supplied (3-way mutex; ``survey=`` and ``weights=`` are deprecated aliases of ``survey_design=``). TypeError If ``survey_design=SurveyDesign(...)`` (or the deprecated ``survey=SurveyDesign(...)`` alias) is passed; array-in helpers accept ``ResolvedSurveyDesign`` only. Use ``survey_design=make_pweight_design(arr)`` for pweight-only or pre-resolve via ``SurveyDesign(...).resolve(data)``. NotImplementedError If ``survey.replicate_weights is not None``. Replicate-weight pretests are a parallel follow-up after Phase 4.5 C; the per-replicate weight-ratio rescaling for the OLS-on-residuals refit step is not covered by the multiplier-bootstrap composition used here. Notes ----- Sample-size gate: below ``G = 10`` the CvM statistic is not well-calibrated. In that case the function emits ``UserWarning`` and returns all-NaN inference rather than raising. Large-G warning: at ``G > 100_000`` the per-iteration refit dominates runtime; the function emits a ``UserWarning`` pointing users to :func:`yatchew_hr_test`. Memory usage remains ``O(G)`` regardless (no G x G matrix). Survey/weighted data (Phase 4.5 C): when ``weights`` or ``survey`` is supplied, the OLS baseline becomes weighted OLS (:func:`_fit_weighted_ols_intercept_slope`), the bootstrap multipliers become PSU-level Mammen draws (broadcast to per-obs perturbation), and the test statistic uses :func:`_cvm_statistic_weighted`. Per-unit constant-within-unit invariant on weights/strata/psu/fpc is the CALLER's responsibility; the workflow (:func:`did_had_pretest_workflow`) enforces it via :func:`_aggregate_unit_weights` / :func:`_aggregate_unit_resolved_survey` from ``had.py``. At ``w = ones(G)``, weighted helpers reduce bit-exactly to the unweighted versions but bootstrap p-values diverge by Monte-Carlo noise (different RNG consumption between batched ``generate_survey_multiplier_weights_batch`` and per-iteration ``_generate_mammen_weights``); use the distribution-equivalence reduction test (large B) for trivial-pweight parity, NOT numerical equivalence. References ---------- Stute, W. (1997). Nonparametric model checks for regression. Annals of Statistics 25, 613-641. Mammen, E. (1993). Bootstrap and wild bootstrap for high-dimensional linear models. Annals of Statistics 21, 255-285. de Chaisemartin et al. (2026), Appendix D. """ if not (0.0 < alpha < 1.0): raise ValueError(f"alpha must satisfy 0 < alpha < 1, got {alpha}.") if n_bootstrap < _MIN_N_BOOTSTRAP: raise ValueError( f"n_bootstrap must be >= {_MIN_N_BOOTSTRAP} (below this the " f"discretised p-value grid is too coarse to be meaningful). " f"Got n_bootstrap={n_bootstrap}." ) # Three-way mutex on survey_design / survey / weights (array-in pattern). n_set = sum(x is not None for x in (survey_design, survey, weights)) if n_set > 1: raise ValueError(HAD_DUAL_KNOB_MUTEX_MSG_ARRAY_IN) # Soft deprecation: route legacy survey=/weights= aliases to survey_design= # FIRST so the type guard below covers `survey=SurveyDesign(...)` too # (PR #376 R1 P1: alias must behave identically to the canonical kwarg). # The bit-exact normalization-order invariant requires passing UNNORMALIZED # weights to make_pweight_design; the unified path's mean=1 step (~line # 1669) fires downstream EXACTLY ONCE. if survey is not None: warnings.warn(HAD_DEPRECATION_MSG_SURVEY_KWARG, DeprecationWarning, stacklevel=2) survey_design = survey elif weights is not None: warnings.warn( HAD_DEPRECATION_MSG_WEIGHTS_KWARG_ARRAY_IN, DeprecationWarning, stacklevel=2, ) survey_design = make_pweight_design(np.asarray(weights, dtype=np.float64)) # Type guard: array-in helpers reject SurveyDesign (cannot resolve column # names without `data`). Runs AFTER alias rebinding so it covers both # `survey_design=SurveyDesign(...)` and the deprecated # `survey=SurveyDesign(...)` form identically. if survey_design is not None and isinstance(survey_design, SurveyDesign): raise TypeError( "stute_test: `survey_design=` accepts a pre-resolved " "ResolvedSurveyDesign only (array-in helpers have no `data` to " "resolve column names against). For pweight-only, use " "`survey_design=make_pweight_design(arr)`. For full PSU/strata/" "FPC, pre-resolve via `SurveyDesign(...).resolve(data)` and pass " "the result." ) # Internal alias rebind: downstream code uses `survey` and `weights` as # internal variable names (Phase 4.5 C convention). After the deprecation # block, fold the canonical survey_design back into the legacy variable # names so the unchanged downstream logic consumes the input transparently. survey = survey_design weights = None # weights= alias has been folded into survey_design # Replicate-weight rejection: the per-replicate weight-ratio rescaling for # the OLS-on-residuals refit step is not covered by the multiplier-bootstrap # composition. Parallel follow-up after Phase 4.5 C. if survey is not None and getattr(survey, "replicate_weights", None) is not None: raise NotImplementedError( "stute_test: replicate-weight survey designs (BRR/Fay/JK1/JKn/SDR) " "are not yet supported on HAD pretests. The per-replicate weight-" "ratio rescaling for the OLS-on-residuals refit step is not covered " "by the multiplier-bootstrap composition. Replicate-weight pretests " "are a parallel follow-up after Phase 4.5 C." ) # R1 P1: pweight-only guard on the direct-helper survey entry (mirrors # _resolve_pretest_unit_weights for the workflow path). if survey is not None and getattr(survey, "weight_type", "pweight") != "pweight": raise ValueError( f"stute_test: HAD pretests require weight_type='pweight'. Got " f"weight_type={survey.weight_type!r}. aweight / fweight have " "different sandwich-variance semantics that are not derived " "for the Stute CvM bootstrap calibration." ) d_arr = _validate_1d_numeric(d, "d") dy_arr = _validate_1d_numeric(dy, "dy") if d_arr.shape[0] != dy_arr.shape[0]: raise ValueError( f"d and dy must have the same length; got d.shape={d_arr.shape}, " f"dy.shape={dy_arr.shape}." ) # HAD support restriction (paper Section 2): doses must be non-negative. # Mirror the front-door guard from qug_test / _validate_had_panel. if (d_arr < 0).any(): n_neg = int((d_arr < 0).sum()) raise ValueError( f"stute_test: d contains {n_neg} negative value(s); HAD doses " f"must be non-negative (paper Section 2). Check your dose " f"column or pre-process before calling stute_test." ) G = int(d_arr.shape[0]) if G < _MIN_G_STUTE: warnings.warn( f"stute_test: G = {G} is below the minimum {_MIN_G_STUTE} for " f"the CvM statistic to be well-calibrated. Returning NaN result.", UserWarning, stacklevel=2, ) return StuteTestResults( cvm_stat=float("nan"), p_value=float("nan"), reject=False, alpha=alpha, n_bootstrap=int(n_bootstrap), n_obs=G, seed=seed, ) if G > _STUTE_LARGE_G_THRESHOLD: warnings.warn( f"stute_test: G = {G} exceeds {_STUTE_LARGE_G_THRESHOLD}; the " f"per-iteration refit is O(G) per iteration so the " f"{n_bootstrap}-replication loop may take tens of seconds or " f"more. Consider yatchew_hr_test() instead (paper Theorem 7 " f"recommends Yatchew-HR at large G).", UserWarning, stacklevel=2, ) # Phase 4.5 C: resolve effective per-unit weights (None on the # unweighted path, preserves bit-exact regression). When survey= is # supplied, w is taken from the resolved design. # R4 P1: validate 1D explicitly so column-vector inputs (e.g. # df[["w"]].to_numpy()) raise instead of silently broadcasting. if survey is not None: w_arr = _validate_1d_numeric(np.asarray(survey.weights), "stute_test: survey.weights") if w_arr.shape[0] != G: raise ValueError( f"stute_test: survey.weights length {w_arr.shape[0]} does not " f"match d/dy length {G}." ) if (w_arr <= 0).any(): raise ValueError( "stute_test: survey weights must be strictly positive. " "Zero / negative weights would leave units in the " "variance / CvM computation while contributing zero " "population mass; pre-filter the panel to the positive-" "weight subpopulation before calling stute_test." ) elif weights is not None: w_arr = _validate_1d_numeric(np.asarray(weights), "stute_test: weights") if w_arr.shape[0] != G: raise ValueError( f"stute_test: weights length {w_arr.shape[0]} does not match " f"d/dy length {G}." ) if (w_arr <= 0).any(): raise ValueError( "stute_test: weights must be strictly positive (the pweight " "shortcut does not support zero weights; use survey= with " "explicit lonely-PSU handling for zero-mass strata)." ) else: w_arr = None # R4 P0: normalize pweights to mean=1 (matches SurveyDesign.resolve() # convention). Makes the test statistic scale-invariant under uniform # rescaling of weights AND ensures weights= shortcut and # survey=SurveyDesign(weights=...) produce identical results for the # same design. Stute is internally scale-invariant in functional form, # but the survey-aware bootstrap helper consumes weight values # directly under non-trivial PSU/strata, so normalization is required # for cross-path agreement. if w_arr is not None: w_arr = w_arr * (float(w_arr.shape[0]) / float(np.sum(w_arr))) if w_arr is None: a_hat, b_hat, eps = _fit_ols_intercept_slope(d_arr, dy_arr) else: a_hat, b_hat, eps = _fit_weighted_ols_intercept_slope(d_arr, dy_arr, w_arr) # Genuine degeneracy: zero dose variation. The CvM cusum is defined # against the regressor, and constant d carries no signal to test # linearity against - emit NaN. if np.var(d_arr) <= 0.0: warnings.warn( "stute_test: constant d (zero dose variation); the Stute " "linearity test requires regressor variation. Returning NaN result.", UserWarning, stacklevel=2, ) return StuteTestResults( cvm_stat=float("nan"), p_value=float("nan"), reject=False, alpha=alpha, n_bootstrap=int(n_bootstrap), n_obs=G, seed=seed, ) # Numerically exact linear fit: Assumption 8 holds to IEEE precision, # so the Stute CvM statistic is formally 0 and every bootstrap draw is # also 0. Short-circuit to p=1 to avoid FP-noise-driven bootstrap # comparisons where cvm_stat and S_b are both at machine-epsilon scale. # Comparison is purely relative against CENTERED TSS: both translation- # invariant (centering absorbs additive shifts) and scale-invariant # (ratio is dimensionless under multiplicative dy rescaling). eps_norm_sq = float(np.sum(eps * eps)) dy_centered_sq = float(np.sum((dy_arr - dy_arr.mean()) ** 2)) if dy_centered_sq <= 0.0: # Constant dy (zero centered TSS): trivially linear in d. # Return p = 1 without running the bootstrap. return StuteTestResults( cvm_stat=0.0, p_value=1.0, reject=False, alpha=alpha, n_bootstrap=int(n_bootstrap), n_obs=G, seed=seed, ) if eps_norm_sq <= _EXACT_LINEAR_RELATIVE_TOL * dy_centered_sq: return StuteTestResults( cvm_stat=0.0, p_value=1.0, reject=False, alpha=alpha, n_bootstrap=int(n_bootstrap), n_obs=G, seed=seed, ) idx = np.argsort(d_arr, kind="stable") d_sorted = d_arr[idx] if w_arr is None: S = _cvm_statistic(eps[idx], d_sorted) else: S = _cvm_statistic_weighted(eps[idx], d_sorted, w_arr[idx]) rng = np.random.default_rng(seed) bootstrap_S = np.empty(n_bootstrap, dtype=np.float64) fitted = a_hat + b_hat * d_arr # baseline fitted values under H_0 if w_arr is None: # Unweighted bit-exact path - identical to pre-PR code. for b in range(n_bootstrap): eta = _generate_mammen_weights(G, rng) dy_b = fitted + eps * eta _, _, eps_b = _fit_ols_intercept_slope(d_arr, dy_b) bootstrap_S[b] = _cvm_statistic(eps_b[idx], d_sorted) else: # Phase 4.5 C survey-aware path: PSU-level Mammen multipliers # (broadcast to per-obs perturbation), weighted OLS refit, weighted # CvM recompute. Routes via synthetic trivial ResolvedSurveyDesign # for the weights= shortcut to share the same kernel. resolved_for_boot = survey if survey is not None else make_pweight_design(w_arr) # Stratified designs are supported via the standard stratified # clustered wild-bootstrap correction on the PSU multipliers # (within-stratum demean + sqrt(n_h/(n_h-1)) Bessel rescale), # applied uniformly before the per-obs broadcast eta_obs = # psu_mults[b, psu_col_idx] below. See REGISTRY # § "Note (Stute stratified survey-bootstrap calibration)" and # ``apply_stratum_centering`` (bootstrap_utils.py) for the # derivation; the same helper backs the HAD sup-t event-study # bootstrap at had.py:2151+. # R5 P1: reject lonely_psu='adjust' singleton-strata designs. # This pseudo-stratum centering transform has not been derived # for the Stute CvM (same gap as the HAD sup-t deviation at # REGISTRY § 'Note (HAD sup-t lonely_psu="adjust") deviation'). if _has_lonely_psu_adjust_singletons(resolved_for_boot): raise NotImplementedError( "stute_test: SurveyDesign(lonely_psu='adjust') with " "singleton strata is not yet supported on the multiplier " "bootstrap. The bootstrap helper pools singletons with " "nonzero multipliers but the matching analytical " "variance target requires a pseudo-stratum centering " "transform that has not been derived for the Stute CvM. " "Use lonely_psu='remove' (drops singleton contributions) " "or 'certainty' (zero-variance singletons), or pre-" "process the panel to remove singleton strata." ) # R3 P0: variance-unidentified survey-design guard. When # n_psu - n_strata <= 0 (e.g. unstratified single-PSU, or one PSU # per stratum under lonely_psu='remove'/'certainty'), # generate_survey_multiplier_weights_batch returns an all-zero # multiplier matrix. Without this guard, the code below would # treat zero perturbations as a valid bootstrap law and emit # p_value = 1/(B+1) for any positive observed CvM (spurious # rejection). Mirrors compute_survey_vcov's df_survey-driven # NaN treatment elsewhere in the package. The lonely_psu='adjust' # singleton case (which has nonzero multipliers but a separate # methodology gap) is already rejected above, so this branch # only catches genuinely degenerate designs. df_survey = resolved_for_boot.df_survey if df_survey is None or df_survey <= 0: warnings.warn( f"stute_test: survey design is variance-unidentified " f"(df_survey={df_survey}); the multiplier bootstrap " "cannot calibrate the test (single-PSU unstratified or " "one-PSU-per-stratum design). Returning NaN result.", UserWarning, stacklevel=2, ) return StuteTestResults( cvm_stat=float(S), p_value=float("nan"), reject=False, alpha=alpha, n_bootstrap=int(n_bootstrap), n_obs=G, seed=seed, ) psu_mults, psu_ids = generate_survey_multiplier_weights_batch( n_bootstrap, resolved_for_boot, weight_type="mammen", rng=rng ) # Stratum centering + Bessel rescale on the PSU multipliers # before broadcast. Same algebra as the HAD sup-t bootstrap at # had.py:2151+ (applied to the influence tensor there), but # applied here to ``psu_mults`` because the Stute bootstrap is a # wild-residual / refit-in-loop bootstrap (no precomputed # influence tensor exists). See REGISTRY § "Note (Stute # stratified survey-bootstrap calibration)" for the derivation # and the non-strata calibration shift it introduces. apply_stratum_centering(psu_mults, resolved_for_boot, psu_ids, psu_axis=1) # Build per-obs PSU-column index. When psu is None (trivial path), # each obs is its own PSU and psu_ids = arange(G) - so psu_col_idx # is just arange(G). if resolved_for_boot.psu is None: psu_col_idx = np.arange(G) else: psu_to_col = {int(p): c for c, p in enumerate(psu_ids)} psu_arr = np.asarray(resolved_for_boot.psu) psu_col_idx = np.array([psu_to_col[int(psu_arr[g])] for g in range(G)]) for b in range(n_bootstrap): eta_obs = psu_mults[b, psu_col_idx] # (G,) dy_b = fitted + eps * eta_obs _, _, eps_b = _fit_weighted_ols_intercept_slope(d_arr, dy_b, w_arr) bootstrap_S[b] = _cvm_statistic_weighted(eps_b[idx], d_sorted, w_arr[idx]) p_value = float((1.0 + float(np.sum(bootstrap_S >= S))) / (n_bootstrap + 1.0)) reject = p_value <= alpha return StuteTestResults( cvm_stat=float(S), p_value=p_value, reject=bool(reject), alpha=alpha, n_bootstrap=int(n_bootstrap), n_obs=G, seed=seed, )
[docs] def yatchew_hr_test( d: np.ndarray, dy: np.ndarray, alpha: float = 0.05, *, null: Literal["linearity", "mean_independence"] = "linearity", survey_design: Any = None, survey: Any = None, weights: Optional[np.ndarray] = None, ) -> YatchewTestResults: """Run the Yatchew heteroskedasticity-robust specification test. Tests one of two nulls (selected via ``null=``) using the variance-ratio statistic T_hr = sqrt(G) * (sigma2_lin - sigma2_diff) / sigma2_W where sigma2_lin = (1/G) * sum(eps^2) # residuals under chosen null sigma2_diff = (1/(2G)) * sum((dy_{(g)} - dy_{(g-1)})^2) # Yatchew differencing sigma2_W = sqrt((1/(G-1)) * sum(eps_{(g)}^2 * eps_{(g-1)}^2)) and ``_{(g)}`` denotes sort by ``d``. Under ``null="linearity"`` (default, paper Assumption 8 / Theorem 7) ``eps`` are residuals from OLS ``dy = a + b*d + eps``. Under ``null="mean_independence"`` ``eps = dy - mean(dy)`` (intercept-only OLS), mirroring R ``YatchewTest::yatchew_test(order=0)``. The ``sigma2_diff`` and ``sigma2_W`` formulas are identical between the two modes - the only delta is the residual definition. Rejection uses the one-sided standard-normal critical value ``z_{1-alpha}``. Parameters ---------- d, dy : np.ndarray, shape (G,) Dose and first-difference outcome vectors. alpha : float, default 0.05 One-sided significance level. null : {"linearity", "mean_independence"}, keyword-only, default "linearity" Which null hypothesis the test targets: - ``"linearity"`` (default): H_0 ``E[dY | D]`` is linear in ``D`` (paper Assumption 8, Theorem 7). Residuals come from OLS ``dy = a + b*d + eps``. Bit-exact backcompat with pre-PR calls. - ``"mean_independence"``: H_0 ``E[dY | D] = E[dY]`` (mean independence of ``dY`` from ``D``). Residuals come from intercept-only OLS ``dy = a + eps``, so ``eps = dy - mean(dy)``. Mirrors R ``YatchewTest::yatchew_test(order=0)``. Used by the R-parity test on placebo Yatchew rows (``Credible-Answers/did_had`` runs ``order=0`` on placebos to test pre-trends as a non-parametric mean-independence assertion). ``d`` is required under both modes (the sort-by-``d`` differencing step is null-agnostic). survey_design : ResolvedSurveyDesign or None, keyword-only, default None Already-resolved survey design (per-unit). Array-in helpers accept ``ResolvedSurveyDesign`` ONLY; passing a ``SurveyDesign`` raises ``TypeError``. For pweight-only, use ``survey_design=make_pweight_design(arr)``. When supplied, the OLS baseline becomes weighted OLS and all three variance components become their pweight-sandwich analogs. PSU clustering is NOT propagated through the variance-ratio statistic (would require deriving a survey-aware variance-of-variance estimator; out of scope per Phase 4.5 C). Replicate-weight designs raise ``NotImplementedError``. survey : ResolvedSurveyDesign or None, keyword-only, default None DEPRECATED alias of ``survey_design=``. Will be removed in the next minor release. weights : np.ndarray or None, keyword-only, default None DEPRECATED alias of ``survey_design=make_pweight_design(arr)``. Will be removed in the next minor release. Returns ------- YatchewTestResults Raises ------ ValueError If ``d`` / ``dy`` are not 1D numeric, contain NaN, have unequal lengths, if any ``d`` value is negative (paper Section 2 HAD support restriction), or if ``alpha`` is outside ``(0, 1)``. Also raised if more than one of ``survey_design``, ``survey``, ``weights`` is supplied (3-way mutex; ``survey=`` and ``weights=`` are deprecated aliases of ``survey_design=``), or if any weight is non-positive. TypeError If ``survey_design=SurveyDesign(...)`` (or the deprecated ``survey=SurveyDesign(...)`` alias) is passed; array-in helpers accept ``ResolvedSurveyDesign`` only. Use ``survey_design=make_pweight_design(arr)`` for pweight-only or pre-resolve via ``SurveyDesign(...).resolve(data)``. NotImplementedError If ``survey.replicate_weights is not None`` (deferred follow-up). Notes ----- Sample-size gate: below ``G = 3`` the difference-variance estimator is undefined; the function emits ``UserWarning`` and returns NaN rather than raising. Dose ties: REJECTED with ``UserWarning`` + all-NaN result. The difference-based variance estimator ``sigma2_diff`` and the heteroskedasticity-robust scale ``sigma4_W`` both use adjacent differences of quantities sorted by ``d``; under tied doses the within-tie row ordering is arbitrary (stable sort falls back to input order) so the statistic becomes order-dependent rather than data-dependent. Callers with tied doses (mass-point designs, discretised dose registers) should use :func:`stute_test` instead - its tie-safe Cramer-von Mises statistic collapses tie blocks to the post-tie cumulative sum and is provably order-invariant under within-tie permutations. Exact-linear short-circuit: when the OLS residual sum-of-squares is below IEEE precision relative to the centered total sum of squares (``sum(eps^2) <= 1e-24 * sum((dy - dybar)^2)``, i.e. essentially ``1 - R^2 == 0``), the test short-circuits to ``t_stat_hr=-inf, p_value=1.0, reject=False`` - Assumption 8 holds exactly, the formal statistic is ``-inf`` under the one-sided critical value, and the correct decision is fail-to-reject. This shortcut is translation- invariant because the comparison is against centered TSS (not raw ``sum(dy^2)``). Degenerate ``sigma4_W = 0`` with non-zero residuals: when the adjacent-residual-product sum vanishes AFTER the exact-linear shortcut is bypassed (e.g. residuals alternate zero/non-zero after sorting), the formal statistic is ``+inf`` or ``-inf`` depending on the sign of the numerator ``sigma2_lin - sigma2_diff``. The function returns the sign-aware limit (``p=0, reject=True`` for positive numerator; ``p=1, reject=False`` for negative; ``NaN`` for zero) with a ``UserWarning``, rather than unconditionally mapping this to ``p=1`` (which would flip a legitimate rejection). Survey/weighted data (Phase 4.5 C): when ``weights`` or ``survey`` is supplied, all three variance components use their pweight-sandwich analogs: - ``sigma2_lin = sum(w * eps^2) / sum(w)`` (weighted OLS residual variance). - ``sigma2_diff = sum(w_avg * (dy_g - dy_{g-1})^2) / (2 * sum(w))`` where ``w_avg_g = (w_g + w_{g-1}) / 2`` and the divisor uses ``sum(w)`` (not ``sum(w_avg)``) so the formula reduces bit-exactly to the unweighted ``(1/(2G))`` divisor at ``w = ones(G)``. - ``sigma4_W = sum(w_avg * eps_g^2 * eps_{g-1}^2) / sum(w_avg)`` with arithmetic-mean pair weights; reduces to the unweighted ``(1/(G-1))`` divisor at ``w = ones(G)``. - ``T_hr = sqrt(sum(w)) * (sigma2_lin - sigma2_diff) / sigma2_W``. The pair-weight convention follows Krieger-Pfeffermann (1997, §3) for design-consistent inference on smooth functionals; PSU clustering is NOT propagated through the variance-ratio statistic. Strictly positive weights are required (the adjacent-difference formula has ``sum(w_avg)`` in the denominator). Per-unit constant-within-unit invariant on weights/strata/psu/fpc is the CALLER's responsibility. References ---------- Yatchew, A. (1997). An elementary estimator of the partial linear model. Economics Letters 57, 135-143. de Chaisemartin et al. (2026), Theorem 7 / Equation 29. Krieger, A., Pfeffermann, D. (1997). Testing of distribution functions from complex sample surveys. Journal of Official Statistics 13(2), 123-142. """ if not (0.0 < alpha < 1.0): raise ValueError(f"alpha must satisfy 0 < alpha < 1, got {alpha}.") if null not in ("linearity", "mean_independence"): raise ValueError( f"yatchew_hr_test: null must be one of " f"('linearity', 'mean_independence'), got {null!r}." ) # Three-way mutex on survey_design / survey / weights (array-in pattern). n_set = sum(x is not None for x in (survey_design, survey, weights)) if n_set > 1: raise ValueError(HAD_DUAL_KNOB_MUTEX_MSG_ARRAY_IN) # Soft deprecation: route legacy survey=/weights= aliases to survey_design= # FIRST so the type guard below covers `survey=SurveyDesign(...)` too # (PR #376 R1 P1: alias must behave identically to the canonical kwarg). if survey is not None: warnings.warn(HAD_DEPRECATION_MSG_SURVEY_KWARG, DeprecationWarning, stacklevel=2) survey_design = survey elif weights is not None: warnings.warn( HAD_DEPRECATION_MSG_WEIGHTS_KWARG_ARRAY_IN, DeprecationWarning, stacklevel=2, ) survey_design = make_pweight_design(np.asarray(weights, dtype=np.float64)) # Type guard: array-in helpers reject SurveyDesign. Runs AFTER alias # rebinding so it covers both `survey_design=SurveyDesign(...)` and the # deprecated `survey=SurveyDesign(...)` form identically. if survey_design is not None and isinstance(survey_design, SurveyDesign): raise TypeError( "yatchew_hr_test: `survey_design=` accepts a pre-resolved " "ResolvedSurveyDesign only (array-in helpers have no `data` to " "resolve column names against). For pweight-only, use " "`survey_design=make_pweight_design(arr)`. For full PSU/strata/" "FPC, pre-resolve via `SurveyDesign(...).resolve(data)`." ) # Internal alias rebind for back-compat with downstream code. survey = survey_design weights = None # Replicate-weight rejection. if survey is not None and getattr(survey, "replicate_weights", None) is not None: raise NotImplementedError( "yatchew_hr_test: replicate-weight survey designs (BRR/Fay/JK1/JKn/" "SDR) are not yet supported on HAD pretests. Replicate-weight " "pretests are a parallel follow-up after Phase 4.5 C." ) # R1 P1: pweight-only guard (aweight/fweight have different sandwich- # variance semantics not derived for the variance-ratio statistic). if survey is not None and getattr(survey, "weight_type", "pweight") != "pweight": raise ValueError( f"yatchew_hr_test: HAD pretests require weight_type='pweight'. " f"Got weight_type={survey.weight_type!r}." ) d_arr = _validate_1d_numeric(d, "d") dy_arr = _validate_1d_numeric(dy, "dy") if d_arr.shape[0] != dy_arr.shape[0]: raise ValueError( f"d and dy must have the same length; got d.shape={d_arr.shape}, " f"dy.shape={dy_arr.shape}." ) # HAD support restriction (paper Section 2): doses must be non-negative. # Mirror the front-door guard from qug_test / _validate_had_panel. if (d_arr < 0).any(): n_neg = int((d_arr < 0).sum()) raise ValueError( f"yatchew_hr_test: d contains {n_neg} negative value(s); HAD " f"doses must be non-negative (paper Section 2). Check your " f"dose column or pre-process before calling yatchew_hr_test." ) G = int(d_arr.shape[0]) critical_value = float(stats.norm.ppf(1.0 - alpha)) # Phase 4.5 C: resolve effective per-unit weights. Strictly positive # required (the adjacent-difference formula divides by sum(w_avg) which # collapses to zero in any contiguous-zero block). # R4 P1: validate 1D explicitly so column-vector inputs raise. if survey is not None: w_arr = _validate_1d_numeric(np.asarray(survey.weights), "yatchew_hr_test: survey.weights") if w_arr.shape[0] != G: raise ValueError( f"yatchew_hr_test: survey.weights length {w_arr.shape[0]} " f"does not match d/dy length {G}." ) if (w_arr <= 0).any(): raise ValueError( "yatchew_hr_test: survey.weights must be strictly positive " "(adjacent-difference variance is undefined under contiguous " "zero-weight blocks)." ) elif weights is not None: w_arr = _validate_1d_numeric(np.asarray(weights), "yatchew_hr_test: weights") if w_arr.shape[0] != G: raise ValueError( f"yatchew_hr_test: weights length {w_arr.shape[0]} does not " f"match d/dy length {G}." ) if (w_arr <= 0).any(): raise ValueError( "yatchew_hr_test: weights must be strictly positive (the " "adjacent-difference variance is undefined under contiguous " "zero-weight blocks)." ) else: w_arr = None # R4 P0: normalize pweights to mean=1 (matches SurveyDesign.resolve() # convention). Yatchew uses sqrt(sum(w)) as the effective sample size, # which without normalization would scale as sqrt(c) under uniform # rescaling weights -> w * c, producing different p-values for # weights=w vs weights=100*w. Normalization makes the statistic # scale-invariant AND ensures weights= and survey=SurveyDesign(...) # produce identical results (the latter resolve()s to mean=1 # internally, the former previously did not). if w_arr is not None: w_arr = w_arr * (float(w_arr.shape[0]) / float(np.sum(w_arr))) if G < _MIN_G_YATCHEW: warnings.warn( f"yatchew_hr_test: G = {G} is below the minimum {_MIN_G_YATCHEW} " f"(need at least 2 sorted differences). Returning NaN result.", UserWarning, stacklevel=2, ) return YatchewTestResults( t_stat_hr=float("nan"), p_value=float("nan"), reject=False, alpha=alpha, critical_value=critical_value, sigma2_lin=float("nan"), sigma2_diff=float("nan"), sigma2_W=float("nan"), n_obs=G, null_form=null, ) # Tie / constant-dose front-door guard. Yatchew's difference-based # variance estimator uses adjacent differences of dy sorted by d; # under tied doses the within-tie ordering is arbitrary (stable sort # falls back to input row order), so the statistic becomes # non-methodological and order-dependent. Reject at the front door # with a UserWarning + NaN result rather than silently permuting. # Mass-point designs and other tied-dose panels should use # `stute_test` instead (its tie-safe CvM handles ties correctly). n_unique_d = int(np.unique(d_arr).shape[0]) if n_unique_d < G: n_dups = G - n_unique_d warnings.warn( f"yatchew_hr_test: d contains {n_dups} duplicate value(s) " f"(only {n_unique_d} distinct dose values out of G={G}); " f"the difference-based variance estimator is not well-defined " f"under ties because adjacent-difference statistics depend on " f"arbitrary within-tie row ordering. Use stute_test() instead " f"(its tie-safe CvM handles ties correctly). Returning NaN result.", UserWarning, stacklevel=2, ) return YatchewTestResults( t_stat_hr=float("nan"), p_value=float("nan"), reject=False, alpha=alpha, critical_value=critical_value, sigma2_lin=float("nan"), sigma2_diff=float("nan"), sigma2_W=float("nan"), n_obs=G, null_form=null, ) # 4-arm dispatch on (null, weighted). Unweighted-linearity path is # bit-exact pre-PR (stability invariant #1). Mean-independence mode # mirrors R YatchewTest::yatchew_test(order=0) and uses intercept-only # OLS residuals; the downstream sigma2_diff / sigma2_W path is # identical across nulls. if null == "linearity": if w_arr is None: _, _, eps = _fit_ols_intercept_slope(d_arr, dy_arr) else: _, _, eps = _fit_weighted_ols_intercept_slope(d_arr, dy_arr, w_arr) else: # null == "mean_independence" if w_arr is None: _, _, eps = _fit_ols_intercept_only(dy_arr) else: _, _, eps = _fit_weighted_ols_intercept_only(dy_arr, w_arr) if w_arr is None: sigma2_lin = float(np.mean(eps * eps)) sum_w = float(G) # uniform-weights effective sample size = G else: sum_w = float(np.sum(w_arr)) sigma2_lin = float(np.sum(w_arr * eps * eps) / sum_w) # Numerically exact linear fit: same short-circuit as `stute_test`. # Assumption 8 holds to IEEE precision; the Yatchew statistic is # formally -inf (finite-negative numerator over zero denominator), # which maps to p = 1 under the one-sided standard-normal critical # value. Short-circuit so FP noise in ``sigma4_W`` cannot produce a # spuriously large finite ``T_hr``. Comparison is purely relative # against CENTERED TSS - translation- AND scale-invariant. eps_norm_sq = float(np.sum(eps * eps)) dy_centered_sq = float(np.sum((dy_arr - dy_arr.mean()) ** 2)) if dy_centered_sq <= 0.0 or eps_norm_sq <= _EXACT_LINEAR_RELATIVE_TOL * dy_centered_sq: # Exact-linear branch. Covers two cases: # - dy_centered_sq == 0: dy is constant (trivially linear). # - relative SSR below IEEE precision: near-exact OLS fit. # For reporting, compute sigma2_diff on the sorted dy (finite, # well-defined even in the exact-linear case). Use the weighted # divisor (2 * sum(w)) when weights are supplied so the reported # sigma2_diff matches the same convention as the active branch. idx_early = np.argsort(d_arr, kind="stable") if w_arr is None: sigma2_diff_exact = float(np.sum(np.diff(dy_arr[idx_early]) ** 2) / (2.0 * G)) else: w_s_e = w_arr[idx_early] w_avg_e = 0.5 * (w_s_e[1:] + w_s_e[:-1]) sigma2_diff_exact = float( np.sum(w_avg_e * np.diff(dy_arr[idx_early]) ** 2) / (2.0 * sum_w) ) return YatchewTestResults( t_stat_hr=float("-inf"), p_value=1.0, reject=False, alpha=alpha, critical_value=critical_value, sigma2_lin=sigma2_lin, sigma2_diff=sigma2_diff_exact, sigma2_W=0.0, n_obs=G, null_form=null, ) idx = np.argsort(d_arr, kind="stable") dy_s = dy_arr[idx] eps_s = eps[idx] diff_dy = np.diff(dy_s) # length G - 1 if w_arr is None: # Paper-literal divisor: 2G (NOT 2(G-1)). This matches paper review # line 168: sigma2_diff := (1/(2G)) * sum((dy_{(g)} - dy_{(g-1)})^2). sigma2_diff = float(np.sum(diff_dy * diff_dy) / (2.0 * G)) # sigma4_W = (1/(G-1)) * sum(eps_(g)^2 * eps_(g-1)^2) using np.mean # which divides by the length of the input (G-1 here). Matches paper # review line 171. sigma4_W = float(np.mean(eps_s[1:] ** 2 * eps_s[:-1] ** 2)) else: # Phase 4.5 C: pweight-sandwich weighted variance components. # Pair-weights (Krieger-Pfeffermann 1997 §3): w_avg_g = (w_g + w_{g-1})/2. # Reduction at w=ones(G): w_avg = ones(G-1), sum(w) = G, so # sigma2_diff = sum(diff^2) / (2*G) (matches existing 2G divisor) # sigma4_W = sum(prod) / (G-1) (matches existing G-1 divisor) w_s = w_arr[idx] w_avg = 0.5 * (w_s[1:] + w_s[:-1]) sigma2_diff = float(np.sum(w_avg * diff_dy * diff_dy) / (2.0 * sum_w)) sigma4_W = float(np.sum(w_avg * eps_s[1:] ** 2 * eps_s[:-1] ** 2) / np.sum(w_avg)) if sigma4_W <= 0.0: # sigma4_W = 0 AFTER the exact-linear short-circuit means OLS # residuals are NOT zero (the shortcut already caught that case) # but every adjacent pair of sorted squared residuals contains a # zero (e.g. residuals alternate zero / nonzero after sort). # The formal test statistic is ±inf depending on the sign of the # numerator ``sigma2_lin - sigma2_diff``; mapping every such case # to p=1 (as an earlier revision did) can flip a legitimate # rejection into a fail-to-reject. warnings.warn( f"yatchew_hr_test: sigma4_W = 0 with non-zero residuals " f"(sigma2_lin = {sigma2_lin:.6g}, sigma2_diff = {sigma2_diff:.6g}); " f"the formal test statistic is infinite. Returning the " f"sign-aware limit decision.", UserWarning, stacklevel=2, ) numerator = sigma2_lin - sigma2_diff if numerator > 0.0: # T_hr -> +inf: reject (far into right tail). t_stat_hr_val = float("inf") p_value_val = 0.0 reject_val = True elif numerator < 0.0: # T_hr -> -inf: fail-to-reject. t_stat_hr_val = float("-inf") p_value_val = 1.0 reject_val = False else: # 0/0: genuinely indeterminate. t_stat_hr_val = float("nan") p_value_val = float("nan") reject_val = False return YatchewTestResults( t_stat_hr=t_stat_hr_val, p_value=p_value_val, reject=reject_val, alpha=alpha, critical_value=critical_value, sigma2_lin=sigma2_lin, sigma2_diff=sigma2_diff, sigma2_W=0.0, n_obs=G, null_form=null, ) sigma2_W = float(np.sqrt(sigma4_W)) # Phase 4.5 C: effective sample size = sum(w) (=G under uniform weights, # so unweighted path is bit-exact). t_stat_hr = float(np.sqrt(sum_w) * (sigma2_lin - sigma2_diff) / sigma2_W) p_value = float(1.0 - stats.norm.cdf(t_stat_hr)) reject = t_stat_hr >= critical_value return YatchewTestResults( t_stat_hr=t_stat_hr, p_value=p_value, reject=bool(reject), alpha=alpha, critical_value=critical_value, sigma2_lin=sigma2_lin, sigma2_diff=sigma2_diff, sigma2_W=sigma2_W, n_obs=G, null_form=null, )
def _validate_multi_period_panel( data: pd.DataFrame, outcome_col: str, dose_col: str, time_col: str, unit_col: str, first_treat_col: Optional[str], ) -> "tuple[Any, list, list, pd.DataFrame, Optional[Dict[str, Any]]]": """Validate a multi-period HAD panel for joint pre-test dispatch. Thin wrapper over :func:`_validate_had_panel_event_study` (had.py) that inherits the full contract: - ``first_treat_col=None`` combined with a staggered panel → raises ``ValueError`` (the had.py helper does NOT silently accept; it requires an explicit first-treatment column to identify cohorts). - ``first_treat_col`` provided but identifies only one cohort → no auto-filter, proceeds. - ``first_treat_col`` provided with multiple cohorts → auto-filters to last-cohort + never-treated, emits ``UserWarning`` with ``filter_info`` summary. - Requires ≥ 3 time periods, balanced panel, ordered time dtype, and the pre-period D=0 invariant across all pre-periods. Additional guards on top of had.py: - ``len(t_pre_list) >= 1`` (need ≥ 1 pre-period for joint pre-trends infrastructure; had.py already enforces this). - ``len(t_post_list) >= 1`` (need ≥ 1 post-period for joint homogeneity; had.py already enforces this). Returns the same 5-tuple as the had.py helper: ``(F, t_pre_list, t_post_list, data_filtered, filter_info)``. """ return _validate_had_panel_event_study( data, outcome_col=outcome_col, dose_col=dose_col, time_col=time_col, unit_col=unit_col, first_treat_col=first_treat_col, ) def _build_period_rank(data: pd.DataFrame, time_col: str) -> Dict[Any, int]: """Build a ``{period_label: chronological_rank}`` map. For ordered categorical time columns, uses the declared category order so that e.g. ``["q1", "q2", "q10"]`` ranks chronologically even though it sorts lexically in the opposite order. For numeric or datetime time columns, uses natural Python `sorted` order on the unique period labels. Object dtypes would fall back to lexicographic order - callers relying on chronology with object- dtype labels should convert to an ordered categorical first (this mirrors the contract in ``_validate_had_panel_event_study``). The rank map lets the joint-pretest wrappers compare period labels chronologically via ``rank[t1] < rank[t2]`` instead of raw Python ``t1 < t2``, which would silently misorder ordered-categorical panels (paper Appendix B.2 support contract). """ time_dtype = data[time_col].dtype if isinstance(time_dtype, pd.CategoricalDtype) and time_dtype.ordered: return {c: i for i, c in enumerate(time_dtype.categories)} periods = sorted(data[time_col].unique()) return {p: i for i, p in enumerate(periods)} def _aggregate_for_joint_test( data: pd.DataFrame, outcome_col: str, dose_col: str, time_col: str, unit_col: str, horizons: list, base_period: Any, ) -> "tuple[np.ndarray, Dict[str, np.ndarray], np.ndarray]": """Aggregate a multi-period panel for a joint-Stute test. Builds per-horizon first differences ``dy_t = Y_{g,t} - Y_{g,base}`` and the unit-level dose ``D_g`` for the joint-Stute test. All units must appear in every (horizon + base_period) period, matching the balanced-panel invariant of the single-period :func:`stute_test`. Dose extraction: ``D_g = max_t D_{g,t}`` under the HAD contract "once treated, stay treated with same dose". For pre-periods ``D_{g,t} = 0`` and for post-periods ``D_{g,t}`` is time-invariant per unit, so ``max`` recovers the realized post-period dose. Parameters ---------- data : pd.DataFrame outcome_col, dose_col, time_col, unit_col : str horizons : list Non-empty list of period labels to build ``dy_t`` for. ``base_period`` must not be in ``horizons``. All ``horizons`` and ``base_period`` must exist in the time column. base_period : period label The reference period for the first difference. Returns ------- d_arr : np.ndarray, shape (G,) dy_by_horizon : dict[str, np.ndarray] Keys are ``str(t)`` per horizon, values are ``dy_t`` arrays of shape ``(G,)``. Insertion order follows ``horizons``. unit_ids : np.ndarray, shape (G,) """ required = [outcome_col, dose_col, time_col, unit_col] missing = [c for c in required if c not in data.columns] if missing: raise ValueError(f"Missing column(s) in data: {missing}. Required: {required}.") if len(horizons) == 0: raise ValueError("horizons must be a non-empty list of period labels.") data_periods = set(data[time_col].unique()) needed_periods = list(horizons) + [base_period] missing_periods = [t for t in needed_periods if t not in data_periods] if missing_periods: raise ValueError( f"Period(s) {missing_periods} not found in time_col " f"{time_col!r}. Available periods: " f"{sorted(data_periods, key=lambda x: (x is None, x))}." ) if base_period in horizons: raise ValueError( f"base_period={base_period!r} must not appear in horizons " f"{list(horizons)!r}." ) mask = data[time_col].isin(needed_periods) subset = data.loc[mask].copy() for col in [outcome_col, dose_col, unit_col]: col_series = subset[col] if bool(pd.isna(col_series).any()): n_nan = int(pd.isna(col_series).sum()) raise ValueError( f"{n_nan} NaN value(s) found in column {col!r} across " f"periods {needed_periods}. Joint pre-test does not " f"silently drop rows; drop or impute before calling." ) # Row-level non-negative-dose guard (paper Section 2 HAD support # restriction `D_{g,t} >= 0`). Must run BEFORE the groupby/max() # collapse below, otherwise a negative post dose would silently # become 0 in the per-unit dose vector (since `max(0, -d) = 0` for # positive d), letting the wrappers run on invalid data and # potentially return finite results. This is the direct-wrapper # equivalent of the row-level check inside # `_validate_had_panel_event_study`, centralized so both # `joint_pretrends_test` and `joint_homogeneity_test` inherit it on # the `n_periods < 3` fallback path that skips the validator. negative_dose_mask = subset[dose_col] < 0 if bool(negative_dose_mask.any()): n_neg = int(negative_dose_mask.sum()) raise ValueError( f"{n_neg} negative dose value(s) found in column " f"{dose_col!r} across periods {needed_periods}. HAD support " f"restriction (paper Section 2) requires D_{{g,t}} >= 0 " f"for every (unit, period)." ) counts = subset.groupby(unit_col).size() n_needed = len(needed_periods) if (counts != n_needed).any(): n_bad = int((counts != n_needed).sum()) raise ValueError( f"Panel unbalanced across needed periods {needed_periods}: " f"{n_bad} unit(s) do not appear in all {n_needed} period(s). " f"Joint pre-test requires a balanced sub-panel." ) wide_y = subset.pivot(index=unit_col, columns=time_col, values=outcome_col) wide_y = wide_y.sort_index() unit_ids = np.asarray(wide_y.index) base_y = wide_y[base_period].to_numpy(dtype=np.float64) dy_by_horizon: Dict[str, np.ndarray] = {} for t in horizons: y_t = wide_y[t].to_numpy(dtype=np.float64) dy_by_horizon[str(t)] = y_t - base_y # Dose per unit is the HAD time-invariant post-period dose: # D_g = max_t D_{g,t}. Critically, compute this over the FULL data, # not just the subset of needed_periods - for joint pre-trends, # needed_periods contains only pre-periods (all D=0), so taking max # over the subset would yield D_g = 0 for every unit and collapse # the CvM sort to arbitrary ties. Paper HAD convention: dose is # fixed per unit once treated; pre-period zero-dose is enforced by # the upstream validator. d_per_unit = data.groupby(unit_col)[dose_col].max().sort_index() # Align dose with the subset's unit ordering (pivot sort_index uses # natural unit_col order; groupby/sort_index on the full data gives # the same order). d_per_unit = d_per_unit.loc[unit_ids] d_arr = d_per_unit.to_numpy(dtype=np.float64) return d_arr, dy_by_horizon, unit_ids def _compose_verdict_event_study( qug: QUGTestResults, pretrends_joint: Optional[StuteJointResult], homogeneity_joint: Optional[StuteJointResult], ) -> str: """Build the event-study :class:`HADPretestReport` verdict. Mirrors :func:`_compose_verdict` (two-period path) idiom verbatim: hyphen-separated ``"<concern> - <detail> (<source>)"`` reason strings, ``"; "`` join, ``"; additional steps unresolved: ..."`` suffix for conclusive rejections that coexist with unresolved steps, lowercase concerns. Coverage: - Step 1 (QUG): always runs on the event-study path. - Step 2 (Assumption 7 pre-trends): runs via ``pretrends_joint`` when at least one earlier pre-period is available. When skipped (only the immediate base pre-period), the verdict flags the skip but does NOT emit the Phase-3 "paper step 2 deferred to Phase 3 follow-up" caveat - this PR closes that gap. - Step 3 (Assumption 8 linearity/homogeneity): runs via ``homogeneity_joint`` (joint Stute only; no joint Yatchew variant exists in the paper). The step-3 alternative Yatchew-HR test is subsumed by joint Stute on this path. (Paper step 4 is the decision itself - "use TWFE if none of the tests rejects" - not a separate diagnostic, so it has no code path here.) Priority: 1. Any conclusive test rejecting → primary verdict bundles each rejection reason. Unresolved / skipped steps append as a suffix. 2. No conclusive rejection but a required step unresolved → ``"inconclusive - ..."``. 3. All required steps conclusive and none reject → admissible fail-to-reject string (Section 4 coverage). """ qug_ok = bool(np.isfinite(qug.p_value)) pretrends_ok = pretrends_joint is not None and bool(np.isfinite(pretrends_joint.p_value)) homogeneity_ok = homogeneity_joint is not None and bool(np.isfinite(homogeneity_joint.p_value)) qug_rej = qug_ok and qug.reject pretrends_rej = pretrends_joint is not None and pretrends_ok and bool(pretrends_joint.reject) homogeneity_rej = ( homogeneity_joint is not None and homogeneity_ok and bool(homogeneity_joint.reject) ) reasons = [] if qug_rej: reasons.append("support infimum rejected - continuous_at_zero design invalid (QUG)") if pretrends_rej: reasons.append("joint pre-trends rejected - assumption 7 violated (joint Stute)") if homogeneity_rej: reasons.append("joint linearity rejected - heterogeneity bias (joint Stute)") unresolved = [] if not qug_ok: unresolved.append("QUG NaN") if pretrends_joint is None: unresolved.append("joint pre-trends skipped (no earlier pre-period)") elif not pretrends_ok: unresolved.append("joint pre-trends NaN") if homogeneity_joint is None: unresolved.append("joint linearity skipped") elif not homogeneity_ok: unresolved.append("joint linearity NaN") if reasons: verdict = "; ".join(reasons) if unresolved: verdict += "; additional steps unresolved: " + "; ".join(unresolved) return verdict if unresolved: return "inconclusive - " + "; ".join(unresolved) return ( "QUG, joint pre-trends, and joint linearity diagnostics " "fail-to-reject (TWFE admissible under Section 4 assumptions)" )
[docs] def stute_joint_pretest( residuals_by_horizon: Dict[Any, np.ndarray], fitted_by_horizon: Dict[Any, np.ndarray], doses: np.ndarray, design_matrix: np.ndarray, *, alpha: float = 0.05, n_bootstrap: int = 999, seed: Optional[int] = None, null_form: str = "custom", survey_design: Any = None, survey: Any = None, weights: Optional[np.ndarray] = None, ) -> StuteJointResult: """Joint Cramer-von Mises pretest across multiple horizons. Generalizes :func:`stute_test` to K horizons with the joint statistic ``S_joint = sum_k S_k``, where ``S_k`` is the single- horizon CvM on residuals ``eps_{g,k}``. Inference is via Mammen wild bootstrap with a **shared** multiplier ``eta_g`` across horizons per unit to preserve the vector-valued empirical process's unit-level dependence. **Note:** sum-of-CvMs aggregation follows the standard joint specification-test construction (Delgado 1993; Escanciano 2006). The paper does not prescribe an aggregation; sum-of-CvMs balances power across diffuse vs concentrated alternatives and bootstraps cleanly with the shared-eta structure. Bootstrap uses the literal per-iteration OLS refit form (paper Appendix D) for consistency with Phase 3's :func:`stute_test`. ``XtX_inv_Xt`` is precomputed once (same design matrix each iteration), so the refit cost is O(Gp) per bootstrap draw and the overall loop is dominated by :func:`_cvm_statistic` across K horizons. Parameters ---------- residuals_by_horizon : dict[str, np.ndarray] ``{label: eps_g}`` per horizon. All values must have identical length ``G`` and be unit-ordered consistently with ``doses``. fitted_by_horizon : dict[str, np.ndarray] ``{label: fitted_g}`` per horizon. Required to reconstruct bootstrap outcomes ``dy*_{g,k} = fitted_{g,k} + eps_{g,k} * eta_g`` under the null. doses : np.ndarray, shape (G,) Dose per unit. Shared across horizons (HAD contract: dose is time-invariant per unit). Must be finite and non-negative. design_matrix : np.ndarray, shape (G, p) Regression design used in the per-horizon bootstrap refit. Mean-independence: ``[1]`` (intercept only). Linearity: ``[1, doses]``. The matrix is identical across horizons. alpha, n_bootstrap, seed : see :func:`stute_test`. null_form : str Diagnostic label recorded on the result (``"mean_independence"`` | ``"linearity"`` | ``"custom"``). The wrappers :func:`joint_pretrends_test` and :func:`joint_homogeneity_test` set this automatically. survey_design : ResolvedSurveyDesign or None, keyword-only, default None Already-resolved per-unit survey design (Phase 4.5 C). Array-in helpers accept ``ResolvedSurveyDesign`` ONLY; passing a ``SurveyDesign`` raises ``TypeError``. For pweight-only, use ``survey_design=make_pweight_design(arr)``. When supplied, the bootstrap is a PSU-level Mammen multiplier bootstrap with the multiplier matrix shared across horizons within each replicate (preserves both vector-valued empirical-process unit-level dependence + PSU clustering). Replicate-weight designs raise ``NotImplementedError``; non-pweight weight types are rejected. Variance-unidentified designs (``df_survey <= 0``) return NaN with a ``UserWarning`` instead of calibrating against an all-zero multiplier matrix. survey : ResolvedSurveyDesign or None, keyword-only, default None DEPRECATED alias of ``survey_design=``. Will be removed in the next minor release. weights : np.ndarray or None, keyword-only, default None DEPRECATED alias of ``survey_design=make_pweight_design(arr)``. Will be removed in the next minor release. Returns ------- StuteJointResult On the common path, a populated result with bootstrap-based ``p_value`` and ``cvm_stat_joint``. On the small-sample branch (``G < _MIN_G_STUTE``), constant-dose branch (``np.ptp(doses) <= 0``), or any-NaN branch in the input residuals / fitted arrays, returns an all-NaN result (with ``reject=False`` and the full ``per_horizon_stats`` dict keyed by the validated horizon labels) and emits a ``UserWarning`` for the first two branches. Mirrors the single-horizon :func:`stute_test` contract so event-study workflows on small or staggered-filtered panels surface an inconclusive report rather than crashing. Raises ------ ValueError On empty input, key-mismatch, stringified-label collisions between distinct raw keys, shape-mismatch, ``doses`` containing negative values, ``n_bootstrap < _MIN_N_BOOTSTRAP``, or invalid ``alpha``. ``G < _MIN_G_STUTE`` does NOT raise; see Returns. """ # Three-way mutex on survey_design / survey / weights (array-in pattern). n_set = sum(x is not None for x in (survey_design, survey, weights)) if n_set > 1: raise ValueError(HAD_DUAL_KNOB_MUTEX_MSG_ARRAY_IN) # Soft deprecation: route legacy survey=/weights= aliases to survey_design= # FIRST so the type guard below covers `survey=SurveyDesign(...)` too # (PR #376 R1 P1: alias must behave identically to the canonical kwarg). if survey is not None: warnings.warn(HAD_DEPRECATION_MSG_SURVEY_KWARG, DeprecationWarning, stacklevel=2) survey_design = survey elif weights is not None: warnings.warn( HAD_DEPRECATION_MSG_WEIGHTS_KWARG_ARRAY_IN, DeprecationWarning, stacklevel=2, ) survey_design = make_pweight_design(np.asarray(weights, dtype=np.float64)) # Type guard: array-in helpers reject SurveyDesign. Runs AFTER alias # rebinding so it covers both `survey_design=SurveyDesign(...)` and the # deprecated `survey=SurveyDesign(...)` form identically. if survey_design is not None and isinstance(survey_design, SurveyDesign): raise TypeError( "stute_joint_pretest: `survey_design=` accepts a pre-resolved " "ResolvedSurveyDesign only (array-in helpers have no `data` to " "resolve column names against). For pweight-only, use " "`survey_design=make_pweight_design(arr)`. For full PSU/strata/" "FPC, pre-resolve via `SurveyDesign(...).resolve(data)`." ) # Internal alias rebind for back-compat with downstream code. survey = survey_design weights = None # Replicate-weight rejection. if survey is not None and getattr(survey, "replicate_weights", None) is not None: raise NotImplementedError( "stute_joint_pretest: replicate-weight survey designs (BRR/Fay/JK1/" "JKn/SDR) are not yet supported on HAD pretests. Replicate-weight " "pretests are a parallel follow-up after Phase 4.5 C." ) # R1 P1: pweight-only guard. if survey is not None and getattr(survey, "weight_type", "pweight") != "pweight": raise ValueError( f"stute_joint_pretest: HAD pretests require weight_type='pweight'. " f"Got weight_type={survey.weight_type!r}." ) if not isinstance(residuals_by_horizon, dict) or not isinstance(fitted_by_horizon, dict): raise ValueError( "residuals_by_horizon and fitted_by_horizon must be dicts " "keyed by horizon label." ) if len(residuals_by_horizon) == 0: raise ValueError("residuals_by_horizon must contain at least one horizon.") if set(residuals_by_horizon.keys()) != set(fitted_by_horizon.keys()): raise ValueError( "residuals_by_horizon and fitted_by_horizon must have " "identical keys. Got " f"residuals keys: {sorted(residuals_by_horizon.keys())!r}, " f"fitted keys: {sorted(fitted_by_horizon.keys())!r}." ) doses_arr = _validate_1d_numeric(np.asarray(doses), "doses") G = doses_arr.shape[0] if np.any(doses_arr < 0): raise ValueError( "doses must be non-negative (HAD contract - paper Section 2). " f"Found {int(np.sum(doses_arr < 0))} negative value(s)." ) # G < _MIN_G_STUTE (CvM statistic not well-calibrated): mirror the # single-horizon `stute_test` contract - warn + return NaN result # rather than raise, so callers (including the event-study workflow # on a staggered panel whose last-cohort filter leaves fewer than # 10 units) get an inconclusive diagnostic instead of a crash. The # NaN return still satisfies the workflow's `np.isfinite(p_value)` # gating, so `all_pass` becomes False downstream. # Note: the actual `warn + return` happens below after horizon # labels are validated and collision-checked, so the NaN result # carries full per-horizon diagnostic keys. if n_bootstrap < _MIN_N_BOOTSTRAP: raise ValueError(f"n_bootstrap must be >= {_MIN_N_BOOTSTRAP}; got " f"{n_bootstrap}.") if not isinstance(alpha, (int, float)) or not (0 < float(alpha) < 1): raise ValueError(f"alpha must be in (0, 1); got {alpha!r}.") X = np.asarray(design_matrix, dtype=np.float64) if X.ndim != 2 or X.shape[0] != G: raise ValueError(f"design_matrix must have shape (G, p) with G={G}; got " f"{X.shape}.") if not np.all(np.isfinite(X)): raise ValueError("design_matrix contains non-finite values (NaN/inf).") raw_horizon_labels = list(residuals_by_horizon.keys()) K = len(raw_horizon_labels) # Stringified-label collision guard: distinct raw keys whose str() # representations collide (e.g. {1: ..., "1": ..., 1.0: ...}) would # overwrite each other in residuals_arrays / fitted_arrays, letting # the surviving horizon be double-counted in S_joint = sum of S_k # and leaving `n_horizons` inconsistent with the number of distinct # diagnostic statistics. Reject explicitly rather than silently # collapsing the test. str_labels = [str(k) for k in raw_horizon_labels] if len(set(str_labels)) != len(str_labels): from collections import Counter dup_strs = [s for s, c in Counter(str_labels).items() if c > 1] collisions = {s: [k for k in raw_horizon_labels if str(k) == s] for s in dup_strs} raise ValueError( f"Horizon label collision after str() stringification: " f"{collisions!r}. The joint Stute helpers index residuals " f"and fitted values by str(label); distinct raw keys whose " f"stringified form collides would silently overwrite each " f"other and double-count the surviving horizon in S_joint. " f"Use string-distinct horizon labels (e.g. 1997 and 1998 " f'as int, or "1997" and "1998" as str; not both).' ) any_nan = False residuals_arrays: Dict[str, np.ndarray] = {} fitted_arrays: Dict[str, np.ndarray] = {} for k in raw_horizon_labels: eps_k = np.asarray(residuals_by_horizon[k], dtype=np.float64) fit_k = np.asarray(fitted_by_horizon[k], dtype=np.float64) if eps_k.shape != (G,) or fit_k.shape != (G,): raise ValueError( f"Horizon {k!r}: residuals shape {eps_k.shape} and " f"fitted shape {fit_k.shape} must both be ({G},) to " f"align with doses." ) if not (np.all(np.isfinite(eps_k)) and np.all(np.isfinite(fit_k))): any_nan = True residuals_arrays[str(k)] = eps_k fitted_arrays[str(k)] = fit_k # Re-key to str labels consistently (wrappers already pass str; direct # callers may pass int/object). String identity per the documented # horizon_labels contract. The collision guard above ensures this # stringification is injective on the provided keys. horizon_labels = str_labels # Small-G NaN result (paired with the comment near the top of this # function): mirror the single-horizon stute_test contract so the # event-study workflow on a small or staggered-filtered panel gets # an inconclusive diagnostic rather than an exception. Positioned # AFTER the label-collision / shape-alignment guards so the NaN # result carries a consistent per-horizon diagnostic surface. if G < _MIN_G_STUTE: warnings.warn( f"stute_joint_pretest: G = {G} is below the minimum " f"{_MIN_G_STUTE} for the CvM statistic to be well-calibrated. " f"Returning NaN result.", UserWarning, stacklevel=2, ) return StuteJointResult( cvm_stat_joint=float("nan"), p_value=float("nan"), reject=False, alpha=float(alpha), horizon_labels=horizon_labels, per_horizon_stats={k: float("nan") for k in horizon_labels}, n_bootstrap=int(n_bootstrap), n_obs=int(G), n_horizons=int(K), seed=None if seed is None else int(seed), null_form=str(null_form), exact_linear_short_circuited=False, ) if any_nan: return StuteJointResult( cvm_stat_joint=float("nan"), p_value=float("nan"), reject=False, alpha=float(alpha), horizon_labels=horizon_labels, per_horizon_stats={k: float("nan") for k in horizon_labels}, n_bootstrap=int(n_bootstrap), n_obs=int(G), n_horizons=int(K), seed=None if seed is None else int(seed), null_form=str(null_form), exact_linear_short_circuited=False, ) # Zero-variation-in-D degeneracy guard: mirrors stute_test's intent # (had_pretests.py:~1233). The CvM cusum is defined against the # dose regressor; constant d has no cross-sectional variation for # the test to detect nonlinearity. Under the mean-independence null # this yields a mechanically-zero statistic (bogus fail-to-reject); # under the linearity null a singular [1, d] design matrix crashes # the refit. Emit warning + NaN result instead. # # Uses ``ptp`` (peak-to-peak = max - min) rather than ``np.var`` for # the degeneracy check: ``np.var`` of a truly constant array returns # a small non-zero value (~1e-32) due to E[X^2] - E[X]^2 rounding # noise, so a ``<= 0`` comparison misses the degeneracy. ``ptp`` is # bit-exact for identical inputs. if float(np.ptp(doses_arr)) <= 0.0: warnings.warn( "stute_joint_pretest: constant doses (zero cross-sectional " "variation); the joint Stute CvM requires dose variation. " "Returning NaN result.", UserWarning, stacklevel=2, ) return StuteJointResult( cvm_stat_joint=float("nan"), p_value=float("nan"), reject=False, alpha=float(alpha), horizon_labels=horizon_labels, per_horizon_stats={k: float("nan") for k in horizon_labels}, n_bootstrap=int(n_bootstrap), n_obs=int(G), n_horizons=int(K), seed=None if seed is None else int(seed), null_form=str(null_form), exact_linear_short_circuited=False, ) # Phase 4.5 C: resolve effective per-unit weights (None → bit-exact # unweighted path). # R4 P1: validate 1D explicitly so column-vector inputs raise. if survey is not None: w_arr = _validate_1d_numeric( np.asarray(survey.weights), "stute_joint_pretest: survey.weights" ) if w_arr.shape[0] != G: raise ValueError( f"stute_joint_pretest: survey.weights length {w_arr.shape[0]} " f"does not match doses length {G}." ) # R1 P0: strictly-positive guard (mirrors stute_test single-horizon). if (w_arr <= 0).any(): raise ValueError( "stute_joint_pretest: survey weights must be strictly " "positive. Zero / negative weights would leave units in " "the variance / CvM computation while contributing zero " "population mass." ) elif weights is not None: w_arr = _validate_1d_numeric(np.asarray(weights), "stute_joint_pretest: weights") if w_arr.shape[0] != G: raise ValueError( f"stute_joint_pretest: weights length {w_arr.shape[0]} does " f"not match doses length {G}." ) if (w_arr <= 0).any(): raise ValueError( "stute_joint_pretest: weights must be strictly positive (the " "pweight shortcut does not support zero weights)." ) else: w_arr = None # R4 P0: normalize pweights to mean=1 (matches SurveyDesign.resolve() # convention; same fix as stute_test / yatchew_hr_test). if w_arr is not None: w_arr = w_arr * (float(w_arr.shape[0]) / float(np.sum(w_arr))) idx = np.argsort(doses_arr, kind="stable") d_sorted = doses_arr[idx] per_horizon_stats: Dict[str, float] = {} if w_arr is None: for k in horizon_labels: per_horizon_stats[k] = _cvm_statistic(residuals_arrays[k][idx], d_sorted) else: w_sorted = w_arr[idx] for k in horizon_labels: per_horizon_stats[k] = _cvm_statistic_weighted( residuals_arrays[k][idx], d_sorted, w_sorted ) S_joint = float(sum(per_horizon_stats.values())) # Per-horizon exact-linear short-circuit (scale- and translation- # invariant, matches Phase 3 invariant). A single degenerate horizon # does NOT collapse the joint test if other horizons have nontrivial # residuals. short_circuit = True for k in horizon_labels: eps_k = residuals_arrays[k] fit_k = fitted_arrays[k] dy_k = fit_k + eps_k tss_centered = float(np.sum((dy_k - dy_k.mean()) ** 2)) if tss_centered == 0.0: # Outcome identically constant: treat as trivially linear for # this horizon (ratio = 0). Does not force short-circuit # because other horizons may still be nontrivial. ratio = 0.0 else: ratio = float(np.sum(eps_k**2) / tss_centered) if ratio >= _EXACT_LINEAR_RELATIVE_TOL: short_circuit = False break if short_circuit: return StuteJointResult( cvm_stat_joint=S_joint, p_value=1.0, reject=False, alpha=float(alpha), horizon_labels=horizon_labels, per_horizon_stats=per_horizon_stats, n_bootstrap=int(n_bootstrap), n_obs=int(G), n_horizons=int(K), seed=None if seed is None else int(seed), null_form=str(null_form), exact_linear_short_circuited=True, ) # Precompute OLS projection matrix once: same X per bootstrap draw, # so (X'X)^-1 X' is constant across iterations. Keeps refit O(Gp) # per draw without changing semantics from the literal paper form. # Weighted variant uses (X' W X)^-1 X' W; same precompute idiom. # Catch rank-deficient designs explicitly rather than surfacing a # raw ``np.linalg.LinAlgError`` to direct callers of the public # residuals-in core; matches the front-door validation style of # the other guards in this function. try: if w_arr is None: XtX_inv_Xt = np.linalg.solve(X.T @ X, X.T) else: # Weighted OLS projection: (X' W X)^-1 X' W XtWX = X.T @ (w_arr[:, np.newaxis] * X) XtW = X.T * w_arr # broadcasts (p, G) XtX_inv_Xt = np.linalg.solve(XtWX, XtW) except np.linalg.LinAlgError as exc: raise ValueError( f"design_matrix is rank-deficient (singular X^T X); cannot " f"compute the OLS projection (X^T X)^-1 X^T for the " f"bootstrap refit. Check for duplicate or linearly-" f"dependent columns. shape={X.shape}." ) from exc rng = np.random.default_rng(seed) bootstrap_S = np.empty(n_bootstrap, dtype=np.float64) if w_arr is None: # Unweighted bit-exact path (stability invariant #1). for b in range(n_bootstrap): # SHARED eta across horizons - preserves unit-level dependence # in the vector-valued empirical process. Independent-per-horizon # draws would overstate precision. eta = _generate_mammen_weights(G, rng) S_b = 0.0 for k in horizon_labels: dy_b = fitted_arrays[k] + residuals_arrays[k] * eta beta_b = XtX_inv_Xt @ dy_b eps_b = dy_b - X @ beta_b S_b += _cvm_statistic(eps_b[idx], d_sorted) bootstrap_S[b] = S_b else: # Phase 4.5 C survey-aware path: PSU-level Mammen multipliers # SHARED across horizons within each replicate. The (B, n_psu) # matrix is drawn ONCE per replicate; the per-horizon loop # broadcasts the SAME multipliers, preserving both the # vector-valued empirical-process unit-level dependence (paper # convention) AND PSU clustering (Krieger-Pfeffermann 1997). resolved_for_boot = survey if survey is not None else make_pweight_design(w_arr) # Stratified designs are supported via the standard stratified # clustered wild-bootstrap correction on the PSU multipliers # (within-stratum demean + sqrt(n_h/(n_h-1)) Bessel rescale), # applied uniformly before the per-obs broadcast eta_obs = # psu_mults[b, psu_col_idx] below. The joint variant shares the # SAME multiplier row across horizons within each replicate, so # the stratum correction applies once and inherits across # horizons (preserving cross-horizon empirical-process # dependence per Hlávka & Huškova 2020 § 3). See REGISTRY # § "Note (Stute stratified survey-bootstrap calibration)". # R5 P1: reject lonely_psu='adjust' singleton-strata designs. # Same pseudo-stratum centering gap as stute_test / HAD sup-t. if _has_lonely_psu_adjust_singletons(resolved_for_boot): raise NotImplementedError( "stute_joint_pretest: SurveyDesign(lonely_psu='adjust') " "with singleton strata is not yet supported on the " "multiplier bootstrap. The bootstrap helper pools " "singletons with nonzero multipliers but the matching " "analytical variance target requires a pseudo-stratum " "centering transform that has not been derived for the " "Stute CvM. Use lonely_psu='remove' (drops singleton " "contributions) or 'certainty' (zero-variance " "singletons), or pre-process the panel to remove " "singleton strata." ) # R3 P0: variance-unidentified survey-design guard (mirrors # stute_test single-horizon). df_survey = resolved_for_boot.df_survey if df_survey is None or df_survey <= 0: warnings.warn( f"stute_joint_pretest: survey design is variance-" f"unidentified (df_survey={df_survey}); the multiplier " "bootstrap cannot calibrate the joint test. Returning " "NaN result.", UserWarning, stacklevel=2, ) return StuteJointResult( cvm_stat_joint=S_joint, p_value=float("nan"), reject=False, alpha=float(alpha), horizon_labels=horizon_labels, per_horizon_stats=per_horizon_stats, n_bootstrap=int(n_bootstrap), n_obs=int(G), n_horizons=int(K), seed=None if seed is None else int(seed), null_form=str(null_form), exact_linear_short_circuited=False, ) psu_mults, psu_ids = generate_survey_multiplier_weights_batch( n_bootstrap, resolved_for_boot, weight_type="mammen", rng=rng ) # Stratum centering + Bessel rescale on the PSU multipliers # before broadcast. Single application here (shared with the # per-horizon loop below) propagates the same centered # multipliers across all horizons in each replicate, preserving # the joint Stute's cross-horizon empirical-process dependence. # See REGISTRY § "Note (Stute stratified survey-bootstrap # calibration)". apply_stratum_centering(psu_mults, resolved_for_boot, psu_ids, psu_axis=1) if resolved_for_boot.psu is None: psu_col_idx = np.arange(G) else: psu_to_col = {int(p): c for c, p in enumerate(psu_ids)} psu_arr = np.asarray(resolved_for_boot.psu) psu_col_idx = np.array([psu_to_col[int(psu_arr[g])] for g in range(G)]) w_sorted = w_arr[idx] for b in range(n_bootstrap): eta_obs = psu_mults[b, psu_col_idx] # (G,) - shared across horizons S_b = 0.0 for k in horizon_labels: dy_b = fitted_arrays[k] + residuals_arrays[k] * eta_obs beta_b = XtX_inv_Xt @ dy_b eps_b = dy_b - X @ beta_b S_b += _cvm_statistic_weighted(eps_b[idx], d_sorted, w_sorted) bootstrap_S[b] = S_b p_value = float((1.0 + np.sum(bootstrap_S >= S_joint)) / (n_bootstrap + 1)) reject = bool(p_value <= alpha) return StuteJointResult( cvm_stat_joint=S_joint, p_value=p_value, reject=reject, alpha=float(alpha), horizon_labels=horizon_labels, per_horizon_stats=per_horizon_stats, n_bootstrap=int(n_bootstrap), n_obs=int(G), n_horizons=int(K), seed=None if seed is None else int(seed), null_form=str(null_form), exact_linear_short_circuited=False, )
def _resolve_pretest_unit_weights( data: pd.DataFrame, unit_col: str, weights: Optional[np.ndarray], survey: Any, caller_name: str, ) -> "tuple[Optional[np.ndarray], Optional[Any]]": """Resolve per-row ``weights`` / ``survey`` kwargs to per-unit (G,) form. Used by ``joint_pretrends_test``, ``joint_homogeneity_test``, and ``did_had_pretest_workflow`` (data-in entry points). Reuses the HAD helpers ``_aggregate_unit_weights`` and ``_aggregate_unit_resolved_survey`` which enforce constant-within-unit invariant on weights and on every survey design column (strata, psu, fpc). Mutex on ``weights`` AND ``survey`` (cannot supply both). Replicate- weight survey designs raise ``NotImplementedError`` (deferred to a parallel follow-up after Phase 4.5 C). Returns ------- (weights_unit, resolved_unit) : Tuple[Optional[np.ndarray], Optional[ResolvedSurveyDesign]] - If neither kwarg supplied: ``(None, None)`` (unweighted path). - If ``weights`` supplied: ``(weights_unit, None)``. - If ``survey`` supplied: ``(None, resolved_unit)`` where ``resolved_unit.weights`` is the per-unit weight vector and strata/psu/fpc are also per-unit. """ if weights is None and survey is None: return None, None if weights is not None and survey is not None: raise ValueError( f"{caller_name}: pass survey=<SurveyDesign> OR weights=<array>, " "not both." ) if weights is not None: weights_arr = np.asarray(weights, dtype=np.float64) # R4 P1: validate 1D explicitly (column-vector inputs would otherwise # broadcast through downstream computations and silently corrupt # results). if weights_arr.ndim != 1: raise ValueError( f"{caller_name}: weights must be 1-dimensional, got shape " f"{weights_arr.shape}. (A common mistake is passing " "df[['w']].to_numpy() which produces (N, 1); use " "df['w'].to_numpy() for (N,).)" ) weights_unit = _aggregate_unit_weights(data, weights_arr, unit_col) # R1 P0: strictly-positive weights required on the pweight shortcut # (matches stute_test/yatchew_hr_test direct entry behavior; the CvM # cusum + adjacent-difference variance assume all rows contribute). if (weights_unit <= 0).any(): raise ValueError( f"{caller_name}: weights must be strictly positive at the " "per-unit level. Zero / negative weights would leave units " "in the variance/CvM computation while contributing zero " "mass; use survey= with explicit lonely-PSU handling for " "principled subpopulation analysis." ) # R4 P0: normalize per-unit weights to mean=1 (matches # SurveyDesign.resolve() convention so weights= and survey= entry # paths produce identical statistic values; ensures Yatchew is # scale-invariant under uniform rescaling). weights_unit = weights_unit * (float(weights_unit.shape[0]) / float(np.sum(weights_unit))) return weights_unit, None # survey is not None if not hasattr(survey, "resolve"): # PR #376 R9 P3: error message names the canonical kwarg # `survey_design=` (with the deprecated `survey=` alias mentioned # for back-compat), and points pre-resolved-design users to the # array-in pretest helpers where ResolvedSurveyDesign / # make_pweight_design(arr) belong. raise TypeError( f"{caller_name}: `survey_design=` (or the deprecated `survey=` " f"alias) accepts a SurveyDesign instance (column-referencing, " f"gets `.resolve(data)`'d at fit time) on data-in surfaces; " f"got {type(survey).__name__} (no `.resolve()` method). " "If you have a pre-resolved ResolvedSurveyDesign or used " "`make_pweight_design(arr)`, that pattern is for the array-in " "pretest helpers (`stute_test`, `yatchew_hr_test`, " "`stute_joint_pretest`). On data-in surfaces, add the weights " "as a column on `data` and pass " "`survey_design=SurveyDesign(weights='col_name', ...)`." ) resolved_full = survey.resolve(data) if getattr(resolved_full, "replicate_weights", None) is not None: raise NotImplementedError( f"{caller_name}: replicate-weight survey designs (BRR/Fay/JK1/JKn/" "SDR) are not yet supported on HAD pretests. Replicate-weight " "pretests are a parallel follow-up after Phase 4.5 C." ) # R1 P1: pweight-only guard. aweight/fweight slip through pweight-only # formulas silently otherwise (mirrors HeterogeneousAdoptionDiD.fit() at # had.py:2976+ and survey._resolve_pweight_only at survey.py:914). if getattr(resolved_full, "weight_type", "pweight") != "pweight": raise ValueError( f"{caller_name}: HAD pretests require weight_type='pweight'. " f"Got weight_type={resolved_full.weight_type!r}. aweight / " "fweight have different sandwich-variance semantics that are " "not derived for the pretest variance components." ) resolved_unit = _aggregate_unit_resolved_survey(data, resolved_full, unit_col) # R1 P0: strictly-positive weights at the per-unit level (mirrors the # weights= shortcut). Zero per-unit weights leave units in the dose- # variation check / CvM sum while contributing zero population mass, # which can produce silently-wrong pretest decisions. if (np.asarray(resolved_unit.weights) <= 0).any(): raise ValueError( f"{caller_name}: survey weights must be strictly positive at " "the per-unit level. Zero / negative weights would leave units " "in the variance/CvM computation while contributing zero " "mass; this would produce silent wrong pretest decisions on " "subpopulation-restricted designs. Pre-filter the panel to " "the positive-weight subpopulation before calling the workflow." ) return None, resolved_unit
[docs] def joint_pretrends_test( data: pd.DataFrame, outcome_col: str, dose_col: str, time_col: str, unit_col: str, pre_periods: list, base_period: Any, first_treat_col: Optional[str] = None, *, alpha: float = 0.05, n_bootstrap: int = 999, seed: Optional[int] = None, survey_design: Any = None, survey: Any = None, weights: Optional[np.ndarray] = None, trends_lin: bool = False, ) -> StuteJointResult: """Joint Stute pre-trends test (paper Section 4.2 step 2). Data-in wrapper around :func:`stute_joint_pretest` for the mean-independence null ``E[Y_{g,t} - Y_{g,base} | D_{g,treat}] = mu_t`` across multiple pre-period placebos. For each ``t in pre_periods``, residuals are the deviations of ``Y_{g,t} - Y_{g,base}`` from their cross-unit mean (an intercept-only OLS fit); the joint CvM tests that the conditional mean depends on ``D``. Use this wrapper to close the paper's step-2 pre-trends gap that :func:`did_had_pretest_workflow` otherwise flags. On a panel with at least one earlier pre-period, the ``aggregate="event_study"`` dispatch calls this wrapper internally. Parameters ---------- data : pd.DataFrame outcome_col, dose_col, time_col, unit_col : str pre_periods : list Non-empty list of pre-period labels (all ``< base_period``, all with ``D = 0`` across every unit). Empty list raises; the workflow dispatch handles the "no earlier pre-period" case by setting ``pretrends_joint=None`` rather than calling this wrapper. base_period : period label The reference period. Must not be in ``pre_periods``. Must also satisfy ``D = 0`` across every unit (reciprocal of the pre-period HAD invariant - base is itself a pre-period in the four-step workflow). first_treat_col : str or None Forwarded to the underlying panel validator; matched cohort handling follows the HAD contract (staggered auto-filter warns and proceeds on last cohort; solo cohort proceeds). alpha, n_bootstrap, seed : as in :func:`stute_test`. survey_design : SurveyDesign or None, keyword-only, default None Survey design (Phase 4.5 C). Resolved on the filtered panel; replicate-weight designs raise ``NotImplementedError``; ``weight_type`` must be ``"pweight"``. Forwarded to :func:`stute_joint_pretest` as a per-unit ``ResolvedSurveyDesign``. Mutually exclusive with the deprecated ``survey=`` and ``weights=`` aliases. survey : SurveyDesign or None, keyword-only, default None DEPRECATED alias of ``survey_design=``. Will be removed in the next minor release. weights : np.ndarray or None, keyword-only, default None DEPRECATED alias for the per-row pweight shortcut. Prefer ``survey_design=SurveyDesign(weights='col_name')`` against your dataframe instead. Will be removed in the next minor release. trends_lin : bool, default False, keyword-only When ``True``, applies paper Eq 17 / Eq 18 linear-trend detrending: per-group slope estimated as ``Y[g, base] - Y[g, base - 1]`` and subtracted from each pre-period horizon's outcome evolution as ``(t - base) × slope``. Mirrors R ``DIDHAD::did_had(..., trends_lin=TRUE)`` on its joint Stute pre-trends surface (paper Section 5.2 Pierce-Schott application). **Requires** ``base_period`` to equal the last validated pre-period (``t_pre_list[-1]``, i.e. the canonical ``F-1`` anchor). Direct callers passing a non-terminal base get a ``ValueError`` — Eq 17 / R both anchor at ``F-1`` and any other anchor would compute a different slope and detrending. The previous validated pre-period (``t_pre_list[-2]``, ``F-2``) must also be present so the slope is identified. The "consumed" placebo at ``F-2`` is dropped from ``pre_periods`` explicitly (its detrended residual is mechanically zero by construction); a ``UserWarning`` fires when the filter triggers. If ``pre_periods`` becomes empty after the drop, raises ``ValueError`` (no testable placebo horizons remain). Mutually exclusive with survey weighting (``survey_design`` / ``survey`` / ``weights``); raises ``NotImplementedError`` if combined. Default ``False`` preserves bit-exact backcompat. Returns ------- StuteJointResult with ``null_form = "mean_independence"``. """ # Three-way mutex on survey_design / survey / weights (data-in pattern). n_set = sum(x is not None for x in (survey_design, survey, weights)) if n_set > 1: raise ValueError(HAD_DUAL_KNOB_MUTEX_MSG_DATA_IN) # Soft deprecation: route legacy survey=/weights= aliases to survey_design=. if survey is not None: warnings.warn(HAD_DEPRECATION_MSG_SURVEY_KWARG, DeprecationWarning, stacklevel=2) survey_design = survey elif weights is not None: warnings.warn( HAD_DEPRECATION_MSG_WEIGHTS_KWARG_DATA_IN, DeprecationWarning, stacklevel=2, ) # weights= shortcut preserved as-is on the back end. # Internal alias rebind: downstream code uses `survey` and `weights`. if survey_design is not None and survey is None: survey = survey_design # ---- trends_lin × survey_design gate (PR #389 / Phase 4 R-parity). ---- # Detrending under survey weighting (weighted slope? per-PSU slope?) # is not derived from the paper. Use trends_lin without survey weights, # OR survey weights without trends_lin. Tracked as TODO follow-up. if trends_lin and (survey is not None or weights is not None): raise NotImplementedError( "joint_pretrends_test(trends_lin=True) is not yet supported " "with survey weighting (`survey_design=` / `survey=` / " "`weights=`). The per-group slope estimator's weighted " "variant is not derived from the paper. Use trends_lin=True " "WITHOUT survey weights, or use survey weights WITHOUT " "trends_lin. Tracked in TODO.md as a follow-up if user " "demand emerges." ) if len(pre_periods) == 0: raise ValueError( "pre_periods must be non-empty. Workflow dispatch handles " "the empty case by setting pretrends_joint=None; direct " "callers should not pass an empty list." ) if base_period in pre_periods: raise ValueError( f"base_period={base_period!r} must not appear in " f"pre_periods {list(pre_periods)!r}." ) # Ordering check: all pre_periods strictly < base_period in # chronological order. Uses `_build_period_rank` to handle ordered- # categorical time columns correctly (raw Python `<` would fail on # categories whose lexical order disagrees with chronology, e.g. # ["q1", "q2", "q10"]). Numeric / datetime dtypes get natural order. period_rank = _build_period_rank(data, time_col) if base_period not in period_rank: raise ValueError( f"base_period={base_period!r} not found in time_col " f"{time_col!r}. Available: " f"{sorted(period_rank.keys(), key=lambda t: period_rank[t])!r}." ) missing_pre_in_data = [t for t in pre_periods if t not in period_rank] if missing_pre_in_data: raise ValueError( f"pre_periods entries {missing_pre_in_data!r} not found in " f"time_col {time_col!r}. Available: " f"{sorted(period_rank.keys(), key=lambda t: period_rank[t])!r}." ) base_rank = period_rank[base_period] out_of_order = [t for t in pre_periods if period_rank[t] >= base_rank] if out_of_order: raise ValueError( f"All pre_periods must be strictly < base_period in " f"chronological order. Violators: {out_of_order!r} " f"(base_period={base_period!r})." ) # ---- trends_lin: defer the consumed-placebo drop and base-1 # identification until AFTER the validator block runs (so we can # use t_pre_list to enforce the non-terminal-base guard and the # observed-period predecessor consistently). On 2-period panels # the validator does not run and trends_lin needs F-2, which is # impossible — front-door-reject here. base_minus_1_period: Any = None pre_periods_effective = list(pre_periods) # Event-study validation contract (paper Appendix B.2): # When the panel has >= 3 distinct periods, always route through # `_validate_had_panel_event_study`. This enforces (a) balanced # panel, (b) ordered time dtype, (c) D = 0 across every pre-period, # (d) last-cohort auto-filter under staggered timing with # UserWarning, (e) constant post-treatment dose within unit. When # first_treat_col is None and the panel is staggered, the validator # RAISES - matching the workflow dispatch contract. For 2-period # panels the validator does not apply; skip and fall through to the # simpler balance/invariant guards in `_aggregate_for_joint_test`. n_periods = int(data[time_col].nunique()) if trends_lin and n_periods < 3: raise ValueError( f"joint_pretrends_test(trends_lin=True) requires a panel " f"with at least 3 distinct time periods so the per-group " f"slope Y[g, base] - Y[g, base - 1] is identified. Got " f"n_periods={n_periods}." ) data_filtered: pd.DataFrame = data if n_periods >= 3: F_val, t_pre_list, _t_post_list, data_filtered, _filter_info = ( _validate_had_panel_event_study( data, outcome_col=outcome_col, dose_col=dose_col, time_col=time_col, unit_col=unit_col, first_treat_col=first_treat_col, ) ) # `_validate_had_panel_event_study` already emits its own # `UserWarning` on the staggered-filter path; the wrapper # consumes `_filter_info` silently to avoid duplicated console # noise (R4 code-quality fix). # Subset invariants: the caller's base_period and pre_periods # must be pre-treatment periods under the validator's partition. if base_period not in t_pre_list: raise ValueError( f"base_period={base_period!r} is not in the validated " f"pre-period set {list(t_pre_list)!r} (periods before " f"first-treatment period F={F_val!r}). For the HAD " f"pre-trends workflow, base_period must be a pre-period " f"anchor (typically the last pre-period, F-1)." ) not_pre = [t for t in pre_periods if t not in t_pre_list] if not_pre: raise ValueError( f"pre_periods must all be validated pre-treatment " f"periods. Not-pre entries: {not_pre!r}. Validator's " f"pre-period set: {list(t_pre_list)!r}." ) # PR #392 R3 P1 (non-terminal base guard): paper Eq 17 / Eq 18 # and R `DIDHAD::did_had(..., trends_lin=TRUE)` anchor the # detrending at F-1 (the last validated pre-period) and use # Y[F-1] - Y[F-2] as the slope. A direct caller passing # base_period < F-1 (e.g. F-2) would compute a different slope # at a different anchor, silently changing the methodology # away from the documented Eq 17/18 construction. Reject # explicitly. Workflow + HAD.fit always pass F-1; this check # only fires on direct user calls with non-terminal bases. if trends_lin and base_period != t_pre_list[-1]: raise ValueError( f"joint_pretrends_test(trends_lin=True) requires " f"base_period to equal the last validated pre-period " f"({t_pre_list[-1]!r}, the canonical Eq 17 anchor " f"F-1). Got base_period={base_period!r}. Anchoring at " f"any other pre-period would compute a different " f"slope and detrending that does not match paper " f"Eq 17 / Eq 18 or R DIDHAD::did_had(trends_lin=TRUE)." ) # PR #392 R3 P1 (observed-period base-1 lookup) + R1 P0 # (consumed-placebo drop) consolidated: # base_minus_1_period = t_pre_list[-2] (= F-2, the validated # observed pre-period immediately before F-1). Using # t_pre_list ensures correctness on ordered-categorical panels # with unused intermediate levels (the validator's t_pre_list # is built from observed contiguous pre-periods, not from the # full dtype's category list). Then drop t_pre_list[-2] from # pre_periods if present (the consumed placebo whose detrended # residual is mechanically zero). if trends_lin: if len(t_pre_list) < 2: raise ValueError( f"joint_pretrends_test(trends_lin=True) requires " f"at least 2 validated pre-periods so the per-" f"group slope Y[g, F-1] - Y[g, F-2] is identified. " f"Got t_pre_list={list(t_pre_list)!r}." ) base_minus_1_period = t_pre_list[-2] if base_minus_1_period in pre_periods_effective: warnings.warn( f"joint_pretrends_test(trends_lin=True): dropping " f"period {base_minus_1_period!r} from pre_periods " f"— it is the 'consumed' placebo (the F-2 → F-1 " f"evolution used by the per-group slope " f"estimator), so under trends_lin its detrended " f"residual is mechanically zero. R's " f"`did_had(trends_lin=TRUE)` reduces max placebo " f"lag by 1 with the same effect.", UserWarning, stacklevel=2, ) pre_periods_effective = [ t for t in pre_periods_effective if t != base_minus_1_period ] if len(pre_periods_effective) == 0: raise ValueError( f"joint_pretrends_test(trends_lin=True): no testable " f"placebo horizons remain after dropping the consumed " f"placebo at base_period - 1 = {base_minus_1_period!r}. " f"Pass at least one earlier observed pre-period when " f"using trends_lin=True." ) d_arr, dy_by_horizon, _ = _aggregate_for_joint_test( data_filtered, outcome_col=outcome_col, dose_col=dose_col, time_col=time_col, unit_col=unit_col, horizons=list(pre_periods_effective), base_period=base_period, ) G = d_arr.shape[0] # HAD invariant: D_{g,t} = 0 for every g and every pre_period (and # for base_period - it is itself a pre-period relative to the # treatment onset). We check this on the passed-in panel subset. # Use pre_periods_effective so the consumed placebo (dropped above # under trends_lin) is not in the dose-zero check window. needed_all_zero = list(pre_periods_effective) + [base_period] subset_zero_check = data_filtered[data_filtered[time_col].isin(needed_all_zero)] if (subset_zero_check[dose_col] != 0).any(): n_nonzero = int((subset_zero_check[dose_col] != 0).sum()) raise ValueError( f"Pre-trends test requires D = 0 in every pre-period " f"(including base_period). Found {n_nonzero} non-zero " f"dose observation(s) across periods " f"{needed_all_zero!r}. HAD contract (paper Section 2) and " f"pre-trends test design both require the zero-dose " f"invariant to hold in ALL periods used as placebo or " f"anchor." ) # ---- Apply trends_lin detrending (paper Eq 17 / Eq 18). # base_minus_1_period was computed and validated above (before the # consumed-placebo drop). Compute the per-group slope and apply the # `(t - base) × slope` adjustment to each remaining horizon. if trends_lin: # Extract Y[g, base] and Y[g, base-1] in unit_col-sorted order # (matching d_arr / dy_by_horizon ordering produced by # _aggregate_for_joint_test's wide-pivot sort by index). slope_subset = data_filtered[ data_filtered[time_col].isin([base_period, base_minus_1_period]) ] wide_y = slope_subset.pivot(index=unit_col, columns=time_col, values=outcome_col) wide_y = wide_y.sort_index() if wide_y[base_period].isna().any() or wide_y[base_minus_1_period].isna().any(): raise ValueError( f"joint_pretrends_test(trends_lin=True): NaN value(s) " f"in outcome at base_period={base_period!r} or " f"base_period-1={base_minus_1_period!r}. The slope " f"estimator requires complete observations at both " f"periods for every unit." ) slope = wide_y[base_period].to_numpy(dtype=np.float64) - wide_y[ base_minus_1_period ].to_numpy(dtype=np.float64) # PR #392 R4 P0: build the detrending rank from OBSERVED # periods (on data_filtered), not from the full categorical # dtype. Otherwise unused intermediate categorical levels # silently change the (t - base) multiplier and corrupt the # joint statistic. Mirrors HAD.fit's # `_aggregate_multi_period_first_differences` convention which # uses `sorted(t_pre_list + t_post_list, ...)` for the # event-time rank. observed_rank = { p: i for i, p in enumerate( sorted( set(data_filtered[time_col].unique()), key=lambda p: period_rank[p], ) ) } base_rank_observed = observed_rank[base_period] # Apply detrending in place to remaining dy_by_horizon entries. for t in pre_periods_effective: label = str(t) delta = observed_rank[t] - base_rank_observed # < 0 for pre-periods dy_by_horizon[label] = dy_by_horizon[label] - delta * slope # Phase 4.5 C: aggregate per-row weights/survey to per-unit (G,) # using the existing HAD helpers (constant-within-unit invariant # enforced; replicate-weight rejected on the survey path). # R2 P1 fix: subset row-level `weights` to data_filtered's rows BEFORE # resolution, mirroring did_had_pretest_workflow. When # _validate_had_panel_event_study auto-filters to the last cohort # under staggered timing, the original weights array no longer aligns # with data_filtered's row count. Survey= path is unaffected # (column references resolved internally on data_filtered). weights_for_resolve = weights if weights is not None: # R9 P1: validate 1D + length-matched-to-data BEFORE any # staggered-panel subsetting. Otherwise oversized arrays would # be silently truncated and undersized arrays would surface raw # NumPy indexing errors instead of the package's front-door # ValueError. weights_arr = np.asarray(weights, dtype=np.float64) if weights_arr.ndim != 1: raise ValueError( f"joint_pretrends_test: weights must be 1-dimensional, got " f"shape {weights_arr.shape}. (A common mistake is passing " "df[['w']].to_numpy() which produces (N, 1); use " "df['w'].to_numpy() for (N,).)" ) if weights_arr.shape[0] != len(data): raise ValueError( f"joint_pretrends_test: weights length {weights_arr.shape[0]} " f"does not match data length {len(data)}." ) if len(data_filtered) != len(data): pos_idx = data.index.get_indexer(data_filtered.index) if (pos_idx < 0).any(): raise ValueError( "joint_pretrends_test: cannot align row-level weights to " "the staggered-filtered panel; some data_filtered rows " "do not appear in original data.index." ) weights_for_resolve = weights_arr[pos_idx] else: weights_for_resolve = weights_arr weights_unit, resolved_unit = _resolve_pretest_unit_weights( data_filtered, unit_col, weights_for_resolve, survey, "joint_pretrends_test" ) # Reorder per-unit weights to match d_arr/dy_by_horizon ordering. # _aggregate_for_joint_test sorts the wide pivot by index (unit_col), # so per-unit order is the SAME as _aggregate_unit_weights' output. w_eff = resolved_unit.weights if resolved_unit is not None else weights_unit residuals_by_horizon: Dict[str, np.ndarray] = {} fitted_by_horizon: Dict[str, np.ndarray] = {} for label, dy_t in dy_by_horizon.items(): if w_eff is None: mean_t = float(dy_t.mean()) else: mean_t = float(np.sum(w_eff * dy_t) / np.sum(w_eff)) fitted_t = np.full(G, mean_t, dtype=np.float64) residuals_t = dy_t - fitted_t residuals_by_horizon[label] = residuals_t fitted_by_horizon[label] = fitted_t design_matrix = np.ones((G, 1), dtype=np.float64) # Internal forwarding: pass survey_design= directly to stute_joint_pretest # to avoid emitting the deprecation warning on every internal call. The # canonical kwarg is the same on both ends; the warning fires ONCE at the # user-facing front door (this wrapper) when the user passed a deprecated # alias. if resolved_unit is not None: joint_survey_design = resolved_unit elif weights_unit is not None: joint_survey_design = make_pweight_design(weights_unit) else: joint_survey_design = None return stute_joint_pretest( residuals_by_horizon=residuals_by_horizon, fitted_by_horizon=fitted_by_horizon, doses=d_arr, design_matrix=design_matrix, alpha=alpha, n_bootstrap=n_bootstrap, seed=seed, null_form="mean_independence", survey_design=joint_survey_design, )
[docs] def joint_homogeneity_test( data: pd.DataFrame, outcome_col: str, dose_col: str, time_col: str, unit_col: str, post_periods: list, base_period: Any, first_treat_col: Optional[str] = None, *, alpha: float = 0.05, n_bootstrap: int = 999, seed: Optional[int] = None, survey_design: Any = None, survey: Any = None, weights: Optional[np.ndarray] = None, trends_lin: bool = False, ) -> StuteJointResult: """Joint Stute homogeneity-linearity test (paper Section 4.3 joint). Data-in wrapper around :func:`stute_joint_pretest` for the linearity null ``E[Y_{g,t} - Y_{g,base} | D_{g,t}] = beta_{0,t} + beta_{fe,t} * D_{g,t}`` across multiple post-period horizons. For each ``t in post_periods``, residuals are from an OLS regression of ``Y_{g,t} - Y_{g,base}`` on ``[1, D_g]``; the joint CvM tests whether the conditional mean is nonlinear in ``D`` in any horizon. Used by :func:`did_had_pretest_workflow` with ``aggregate="event_study"`` as the step-3 test (no joint Yatchew variant exists - the paper does not derive one; users who need Yatchew-style adjacent-difference robustness can call :func:`yatchew_hr_test` on each (base, post) pair manually). Parameters ---------- data : pd.DataFrame outcome_col, dose_col, time_col, unit_col : str post_periods : list Non-empty list of post-period labels (all strictly ``> base_period`` by chronological order; each with ``D > 0`` for some unit, i.e. at least one treated unit per horizon). base_period : period label The reference period (last pre-period in the event-study convention). Must not be in ``post_periods``. first_treat_col : str or None Forwarded to the underlying panel validator. alpha, n_bootstrap, seed : as in :func:`stute_test`. survey_design : SurveyDesign or None, keyword-only, default None Survey design (Phase 4.5 C). Same contract as :func:`joint_pretrends_test`. Mutually exclusive with the deprecated ``survey=`` and ``weights=`` aliases. survey : SurveyDesign or None, keyword-only, default None DEPRECATED alias of ``survey_design=``. Will be removed in the next minor release. weights : np.ndarray or None, keyword-only, default None DEPRECATED alias for the per-row pweight shortcut. Prefer ``survey_design=SurveyDesign(weights='col_name')`` against your dataframe instead. Will be removed in the next minor release. trends_lin : bool, default False, keyword-only When ``True``, applies paper page-32 linear-trend detrending: per-group slope estimated as ``Y[g, base] - Y[g, base - 1]`` and applied to each post-period horizon's outcome evolution as ``(t - base) × slope`` (forward extrapolation into post). Same slope estimator as :func:`joint_pretrends_test`. Mirrors R ``DIDHAD::did_had(..., trends_lin=TRUE)`` on its joint homogeneity surface (paper Section 4.3, Pierce-Schott p=0.40 anchor). **Requires** ``base_period`` to equal the last validated pre-period (``t_pre_list[-1]``, the canonical ``F-1`` anchor) AND ``F-2`` to be present in the panel so the slope is identified. Direct callers passing a non- terminal base get a ``ValueError`` — Eq 17 / R both anchor at ``F-1``. Mutually exclusive with survey weighting; raises ``NotImplementedError`` if combined. Default ``False`` preserves bit-exact backcompat. Returns ------- StuteJointResult with ``null_form = "linearity"``. """ # Three-way mutex on survey_design / survey / weights (data-in pattern). n_set = sum(x is not None for x in (survey_design, survey, weights)) if n_set > 1: raise ValueError(HAD_DUAL_KNOB_MUTEX_MSG_DATA_IN) # Soft deprecation: route legacy survey=/weights= aliases to survey_design=. if survey is not None: warnings.warn(HAD_DEPRECATION_MSG_SURVEY_KWARG, DeprecationWarning, stacklevel=2) survey_design = survey elif weights is not None: warnings.warn( HAD_DEPRECATION_MSG_WEIGHTS_KWARG_DATA_IN, DeprecationWarning, stacklevel=2, ) # weights= shortcut preserved as-is on the back end. # Internal alias rebind: downstream code uses `survey` and `weights`. if survey_design is not None and survey is None: survey = survey_design # ---- trends_lin × survey_design gate (PR #389 / Phase 4 R-parity). # Twin of joint_pretrends_test guard. ---- if trends_lin and (survey is not None or weights is not None): raise NotImplementedError( "joint_homogeneity_test(trends_lin=True) is not yet " "supported with survey weighting (`survey_design=` / " "`survey=` / `weights=`). The per-group slope estimator's " "weighted variant is not derived from the paper. Use " "trends_lin=True WITHOUT survey weights, or use survey " "weights WITHOUT trends_lin. Tracked in TODO.md as a " "follow-up if user demand emerges." ) if len(post_periods) == 0: raise ValueError( "post_periods must be non-empty. Workflow dispatch handles " "the empty case upstream; direct callers should not pass " "an empty list." ) if base_period in post_periods: raise ValueError( f"base_period={base_period!r} must not appear in " f"post_periods {list(post_periods)!r}." ) # Ordering: all post_periods strictly > base_period in # chronological order. Uses `_build_period_rank` for ordered- # categorical correctness (raw Python `>` would misorder e.g. # "q10" > "q2"). period_rank = _build_period_rank(data, time_col) if base_period not in period_rank: raise ValueError( f"base_period={base_period!r} not found in time_col " f"{time_col!r}. Available: " f"{sorted(period_rank.keys(), key=lambda t: period_rank[t])!r}." ) missing_post_in_data = [t for t in post_periods if t not in period_rank] if missing_post_in_data: raise ValueError( f"post_periods entries {missing_post_in_data!r} not found in " f"time_col {time_col!r}. Available: " f"{sorted(period_rank.keys(), key=lambda t: period_rank[t])!r}." ) base_rank = period_rank[base_period] out_of_order = [t for t in post_periods if period_rank[t] <= base_rank] if out_of_order: raise ValueError( f"All post_periods must be strictly > base_period in " f"chronological order. Violators: {out_of_order!r} " f"(base_period={base_period!r})." ) # Event-study validation contract (paper Appendix B.2) - twin of # `joint_pretrends_test`. Same gating by `n_periods >= 3`; same # subset-invariant checks; emits the staggered-filter UserWarning. # The validator also enforces constant post-treatment dose within # unit, which is critical for the homogeneity path because a # time-varying post-dose would make the per-horizon refit on # `[1, D_g]` misspecify the regressor. n_periods = int(data[time_col].nunique()) if trends_lin and n_periods < 3: raise ValueError( f"joint_homogeneity_test(trends_lin=True) requires a " f"panel with at least 3 distinct time periods so the " f"per-group slope Y[g, base] - Y[g, base - 1] is " f"identified. Got n_periods={n_periods}." ) base_minus_1_period_validated: Any = None # set inside validator block under trends_lin data_filtered: pd.DataFrame = data if n_periods >= 3: F_val, t_pre_list, t_post_list, data_filtered, _filter_info = ( _validate_had_panel_event_study( data, outcome_col=outcome_col, dose_col=dose_col, time_col=time_col, unit_col=unit_col, first_treat_col=first_treat_col, ) ) # `_validate_had_panel_event_study` already emits its own # `UserWarning` on the staggered-filter path; the wrapper # consumes `_filter_info` silently to avoid duplicated console # noise (R4 code-quality fix). if base_period not in t_pre_list: raise ValueError( f"base_period={base_period!r} is not in the validated " f"pre-period set {list(t_pre_list)!r} (periods before " f"first-treatment period F={F_val!r}). For the HAD " f"homogeneity workflow, base_period must be a pre-period " f"anchor (typically the last pre-period, F-1)." ) not_post = [t for t in post_periods if t not in t_post_list] if not_post: raise ValueError( f"post_periods must all be validated post-treatment " f"periods. Not-post entries: {not_post!r}. Validator's " f"post-period set: {list(t_post_list)!r}." ) # PR #392 R3 P1 (non-terminal base guard + observed-period # base-1 lookup, twin of joint_pretrends_test). Eq 17 anchors # at F-1 and uses Y[F-1] - Y[F-2] as slope; require base == # t_pre_list[-1] AND derive base-1 from t_pre_list[-2]. if trends_lin and base_period != t_pre_list[-1]: raise ValueError( f"joint_homogeneity_test(trends_lin=True) requires " f"base_period to equal the last validated pre-period " f"({t_pre_list[-1]!r}, the canonical Eq 17 anchor " f"F-1). Got base_period={base_period!r}. Anchoring at " f"any other pre-period would compute a different " f"slope and detrending that does not match paper " f"Eq 17 / page 32 or R DIDHAD::did_had(trends_lin=TRUE)." ) if trends_lin and len(t_pre_list) < 2: raise ValueError( f"joint_homogeneity_test(trends_lin=True) requires " f"at least 2 validated pre-periods so the per-group " f"slope Y[g, F-1] - Y[g, F-2] is identified. Got " f"t_pre_list={list(t_pre_list)!r}." ) # Capture the validator's predecessor for downstream use. if trends_lin: base_minus_1_period_validated = t_pre_list[-2] d_arr, dy_by_horizon, _ = _aggregate_for_joint_test( data_filtered, outcome_col=outcome_col, dose_col=dose_col, time_col=time_col, unit_col=unit_col, horizons=list(post_periods), base_period=base_period, ) G = d_arr.shape[0] # HAD invariant for the homogeneity path: base_period has D = 0 # (last pre-period contract); each post_period has D > 0 for SOME # unit (existence) and is NOT identically zero across all units # (reciprocal twin of the pretrends guard - an all-zero post-period # contradicts the HAD treatment-onset contract). base_doses = data_filtered.loc[data_filtered[time_col] == base_period, dose_col] if (base_doses != 0).any(): n_nonzero = int((base_doses != 0).sum()) raise ValueError( f"base_period={base_period!r} must have D = 0 across every " f"unit (HAD last-pre-period invariant). Found {n_nonzero} " f"non-zero dose observation(s) in base_period." ) for t in post_periods: post_doses = data_filtered.loc[data_filtered[time_col] == t, dose_col] if not (post_doses > 0).any(): raise ValueError( f"post_period={t!r} has D = 0 for every unit. HAD " f"contract requires at least some unit to have D > 0 " f"in each post-period (reciprocal of the pre-period " f"zero-dose invariant)." ) # ---- Apply trends_lin detrending (paper Eq 17 / page 32 joint-Stute # post-period homogeneity null with industry-specific linear trends). # Twin of joint_pretrends_test detrending: per-group slope from # Y[g, base] - Y[g, base-1], applied to each post-period horizon's # dy_t. The post-period delta = t_rank - base_rank > 0, so the # subtraction extrapolates the linear trend FORWARD into post-periods. if trends_lin: # PR #392 R3 P1: use the validator's t_pre_list[-2] as the # predecessor (captured above as base_minus_1_period_validated). # This is robust to ordered-categorical panels with unused # intermediate levels because the validator builds t_pre_list # from observed contiguous pre-periods, not the full dtype # category list. base_minus_1_period_h = base_minus_1_period_validated slope_subset_h = data_filtered[ data_filtered[time_col].isin([base_period, base_minus_1_period_h]) ] wide_y_h = slope_subset_h.pivot(index=unit_col, columns=time_col, values=outcome_col) wide_y_h = wide_y_h.sort_index() if wide_y_h[base_period].isna().any() or wide_y_h[base_minus_1_period_h].isna().any(): raise ValueError( f"joint_homogeneity_test(trends_lin=True): NaN value(s) " f"in outcome at base_period={base_period!r} or " f"base_period-1={base_minus_1_period_h!r}. The slope " f"estimator requires complete observations at both " f"periods for every unit." ) slope_h = wide_y_h[base_period].to_numpy(dtype=np.float64) - wide_y_h[ base_minus_1_period_h ].to_numpy(dtype=np.float64) # PR #392 R4 P0: build the detrending rank from OBSERVED # periods on data_filtered (matching HAD.fit). Twin of # joint_pretrends_test fix. observed_rank_h = { p: i for i, p in enumerate( sorted( set(data_filtered[time_col].unique()), key=lambda p: period_rank[p], ) ) } base_rank_observed_h = observed_rank_h[base_period] for t in post_periods: label = str(t) delta = observed_rank_h[t] - base_rank_observed_h # > 0 for post-periods dy_by_horizon[label] = dy_by_horizon[label] - delta * slope_h # Phase 4.5 C: aggregate weights/survey to per-unit; thread through. # R2 P1 fix: subset row-level `weights` to data_filtered's rows BEFORE # resolution, mirroring did_had_pretest_workflow / joint_pretrends_test # for staggered last-cohort filtering. # R9 P1 fix: validate 1D + length-matched-to-data BEFORE subsetting. weights_for_resolve = weights if weights is not None: weights_arr = np.asarray(weights, dtype=np.float64) if weights_arr.ndim != 1: raise ValueError( f"joint_homogeneity_test: weights must be 1-dimensional, got " f"shape {weights_arr.shape}." ) if weights_arr.shape[0] != len(data): raise ValueError( f"joint_homogeneity_test: weights length {weights_arr.shape[0]} " f"does not match data length {len(data)}." ) if len(data_filtered) != len(data): pos_idx = data.index.get_indexer(data_filtered.index) if (pos_idx < 0).any(): raise ValueError( "joint_homogeneity_test: cannot align row-level weights to " "the staggered-filtered panel; some data_filtered rows do " "not appear in original data.index." ) weights_for_resolve = weights_arr[pos_idx] else: weights_for_resolve = weights_arr weights_unit, resolved_unit = _resolve_pretest_unit_weights( data_filtered, unit_col, weights_for_resolve, survey, "joint_homogeneity_test" ) w_eff = resolved_unit.weights if resolved_unit is not None else weights_unit residuals_by_horizon: Dict[str, np.ndarray] = {} fitted_by_horizon: Dict[str, np.ndarray] = {} for label, dy_t in dy_by_horizon.items(): if w_eff is None: a_hat, b_hat, residuals_t = _fit_ols_intercept_slope(d_arr, dy_t) else: a_hat, b_hat, residuals_t = _fit_weighted_ols_intercept_slope(d_arr, dy_t, w_eff) fitted_t = a_hat + b_hat * d_arr residuals_by_horizon[label] = residuals_t fitted_by_horizon[label] = fitted_t design_matrix = np.column_stack([np.ones(G, dtype=np.float64), d_arr.astype(np.float64)]) # Internal forwarding via canonical kwarg (avoids deprecation warning). if resolved_unit is not None: joint_survey_design = resolved_unit elif weights_unit is not None: joint_survey_design = make_pweight_design(weights_unit) else: joint_survey_design = None return stute_joint_pretest( residuals_by_horizon=residuals_by_horizon, fitted_by_horizon=fitted_by_horizon, doses=d_arr, design_matrix=design_matrix, alpha=alpha, n_bootstrap=n_bootstrap, seed=seed, null_form="linearity", survey_design=joint_survey_design, )
_VALID_AGGREGATES = ("overall", "event_study") _QUG_DEFERRED_SUFFIX = ( " (linearity-conditional verdict; QUG-under-survey deferred per Phase 4.5 C0)" ) def _compose_verdict_overall_survey( stute: Optional[StuteTestResults], yatchew: Optional[YatchewTestResults], ) -> str: """Build the overall-path :class:`HADPretestReport` verdict on the survey/weights branch (Phase 4.5 C). Drops the QUG step from consideration (skipped per Phase 4.5 C0) and composes the verdict from Stute + Yatchew alone, with the linearity-conditional suffix appended in every branch. R7 P1 fix: explicit survey-aware composer replaces the prior approach of composing the unweighted verdict with a NaN QUG and string-replacing the resulting "QUG NaN" suffix, which could leave pass cases starting with "inconclusive". Priority (mirrors :func:`_compose_verdict` minus QUG): 1. Conclusive rejections of Stute or Yatchew lead. 2. No conclusive rejection but linearity inconclusive (both NaN) -> "inconclusive - both linearity tests NaN". 3. Linearity conclusive (at least one of Stute/Yatchew finite) AND no rejection -> fail-to-reject string. All branches end with `_QUG_DEFERRED_SUFFIX`. """ stute_ok = stute is not None and bool(np.isfinite(stute.p_value)) yatchew_ok = yatchew is not None and bool(np.isfinite(yatchew.p_value)) stute_rej = stute_ok and bool(stute.reject) yatchew_rej = yatchew_ok and bool(yatchew.reject) reasons = [] if stute_rej: reasons.append("linearity rejected - heterogeneity bias (Stute)") if yatchew_rej: reasons.append("linearity rejected - heterogeneity bias (Yatchew)") unresolved = [] if not stute_ok: unresolved.append("Stute NaN") if not yatchew_ok: unresolved.append("Yatchew NaN") if reasons: verdict = "; ".join(reasons) if unresolved: verdict += "; additional steps unresolved: " + "; ".join(unresolved) return verdict + _QUG_DEFERRED_SUFFIX # No rejections. if not (stute_ok or yatchew_ok): return "inconclusive - both Stute and Yatchew linearity tests NaN" + _QUG_DEFERRED_SUFFIX # At least one linearity test conclusive AND no rejection. skipped_note = "" if not stute_ok: skipped_note = " (Stute NaN - skipped)" elif not yatchew_ok: skipped_note = " (Yatchew NaN - skipped)" return ( "Stute and Yatchew linearity diagnostics fail-to-reject" + skipped_note + _QUG_DEFERRED_SUFFIX ) def _compose_verdict_event_study_survey( pretrends_joint: Optional[StuteJointResult], homogeneity_joint: Optional[StuteJointResult], ) -> str: """Event-study survey-path verdict (R7 P1 fix; mirrors :func:`_compose_verdict_event_study` minus QUG).""" pretrends_ok = pretrends_joint is not None and bool(np.isfinite(pretrends_joint.p_value)) homogeneity_ok = homogeneity_joint is not None and bool(np.isfinite(homogeneity_joint.p_value)) pretrends_rej = pretrends_joint is not None and pretrends_ok and bool(pretrends_joint.reject) homogeneity_rej = ( homogeneity_joint is not None and homogeneity_ok and bool(homogeneity_joint.reject) ) reasons = [] if pretrends_rej: reasons.append("joint pre-trends rejected - assumption 7 violated (joint Stute)") if homogeneity_rej: reasons.append("joint linearity rejected - heterogeneity bias (joint Stute)") unresolved = [] if pretrends_joint is None: unresolved.append("joint pre-trends skipped (no earlier pre-period)") elif not pretrends_ok: unresolved.append("joint pre-trends NaN") if homogeneity_joint is None: unresolved.append("joint linearity skipped") elif not homogeneity_ok: unresolved.append("joint linearity NaN") if reasons: verdict = "; ".join(reasons) if unresolved: verdict += "; additional steps unresolved: " + "; ".join(unresolved) return verdict + _QUG_DEFERRED_SUFFIX if unresolved: return "inconclusive - " + "; ".join(unresolved) + _QUG_DEFERRED_SUFFIX return "joint pre-trends and joint linearity diagnostics fail-to-reject" + _QUG_DEFERRED_SUFFIX
[docs] def did_had_pretest_workflow( data: pd.DataFrame, outcome_col: str, dose_col: str, time_col: str, unit_col: str, first_treat_col: Optional[str] = None, alpha: float = 0.05, n_bootstrap: int = 999, seed: Optional[int] = None, *, aggregate: str = "overall", survey_design: Any = None, survey: Any = None, weights: Optional[np.ndarray] = None, trends_lin: bool = False, ) -> HADPretestReport: """Run the HAD pre-test workflow (paper Section 4.2-4.3). Two dispatch modes via ``aggregate``: ``aggregate="overall"`` (default, two-period panel): runs paper steps 1 (:func:`qug_test`) and 3 (:func:`stute_test` + :func:`yatchew_hr_test`). Step 2 (Assumption 7 pre-trends) is NOT implemented on this path because a single-pre-period panel cannot support the joint Stute variant; the returned verdict flags the Assumption 7 gap explicitly so callers do not receive a misleading "TWFE safe" signal. For multi-period panels, pass ``aggregate="event_study"`` to close the step-2 gap. ``aggregate="event_study"`` (multi-period panel, >= 3 periods): runs QUG + joint pre-trends Stute + joint homogeneity-linearity Stute, covering paper Section 4 steps 1-3 together. The step-3 Yatchew-HR alternative (a single-horizon swap-in for Stute) is subsumed by joint Stute on this path - the paper does not derive a joint Yatchew variant, so users who need Yatchew robustness under multi-period data should call :func:`yatchew_hr_test` on each ``(base, post)`` pair manually. (Paper step 4 is the decision itself - "use TWFE if none of the tests rejects" - not a separate test, so it has no code path here. Mirrors the framing in the module-level docstring at line 54 and ``_compose_verdict_event_study`` at line 2735.) Eq 17 / Eq 18 linear-trend detrending (paper Section 5.2 Pierce- Schott application) is now SHIPPED on the event-study path via the ``trends_lin`` keyword-only parameter (PR #392 / Phase 4 R-parity). When ``trends_lin=True``, this workflow forwards the flag to both :func:`joint_pretrends_test` and :func:`joint_homogeneity_test`; the consumed placebo at ``base_period - 1`` is auto-dropped from step 2 and the workflow skips step 2 (``pretrends_joint=None``) if no earlier placebo survives. Mirrors R ``DIDHAD::did_had(..., trends_lin=TRUE)``. Mutually exclusive with ``aggregate="overall"`` (raises ``NotImplementedError``). Parameters ---------- data : pd.DataFrame HAD panel. For ``aggregate="overall"``: balanced two-period panel with pre-period dose = 0 for every unit. For ``aggregate="event_study"``: balanced multi-period panel with >= 3 periods, an ordered time dtype (numeric, datetime, or ordered categorical), and the pre-period D=0 invariant across all pre-periods. outcome_col, dose_col, time_col, unit_col : str first_treat_col : str or None, default None Optional first-treatment-period column. Required on the ``aggregate="event_study"`` path when the panel is staggered (multi-cohort); the panel validator auto-filters to the last cohort and emits ``UserWarning``. The overall path uses this for cross-validation only. alpha : float, default 0.05 n_bootstrap : int, default 999 Replication count for the single-horizon Stute (overall) or joint Stute (event_study). seed : int or None, default None Seed forwarded to the Stute bootstrap. QUG / Yatchew are deterministic. aggregate : str, keyword-only, default ``"overall"`` Dispatch mode. Invalid values raise ``ValueError``. survey_design : SurveyDesign or None, keyword-only, default None Survey design for design-based pretest inference. Linearity-family pretests use PSU-level Mammen multiplier bootstrap (Stute family) and weighted OLS + weighted variance components (Yatchew). The QUG step is skipped under survey with a ``UserWarning`` (permanent deferral per Phase 4.5 C0). Replicate-weight designs raise ``NotImplementedError``. Mutually exclusive with the deprecated ``survey=`` and ``weights=`` aliases. survey : SurveyDesign or None, keyword-only, default None DEPRECATED alias of ``survey_design=``. Will be removed in the next minor release; prefer ``survey_design=``. weights : np.ndarray or None, keyword-only, default None DEPRECATED alias for the per-row pweight shortcut. Prefer adding the weights as a column on ``data`` and passing ``survey_design=SurveyDesign(weights='col_name')`` instead. Will be removed in the next minor release. Currently routed through a synthetic trivial ``ResolvedSurveyDesign`` so the same kernel handles both paths. trends_lin : bool, default False, keyword-only Forwards into :func:`joint_pretrends_test` and :func:`joint_homogeneity_test` on the event-study dispatch path. Mirrors R ``DIDHAD::did_had(..., trends_lin=TRUE)``. Requires ``aggregate="event_study"``; raises ``NotImplementedError`` on ``aggregate="overall"`` (the overall path's qug + stute + yatchew block has no joint-pretest surface). Mutually exclusive with survey weighting at the joint-pretest layer; the joint wrappers raise ``NotImplementedError`` if combined. **Effective step-2 rule under trends_lin**: the consumed placebo at ``base_period - 1`` is dropped before step 2 is dispatched; if no earlier placebo survives the drop (e.g., a minimal 4-period panel with ``F=3`` where ``base_period=2`` and the only earlier placebo at ``t=1`` is the consumed one), step 2 is skipped (``pretrends_joint=None``) and the workflow proceeds to step 3 (homogeneity). Default ``False`` preserves bit-exact backcompat. Returns ------- HADPretestReport On the overall path: ``stute`` and ``yatchew`` populated, ``pretrends_joint`` / ``homogeneity_joint`` are ``None``. On the event-study path: ``pretrends_joint`` (``None`` if no earlier pre-period) and ``homogeneity_joint`` populated, ``stute`` / ``yatchew`` are ``None``. ``aggregate`` is recorded on the report for serialization dispatch. On the survey/weights path, ``qug`` is ``None`` (Phase 4.5 C0 deferral); other components populated as on the unweighted path. Raises ------ ValueError On invalid ``aggregate``; if more than one of ``survey_design``, ``survey``, ``weights`` is supplied (3-way mutex; ``survey=`` and ``weights=`` are deprecated aliases of ``survey_design=``); or any downstream front-door failure (panel balance, dtype, dose invariant). NotImplementedError If ``survey.replicate_weights is not None`` (replicate-weight pretests deferred to a parallel follow-up after Phase 4.5 C). Notes ----- Survey/weighted data (Phase 4.5 C): under ``survey=`` or ``weights=``, the workflow: 1. **Skips QUG** with a ``UserWarning`` and sets ``qug=None`` on the report. QUG-under-survey is permanently deferred per Phase 4.5 C0; extreme-order-statistic tests are not smooth functionals of the empirical CDF and have no off-the-shelf survey-aware analog. See :func:`qug_test` Notes for the full methodology rationale. 2. **Runs the linearity family** with the survey-aware mechanism (PSU-level Mammen multiplier bootstrap for Stute / joint variants; weighted OLS + weighted variance components for Yatchew) routed via the existing kernels. 3. **Verdict** carries a ``"linearity-conditional verdict; QUG-under- survey deferred per Phase 4.5 C0"`` suffix to remind callers that admissibility is conditional on the linearity family alone. 4. **`all_pass`** drops the QUG-conclusiveness gate (one less precondition). The linearity-conditional rule splits by aggregate: - ``aggregate="overall"`` survey: ``True`` iff at least one of Stute/Yatchew is conclusive AND no conclusive test rejects (paper Section 4 step-3 "Stute OR Yatchew" wording). - ``aggregate="event_study"`` survey: ``True`` iff ``pretrends_joint`` is non-None and conclusive, ``homogeneity_joint`` is conclusive, AND neither rejects. Both joint variants must be conclusive on the event-study path (same step-2 + step-3 closure as the unweighted aggregate, just without the QUG step). Sister pretests are unchanged on the workflow path; direct callers can also pass ``weights=`` / ``survey=`` to :func:`stute_test`, :func:`yatchew_hr_test`, etc. (Phase 4.5 C extends each helper's signature). Per-unit constant-within-unit invariant on weights / strata / psu / fpc is enforced by the workflow via :func:`diff_diff.had._aggregate_unit_weights` / :func:`diff_diff.had._aggregate_unit_resolved_survey`. References ---------- de Chaisemartin et al. (2026), Section 4.2-4.3, Theorem 4, Appendix D, Theorem 7. """ if aggregate not in _VALID_AGGREGATES: raise ValueError( f"aggregate must be one of {list(_VALID_AGGREGATES)!r}; " f"got {aggregate!r}." ) # ---- trends_lin scope gate (PR #392 R1 P1). # `trends_lin=True` is only meaningful on the event-study path because # it forwards into joint_pretrends_test / joint_homogeneity_test. The # overall path runs qug + stute + yatchew on a 2-period panel and has # no joint-pretest surface to receive the kwarg. Front-door reject # rather than silently ignore. if trends_lin and aggregate != "event_study": raise NotImplementedError( "did_had_pretest_workflow(trends_lin=True) requires " "aggregate='event_study' (the trends_lin kwarg forwards " "into the joint pretests, which only run on the event-" "study path). The overall path's qug + stute + yatchew " "block has no per-group slope surface; pass a multi-" "period panel and aggregate='event_study'." ) # Three-way mutex on survey_design / survey / weights (data-in pattern). # R6 P1 fix: do NOT call _resolve_pretest_unit_weights on the FULL panel # here -- under aggregate='event_study' the panel may be staggered and the # cohort filter at _validate_multi_period_panel can drop units. If those # dropped units have zero/invalid weights, eager full-panel resolution # would abort an otherwise-valid event-study run. Defer resolution to the # per-aggregate branches: overall path resolves on the original data (no # filtering); event-study path lets the joint wrappers handle resolution # on data_filtered. n_set = sum(x is not None for x in (survey_design, survey, weights)) if n_set > 1: raise ValueError(HAD_DUAL_KNOB_MUTEX_MSG_DATA_IN) # Soft deprecation: route legacy survey=/weights= aliases to survey_design=. # The internal back-end paths (_resolve_pretest_unit_weights + per-aggregate # dispatch) consume `survey` and `weights` as internal variable names, so # rebind both for back-compat with the unchanged downstream logic. The # bit-exact regression invariant is preserved because we only rebind names, # not values. if survey is not None: warnings.warn(HAD_DEPRECATION_MSG_SURVEY_KWARG, DeprecationWarning, stacklevel=2) survey_design = survey elif weights is not None: warnings.warn( HAD_DEPRECATION_MSG_WEIGHTS_KWARG_DATA_IN, DeprecationWarning, stacklevel=2, ) # weights= shortcut preserved as-is on the back end. Don't rebind # survey_design -- the array is not a SurveyDesign. # Internal alias rebind: downstream code uses `survey` (when set, a # SurveyDesign or pre-resolved). Map the canonical input back so the # unchanged downstream `if survey is not None:` branches consume it. if survey_design is not None and survey is None: survey = survey_design use_survey_path = (survey is not None) or (weights is not None) if use_survey_path: # Phase 4.5 C0 deferral surface: skip QUG with educational warning. warnings.warn( "did_had_pretest_workflow: QUG step skipped under survey/weights " "(permanently deferred per Phase 4.5 C0; extreme-value theory " "under complex sampling is not a settled toolkit). Verdict " "reflects the linearity family only ('linearity-conditional').", UserWarning, stacklevel=2, ) if aggregate == "event_study": F, t_pre_list, t_post_list, data_filtered, _filter_info = _validate_multi_period_panel( data, outcome_col=outcome_col, dose_col=dose_col, time_col=time_col, unit_col=unit_col, first_treat_col=first_treat_col, ) # `_validate_multi_period_panel` delegates to # `_validate_had_panel_event_study`, which already emits its own # `UserWarning` on the staggered-filter path; we do NOT warn a # second time here (R4 code-quality fix - single emission point). # Base period for both joint tests is the last pre-period # (paper convention: anchor at F-1 under natural time order). # This is t_pre_list[-1] - NOT an arithmetic F-1, since the # time column may be non-integer (datetime, ordered categorical). base_period = t_pre_list[-1] # Step 1: QUG on dose distribution at F. Doses are # time-invariant in HAD, so D_g at F equals max_t D_{g,t}. # Phase 4.5 C: skipped under survey/weights (qug_res = None). doses_at_F = ( data_filtered.loc[data_filtered[time_col] == F, [unit_col, dose_col]] .set_index(unit_col) .sort_index()[dose_col] .to_numpy(dtype=np.float64) ) qug_res = None if use_survey_path else qug_test(doses_at_F, alpha=alpha) # Phase 4.5 C: forward weights/survey to the joint helpers. The # data-in wrappers handle their own per-row → per-unit aggregation # via _resolve_pretest_unit_weights internally on `data_filtered`. # R1 P1 fix: subset row-level `weights` to data_filtered's rows # BEFORE passing through. Otherwise on staggered panels (where # _validate_multi_period_panel auto-filters to last cohort), # the wrappers would call _aggregate_unit_weights(data_filtered, # weights[full_panel_length], ...) and crash on length mismatch. # Mirrors HeterogeneousAdoptionDiD.fit()'s positional-index # subsetting via `data.index.get_indexer(data_filtered.index)`. # `survey=` carries column references resolved internally on # data_filtered, so no subsetting needed there. if use_survey_path and weights is not None: # R9 P1: validate 1D + length-matched-to-data BEFORE staggered- # panel subsetting. Otherwise oversized arrays would be # silently truncated and undersized arrays would surface raw # NumPy indexing errors. weights_arr = np.asarray(weights, dtype=np.float64) if weights_arr.ndim != 1: raise ValueError( "did_had_pretest_workflow: weights must be 1-dimensional, " f"got shape {weights_arr.shape}. (A common mistake is " "passing df[['w']].to_numpy() which produces (N, 1); " "use df['w'].to_numpy() for (N,).)" ) if weights_arr.shape[0] != len(data): raise ValueError( f"did_had_pretest_workflow: weights length " f"{weights_arr.shape[0]} does not match data length " f"{len(data)}." ) pos_idx = data.index.get_indexer(data_filtered.index) if (pos_idx < 0).any(): raise ValueError( "did_had_pretest_workflow: cannot align row-level " "weights to the staggered-filtered panel " "(some data_filtered rows do not appear in original " "data.index). This is a bug; please report." ) joint_weights = weights_arr[pos_idx] else: joint_weights = None joint_survey = survey if use_survey_path and survey is not None else None # Step 2: joint pre-trends on earlier pre-periods (those # strictly before base_period). If only the base pre-period is # available (len(t_pre_list) == 1), there are no earlier # placebos; set pretrends_joint=None and flag in verdict. # ``t_pre_list`` is returned chronologically sorted by # ``_validate_had_panel_event_study`` (using the column's # ordered-categorical category order or the natural numeric / # datetime order), so taking everything but the last element # gives the earlier pre-periods regardless of dtype. Raw # ``t < base_period`` would misorder ordered-categorical labels # whose lexical and chronological order disagree (e.g. "q10" < # "q2" lexically but > chronologically). earlier_pre = list(t_pre_list[:-1]) # PR #392 R2 P1: under trends_lin=True, the consumed placebo at # base_period - 1 (= t_pre_list[-2] in the contiguous validated # pre-period list) is dropped by joint_pretrends_test downstream # because its detrended residual is mechanically zero. Pre-filter # it here so we can preserve the EXISTING "no earlier placebo → # pretrends_joint=None, skip step 2" verdict path (rather than # propagating joint_pretrends_test's `ValueError("no testable # placebo horizons remain")` and aborting the whole workflow). # The minimal valid trends_lin event-study panel (4 periods, # F=3, base=2, only earlier placebo at 1 = the consumed one) # hits this path; the workflow should still run step 3 # (homogeneity) and emit the standard "step 2 skipped" verdict. if trends_lin and len(t_pre_list) >= 2: consumed_placebo_period = t_pre_list[-2] earlier_pre = [t for t in earlier_pre if t != consumed_placebo_period] # PR #376 R2 P3: when `weights=joint_weights` is forwarded to the joint # wrappers (the only joint-internal entry that takes a numpy array), # the wrapper would re-emit a DeprecationWarning. Suppress those # nested warnings — the user-facing warning has already fired at the # workflow's front door above. survey_design=joint_survey is a # SurveyDesign (column-referencing) on the survey path and goes # through canonically; only the weights= forwarding path needs the # suppression. The joint wrappers also can't accept a pre-resolved # ResolvedSurveyDesign (their `_resolve_pretest_unit_weights` requires # a SurveyDesign with .resolve()), so converting weights= to # survey_design= via make_pweight_design isn't an option here. with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) if len(earlier_pre) >= 1: pretrends_joint = joint_pretrends_test( data_filtered, outcome_col=outcome_col, dose_col=dose_col, time_col=time_col, unit_col=unit_col, pre_periods=earlier_pre, base_period=base_period, first_treat_col=first_treat_col, alpha=alpha, n_bootstrap=n_bootstrap, seed=seed, survey_design=joint_survey, weights=joint_weights, trends_lin=trends_lin, ) else: pretrends_joint = None # Step 3: joint homogeneity-linearity on post-periods. homogeneity_joint = joint_homogeneity_test( data_filtered, outcome_col=outcome_col, dose_col=dose_col, time_col=time_col, unit_col=unit_col, post_periods=list(t_post_list), base_period=base_period, first_treat_col=first_treat_col, alpha=alpha, n_bootstrap=n_bootstrap, seed=seed, survey_design=joint_survey, weights=joint_weights, trends_lin=trends_lin, ) # Event-study `all_pass`. On the unweighted path, every implemented # step must be conclusive AND none reject (Phase 3 convention). On # the survey/weights path, drop the QUG-conclusiveness condition # (qug=None per Phase 4.5 C0 deferral); admissibility becomes # linearity-conditional. pretrends_ok = pretrends_joint is not None and bool(np.isfinite(pretrends_joint.p_value)) homogeneity_ok = bool(np.isfinite(homogeneity_joint.p_value)) if use_survey_path: all_pass = bool( pretrends_ok and pretrends_joint is not None and not pretrends_joint.reject and homogeneity_ok and not homogeneity_joint.reject ) # R7 P1 fix: explicit survey-aware verdict composer instead # of post-processing the unweighted-verdict output (the # previous string-replace approach could leave pass cases # starting with "inconclusive" even when all_pass=True). verdict = _compose_verdict_event_study_survey(pretrends_joint, homogeneity_joint) else: qug_ok = bool(np.isfinite(qug_res.p_value)) all_pass = bool( qug_ok and pretrends_ok and pretrends_joint is not None and not pretrends_joint.reject and homogeneity_ok and not homogeneity_joint.reject and not qug_res.reject ) verdict = _compose_verdict_event_study(qug_res, pretrends_joint, homogeneity_joint) return HADPretestReport( qug=qug_res, stute=None, yatchew=None, all_pass=all_pass, verdict=verdict, alpha=alpha, n_obs=int(doses_at_F.shape[0]), pretrends_joint=pretrends_joint, homogeneity_joint=homogeneity_joint, aggregate="event_study", ) # aggregate == "overall" - Phase 3 behavior on the unweighted path # (bit-exact regression preserved); Phase 4.5 C survey path skips QUG # and dispatches stute / yatchew with weights=/survey=. t_pre, t_post = _validate_had_panel( data, outcome_col, dose_col, time_col, unit_col, first_treat_col ) d_arr, dy_arr, _, _ = _aggregate_first_difference( data, outcome_col, dose_col, time_col, unit_col, t_pre, t_post, cluster_col=None, # pretests do not use cluster-robust SE ) # R6 P1 fix: resolve weights/survey HERE (overall path operates on # the original data; no cohort filter to interact with). weights_unit, resolved_unit = _resolve_pretest_unit_weights( data, unit_col, weights, survey, "did_had_pretest_workflow" ) qug_res = None if use_survey_path else qug_test(d_arr, alpha=alpha) # Forward weights/survey to per-test calls. The data-in workflow has # already aggregated to per-unit (weights_unit / resolved_unit); the # _aggregate_first_difference call above also collapses to per-unit # (one row per unit), so weights_unit and resolved_unit are aligned. # Internal forwarding uses the canonical survey_design= kwarg to skip # deprecation warnings; the user-facing warning has already fired at the # workflow's front door. if resolved_unit is not None: per_test_survey_design = resolved_unit elif weights_unit is not None: per_test_survey_design = make_pweight_design(weights_unit) else: per_test_survey_design = None stute_res = stute_test( d_arr, dy_arr, alpha=alpha, n_bootstrap=n_bootstrap, seed=seed, survey_design=per_test_survey_design, ) # Linearity null is correct for the workflow's post-treatment Yatchew step # (paper Theorem 7); placebo mean-independence routing lives in the # R-parity test, not the workflow. yatchew_res = yatchew_hr_test( d_arr, dy_arr, alpha=alpha, survey_design=per_test_survey_design, ) # `all_pass` must be conclusive under the paper's four-step workflow # (step 1 QUG + step 3 linearity via Stute OR Yatchew): # - QUG must produce a finite p-value (step 1 is required). # - At least ONE of Stute / Yatchew must produce a finite p-value # (step 3 accepts either; the paper's wording is "Stute OR # Yatchew"). This accommodates common QUG-style panels with # repeated d=0 units, where Yatchew's duplicate-dose guard trips # but Stute's tie-safe CvM still produces a conclusive result. # - No conclusive test may reject. NaN-p tests have reject=False by # convention, so the OR across `.reject` naturally counts only # the conclusive rejections. # On the survey path, drop QUG conclusiveness (qug=None per C0 deferral). linearity_conclusive = bool(np.isfinite(stute_res.p_value) or np.isfinite(yatchew_res.p_value)) if use_survey_path: any_reject = stute_res.reject or yatchew_res.reject all_pass = bool(linearity_conclusive and not any_reject) # R7 P1 fix: explicit survey-aware verdict composer. verdict = _compose_verdict_overall_survey(stute_res, yatchew_res) else: qug_conclusive = bool(np.isfinite(qug_res.p_value)) any_reject = qug_res.reject or stute_res.reject or yatchew_res.reject all_pass = bool(qug_conclusive and linearity_conclusive and not any_reject) verdict = _compose_verdict(qug_res, stute_res, yatchew_res) return HADPretestReport( qug=qug_res, stute=stute_res, yatchew=yatchew_res, all_pass=all_pass, verdict=verdict, alpha=alpha, n_obs=int(d_arr.shape[0]), aggregate="overall", )