Source code for diff_diff.pretrends

"""
Pre-trends power analysis for difference-in-differences designs.

This module implements the power analysis framework from Roth (2022) for assessing
the informativeness of pre-trends tests. It answers the question: "If my pre-trends
test passed, what violations would I have been able to detect?"

Key concepts:
- **Minimum Detectable Violation (MDV)**: The smallest pre-trends violation that
  would be detected with given power (e.g., 80%).
- **Power of Pre-Trends Test**: Probability of rejecting parallel trends given
  a specific violation pattern.
- **Relationship to HonestDiD**: If MDV is large relative to your estimated effect,
  a passing pre-trends test provides limited reassurance.

References
----------
Roth, J. (2022). Pretest with Caution: Event-Study Estimates after Testing for
    Parallel Trends. American Economic Review: Insights, 4(3), 305-322.
    https://doi.org/10.1257/aeri.20210236

See Also
--------
https://github.com/jonathandroth/pretrends - R package implementation
diff_diff.honest_did - Sensitivity analysis for parallel trends violations
"""

import warnings
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

import numpy as np
import pandas as pd
from scipy import optimize, stats

from diff_diff.results import MultiPeriodDiDResults


def _compute_nis_acceptance_prob(
    M: float,
    weights: np.ndarray,
    vcov: np.ndarray,
    z_alpha: float,
) -> float:
    """
    Compute the NIS box acceptance probability ``P(β̂_pre ∈ B_NIS(Σ))``.

    Used by both ``PreTrendsPower._compute_power_nis`` and
    ``PreTrendsPowerResults.power_at()`` to avoid code duplication and
    centralize the analytical-or-MC fallback path.

    Returns
    -------
    accept_prob : float
        Acceptance probability in [0, 1]. Always finite — falls back to
        Monte Carlo (N=20000) if the analytical scipy MVN CDF raises OR
        returns a non-finite value (e.g., on numerically degenerate Σ).
    """
    sigma = np.sqrt(np.maximum(np.diag(vcov), 0))
    delta = M * weights
    upper = z_alpha * sigma - delta
    lower = -z_alpha * sigma - delta

    accept_prob: float
    try:
        accept_prob = float(
            stats.multivariate_normal.cdf(  # type: ignore[arg-type]
                upper,
                lower_limit=lower,
                mean=np.zeros(len(weights)),
                cov=vcov,
                allow_singular=True,
            )
        )
    except (ValueError, np.linalg.LinAlgError):
        accept_prob = float("nan")

    # MC fallback on non-finite analytical output. The scipy CDF can return
    # nan on numerically degenerate Σ even when no exception is raised
    # (Genz algorithm internal cancellation); detecting nan and falling
    # back to simulation keeps the downstream MDV solver from silently
    # propagating nan and returning a wrong-but-finite MDV.
    if not np.isfinite(accept_prob):
        rng = np.random.default_rng(0)
        samples = rng.multivariate_normal(mean=np.zeros(len(weights)), cov=vcov, size=20000)
        in_box = np.all((samples >= lower[None, :]) & (samples <= upper[None, :]), axis=1)
        accept_prob = float(in_box.mean())

    return float(np.clip(accept_prob, 0.0, 1.0))


def _coerce_relative_times_from_reference(
    estimated_pre_periods: List[Any],
    reference_period: Any,
) -> Optional[np.ndarray]:
    """
    Convert ``estimated_pre_periods`` to Roth-style relative-time offsets
    from a numeric / Period / datetime ``reference_period``.

    Returns ``np.ndarray`` of float relative times when conversion succeeds,
    or ``None`` when the labels are genuinely non-numeric / unordered
    (string period IDs, categoricals, etc.). In the ``None`` case, the
    caller's downstream linear-violation weight construction falls back to
    the legacy count-based normalized direction — the reported MDV is then
    NOT in Roth's γ units. We emit a ``UserWarning`` so the user knows
    the γ-unit contract did not hold and can re-fit with numeric labels.

    Supported regimes:

    - Numeric (``int`` / ``float`` / ``np.int64``): direct ``float()``
      coercion gives the correct relative offset.
    - ``pandas.Period`` / ``pandas.Timestamp`` / ``np.datetime64``: period
      arithmetic returns an offset / ``Timedelta`` that we coerce to a
      float via ``.n`` (for Period frequencies) or ``.days`` (for
      Timedelta-like). The result is in units of the reference's
      frequency for Period, days for Timestamp / datetime64 — the linear
      γ-units scale is per-unit-of-frequency.
    - Anything else (string period IDs, categoricals with no ordering,
      mixed types): returns ``None`` with a warning.
    """
    # Path 1: direct float coercion (numeric scalars).
    try:
        ref_float = float(reference_period)
        return np.asarray(
            [float(p) - ref_float for p in estimated_pre_periods],
            dtype=float,
        )
    except (TypeError, ValueError):
        pass

    # Path 2: pandas.Period / pandas.Timestamp / datetime64 — try
    # subtraction-based offset arithmetic.
    try:
        diffs = [p - reference_period for p in estimated_pre_periods]
        floats: List[float] = []
        for d in diffs:
            # pandas.tseries.offsets.* or pandas.Period offset — has `.n`.
            n_attr = getattr(d, "n", None)
            if n_attr is not None:
                floats.append(float(n_attr))
                continue
            # pandas.Timedelta / numpy.timedelta64 — convert to days.
            days_attr = getattr(d, "days", None)
            if days_attr is not None:
                floats.append(float(days_attr))
                continue
            # Bare numpy.timedelta64 fallback.
            try:
                floats.append(float(d / np.timedelta64(1, "D")))
                continue
            except (TypeError, ValueError):
                raise TypeError(
                    f"cannot coerce difference {d!r} of type {type(d).__name__} "
                    "to float days/periods"
                )
        return np.asarray(floats, dtype=float)
    except (TypeError, ValueError):
        pass

    # Path 3: genuinely non-numeric labels — warn and fall back to legacy.
    warnings.warn(
        f"PreTrendsPower: reference_period {reference_period!r} (type "
        f"{type(reference_period).__name__}) is not numeric or datetime-like, "
        "so per-period relative times cannot be derived. Linear-violation "
        "weights will use the legacy count-based [n_pre-1, ..., 0]/||·||_2 "
        "direction; the reported MDV is NOT in Roth (2022) γ units. Re-fit "
        "with numeric period labels (int year, pandas.Period, datetime) to "
        "obtain γ-unit MDV.",
        UserWarning,
        stacklevel=3,
    )
    return None


def _extract_event_study_vcov_subblock(
    results: Any,
    pre_periods: List[int],
    ses: np.ndarray,
) -> Tuple[np.ndarray, str]:
    """
    Extract the pre-period sub-block of ``results.event_study_vcov`` when
    available; otherwise fall back to ``diag(ses**2)``.

    This is the canonical Σ_22 routing path for ``compute_pretrends_power``
    when the event-study result type exposes a full event-study covariance
    matrix (CallawaySantAnnaResults non-bootstrap fits at
    ``staggered_results.py:126-128`` and SunAbrahamResults non-bootstrap
    fits via the W-matrix construction added in PR-B Step 3). Bootstrap
    fits and replicate-weight survey fits clear ``event_study_vcov`` so
    the analytical VCV is not mixed with bootstrap / replicate SE
    overrides — those cases naturally fall through to the diag fallback.

    Parameters
    ----------
    results : event-study results object
        Must have ``event_study_vcov`` and ``event_study_vcov_index``
        attributes (CallawaySantAnnaResults and SunAbrahamResults both
        expose them; either may be None for the bootstrap / replicate
        paths).
    pre_periods : list of int
        Sorted relative-time labels of the pre-period coefficients to
        extract.
    ses : np.ndarray
        Per-period standard errors (used for the ``diag(ses**2)`` fallback
        path; must be in the same order as ``pre_periods``).

    Returns
    -------
    vcov : np.ndarray
        The (n_pre, n_pre) covariance sub-block. Full event_study_vcov
        sub-block when available; diag(ses**2) otherwise.
    source : str
        Provenance label for downstream report-layer tier classification:
        ``"full_pre_period_vcov"`` when the full event-study sub-block
        was used (no off-diagonal information was discarded), or
        ``"diag_fallback"`` when ``event_study_vcov`` was missing /
        cleared (bootstrap / replicate-weight CS or SA paths).
    """
    es_vcov = getattr(results, "event_study_vcov", None)
    es_vcov_index = getattr(results, "event_study_vcov_index", None)
    if es_vcov is None or es_vcov_index is None:
        return np.diag(ses**2), "diag_fallback"

    try:
        indices = [list(es_vcov_index).index(t) for t in pre_periods]
    except ValueError as e:
        # event_study_vcov_index out of sync with the filtered pre_periods.
        # This is a defensive guard — should not happen on the canonical
        # construction paths, but if it does we fail loud rather than
        # silently substituting diag.
        raise ValueError(
            f"event_study_vcov_index is missing one of the pre-period labels "
            f"{pre_periods}; cannot extract sub-block. Available index: "
            f"{list(es_vcov_index)}. Original error: {e}"
        ) from e

    return np.asarray(es_vcov)[np.ix_(indices, indices)], "full_pre_period_vcov"


# =============================================================================
# Results Classes
# =============================================================================


[docs] @dataclass class PreTrendsPowerResults: """ Results from pre-trends power analysis. Attributes ---------- power : float Power to detect the specified violation pattern at given alpha. mdv : float Minimum detectable violation (smallest M detectable at target power). violation_magnitude : float The magnitude of violation tested (M parameter). violation_type : str Type of violation pattern ('linear', 'constant', 'last_period', 'custom'). alpha : float Significance level for the pre-trends test. target_power : float Target power level used for MDV calculation. n_pre_periods : int Number of pre-treatment periods in the event study. test_statistic : float Expected test statistic under the specified violation (Wald only; NaN for NIS fits). critical_value : float Critical value for the pre-trends test. noncentrality : float Non-centrality parameter under the alternative hypothesis (Wald only; NaN for NIS fits). pre_period_effects : np.ndarray Estimated pre-period effects from the event study. pre_period_ses : np.ndarray Standard errors of pre-period effects. vcov : np.ndarray Variance-covariance matrix of pre-period effects. pretest_form : str Pretest acceptance-region form used: ``'nis'`` (no-individually- significant box probability — Roth 2022 Section II.A-B, default for new fits) or ``'wald'`` (noncentral-chi-squared on the quadratic form ``delta' Sigma_22^{-1} delta`` — paper-supported alternative, retained for backwards compatibility with shipped numerical baselines). nis_box_probability : float Acceptance probability ``P(beta_hat_pre in B_NIS(Sigma))`` under the alternative ``M * weights``. NIS-only; NaN for Wald fits. violation_weights : np.ndarray, optional The violation-direction vector used at fit time. Populated for all violation types on fresh fits. Normalization depends on the type so that ``M`` always matches the documented per-pattern contract: - ``linear`` threaded with ``relative_times`` (post PR-B Step 4): ``|t|`` directly, NOT L2-normalized, so ``δ_t = M·|t|`` and the reported MDV equals Roth's γ exactly. - ``linear`` without ``relative_times`` (legacy): ``[n_pre-1, ..., 0]`` L2-normalized. - ``constant`` (post PR-B R13): ``[1, ..., 1]`` directly, NOT L2-normalized, so ``δ_t = M`` is a true per-period level shift. - ``last_period``: ``[0, ..., 0, 1]`` (already unit-norm). - ``custom``: user vector L2-normalized to unit norm. Old serialized results may have ``None`` here; ``power_at()`` falls back to reconstruction in that case (with the PR-A ``NotImplementedError`` guard retained only for ``violation_type='custom'`` with ``violation_weights=None``). """ power: float mdv: float violation_magnitude: float violation_type: str alpha: float target_power: float n_pre_periods: int test_statistic: float critical_value: float noncentrality: float pre_period_effects: np.ndarray = field(repr=False) pre_period_ses: np.ndarray = field(repr=False) vcov: np.ndarray = field(repr=False) original_results: Optional[Any] = field(default=None, repr=False) pretest_form: Literal["nis", "wald"] = "wald" nis_box_probability: float = np.nan violation_weights: Optional[np.ndarray] = field(default=None, repr=False) # Provenance for downstream tier classification. Populated at fit time # from `_extract_pre_period_params`. ``"full_pre_period_vcov"`` when # off-diagonal pre-period covariances were used; ``"diag_fallback"`` # when only per-period SEs were available; ``"unknown"`` for legacy # serialized results pre-PR-B (backwards-compat default). See # ``diagnostic_report._infer_cov_source`` for consumer-side use. covariance_source: str = "unknown" def __repr__(self) -> str: return ( f"PreTrendsPowerResults(power={self.power:.3f}, " f"mdv={self.mdv:.4f}, M={self.violation_magnitude:.4f})" ) @property def is_informative(self) -> bool: """ Check if the pre-trends test is informative. A pre-trends test is considered informative if the MAX level-scale pre-period violation under the MDV is reasonably small relative to the per-period standard errors. Post PR-B Step 4 the `linear` MDV is in Roth's γ units (a slope), so comparing the raw ``mdv`` scalar to the level-scale ``max(pre_period_ses)`` would mix units on irregular pre-period grids. The comparable level-scale scalar is ``mdv * max(|violation_weights|)`` (the largest pre-period deviation under the MDV — see ``max_abs_pre_violation``). """ max_se = np.max(self.pre_period_ses) if len(self.pre_period_ses) > 0 else 1.0 return bool(self.max_abs_pre_violation < 2 * max_se) @property def max_abs_pre_violation(self) -> float: """ Largest level-scale pre-period deviation under the MDV. Returns ``mdv * max(|violation_weights|)`` — the maximum absolute pre-period violation ``δ_t`` when the violation magnitude equals the MDV. This is the right level-scale scalar for comparing pre-trends sensitivity against coefficient-scale quantities (post-treatment ATT, per-period SEs, HonestDiD's M bound). Why this matters: PR-B Step 4 made the linear ``mdv`` report Roth's γ units (a slope on relative time). On a regular grid ``[-3, -2, -1]`` the max deviation is ``γ * 3``; on an irregular grid ``[-5, -3, -1]`` it is ``γ * 5``. Raw ``mdv`` alone cannot be compared to level effects without applying the weight scale. For non-linear violation types under the PR-B R13 level-shift convention: constant weights ``[1, ..., 1]`` (unnormalized) yield ``max_abs_pre_violation = mdv * 1 = mdv`` — raw ``mdv`` IS the per-period level shift, so level- and γ-scales coincide. Last_period ``[0, ..., 0, 1]`` yields ``max_abs_pre_violation = mdv`` for the same reason. Custom uses the L2-normalized user-supplied weight vector, so ``max_abs_pre_violation`` depends on the user's direction. Backwards-compat: legacy serialized results without ``violation_weights`` (pre-PR-B) fall back to the raw ``mdv`` (which under the pre-PR-B count-based L2-normalized linear convention already had a roughly level-scale magnitude). """ if self.violation_weights is None or len(self.violation_weights) == 0: return float(self.mdv) if not np.isfinite(self.mdv): return float(self.mdv) max_w = float(np.max(np.abs(self.violation_weights))) return float(self.mdv * max_w) @property def power_adequate(self) -> bool: """Check if power meets the target threshold.""" return bool(self.power >= self.target_power)
[docs] def summary(self) -> str: """ Generate formatted summary of pre-trends power analysis. Returns ------- str Formatted summary. """ lines = [ "=" * 70, "Pre-Trends Power Analysis Results".center(70), "(Roth 2022)".center(70), "=" * 70, "", f"{'Number of pre-periods:':<35} {self.n_pre_periods}", f"{'Significance level (alpha):':<35} {self.alpha:.3f}", f"{'Target power:':<35} {self.target_power:.1%}", f"{'Violation type:':<35} {self.violation_type}", f"{'Pretest form:':<35} {self.pretest_form}", "", "-" * 70, "Power Analysis".center(70), "-" * 70, f"{'Violation magnitude (M):':<35} {self.violation_magnitude:.4f}", f"{'Power to detect this violation:':<35} {self.power:.1%}", f"{'Minimum detectable violation:':<35} {self.mdv:.4f}", "", f"{'Critical value:':<35} {self.critical_value:.4f}", ] # Dispatch on pretest_form: NIS reports the MVN box acceptance # probability, Wald reports the noncentral-chi-squared noncentrality. if self.pretest_form == "nis": lines.append(f"{'NIS box probability (accept):':<35} {self.nis_box_probability:.4f}") else: lines.append(f"{'Test statistic (expected):':<35} {self.test_statistic:.4f}") lines.append(f"{'Non-centrality parameter:':<35} {self.noncentrality:.4f}") lines.extend( [ "", "-" * 70, "Interpretation".center(70), "-" * 70, ] ) if self.power_adequate: lines.append(f"✓ Power ({self.power:.0%}) meets target ({self.target_power:.0%}).") lines.append( f" The pre-trends test would detect violations of magnitude {self.violation_magnitude:.3f}." ) else: lines.append(f"✗ Power ({self.power:.0%}) below target ({self.target_power:.0%}).") lines.append( f" Would need violations of {self.mdv:.3f} to achieve {self.target_power:.0%} power." ) lines.append("") lines.append(f"Minimum detectable violation (MDV): {self.mdv:.4f}") lines.append(" → Passing pre-trends test does NOT rule out violations up to this size.") lines.extend(["", "=" * 70]) return "\n".join(lines)
[docs] def print_summary(self) -> None: """Print summary to stdout.""" print(self.summary())
[docs] def to_dict(self) -> Dict[str, Any]: """Convert results to JSON-serializable dictionary. Includes the post-PR-B provenance fields (``violation_weights``, ``covariance_source``) so callers that round-trip the result through ``to_dict``/``to_dataframe`` (e.g., for serialization or downstream transport) preserve the same information the reporting layer reads off the dataclass directly. ``violation_weights`` is emitted as ``list[float]`` (or ``None``) so ``json.dumps(result.to_dict())`` works out of the box. Use ``self.violation_weights`` directly on the dataclass when an ndarray is needed. """ weights = self.violation_weights weights_list: Optional[List[float]] if weights is None: weights_list = None else: weights_list = [float(w) for w in np.asarray(weights).ravel()] return { "power": self.power, "mdv": self.mdv, "violation_magnitude": self.violation_magnitude, "violation_type": self.violation_type, "alpha": self.alpha, "target_power": self.target_power, "n_pre_periods": self.n_pre_periods, "test_statistic": self.test_statistic, "critical_value": self.critical_value, "noncentrality": self.noncentrality, "pretest_form": self.pretest_form, "nis_box_probability": self.nis_box_probability, "violation_weights": weights_list, "covariance_source": self.covariance_source, "is_informative": self.is_informative, "power_adequate": self.power_adequate, }
[docs] def to_dataframe(self) -> pd.DataFrame: """Convert results to DataFrame. ``violation_weights`` is stored as a Python list in the single row (pandas-friendly); ``covariance_source`` is a plain string. Mirrors ``to_dict``. """ return pd.DataFrame([self.to_dict()])
[docs] def power_at(self, M: float) -> float: """ Compute power to detect a specific violation magnitude. Uses the stored fitted ``violation_weights`` and the stored ``pretest_form`` to dispatch to the NIS or Wald power computation without re-fitting. Parameters ---------- M : float Violation magnitude to evaluate. Returns ------- float Power to detect violation of magnitude M. Raises ------ NotImplementedError If the result was produced by an older library version (before the ``violation_weights`` field was added to ``PreTrendsPowerResults``) AND ``violation_type='custom'``. The reconstruction fallback can handle ``linear``/``constant``/``last_period`` from stored metadata, but custom weights cannot be reconstructed; refit ``PreTrendsPower(violation_type='custom', violation_weights=...)`` with the new ``M`` instead. """ from scipy import stats n_pre = self.n_pre_periods # Prefer the persisted fitted weights (populated for all violation # types on fresh fits after PR-B). Fall back to reconstruction only # for old serialized results lacking the field. if self.violation_weights is not None: weights = np.asarray(self.violation_weights, dtype=float) else: if self.violation_type == "custom": raise NotImplementedError( "PreTrendsPowerResults.power_at() cannot reconstruct " "custom violation weights from an older serialized result " "(violation_weights field is None). Refit " "PreTrendsPower(violation_type='custom', " "violation_weights=...) with the new M instead. " "Fresh fits from the current library version persist " "violation_weights and do not hit this guard." ) # Reconstruction fallback for legacy serialized results. # Matches the pre-PR-B count-based linear behavior (no # relative_times available on an old result). Only used when # violation_weights is None. if self.violation_type == "linear": weights = np.arange(-n_pre + 1, 1, dtype=float) weights = -weights # [n-1, n-2, ..., 1, 0] elif self.violation_type == "constant": weights = np.ones(n_pre) elif self.violation_type == "last_period": weights = np.zeros(n_pre) weights[-1] = 1.0 else: raise ValueError( f"Unknown violation_type: {self.violation_type!r}. " f"Expected one of: 'linear', 'constant', 'last_period', 'custom'." ) # Normalize to unit L2 norm — matches the legacy normalize-at-end # path in _get_violation_weights for non-relative_times callers. norm = np.linalg.norm(weights) if norm > 0: weights = weights / norm # Dispatch on the stored pretest_form. Old serialized results default # to pretest_form='wald' (the dataclass default) which preserves the # previous power_at numerical output for backwards compat. if self.pretest_form == "nis": z_alpha = float( self.critical_value if np.isfinite(self.critical_value) else stats.norm.ppf(1 - self.alpha / 2) ) # Centralized analytical-or-MC fallback (module-level helper). accept_prob = _compute_nis_acceptance_prob(M, weights, self.vcov, z_alpha) return float(1.0 - accept_prob) # Wald path (legacy default, also opt-in for new fits with # pretest_form='wald'). Matches the pre-PR-B numerical output. try: vcov_inv = np.linalg.inv(self.vcov) except np.linalg.LinAlgError: vcov_inv = np.linalg.pinv(self.vcov) noncentrality = M**2 * (weights @ vcov_inv @ weights) power = 1 - stats.ncx2.cdf(self.critical_value, df=n_pre, nc=noncentrality) return float(power)
[docs] @dataclass class PreTrendsPowerCurve: """ Power curve across violation magnitudes. Attributes ---------- M_values : np.ndarray Grid of violation magnitudes tested. powers : np.ndarray Power at each violation magnitude. mdv : float Minimum detectable violation. alpha : float Significance level. target_power : float Target power level. violation_type : str Type of violation pattern. pretest_form : str Pretest acceptance-region form (``'nis'`` or ``'wald'``) used to compute the curve. NIS and Wald curves can differ materially under correlated Σ_22; persisting the form prevents callers from misinterpreting a serialized/plotted curve. """ M_values: np.ndarray powers: np.ndarray mdv: float alpha: float target_power: float violation_type: str pretest_form: Literal["nis", "wald"] = "wald" def __repr__(self) -> str: return f"PreTrendsPowerCurve(n_points={len(self.M_values)}, " f"mdv={self.mdv:.4f})"
[docs] def to_dataframe(self) -> pd.DataFrame: """Convert to DataFrame with M, power, and pretest_form columns.""" return pd.DataFrame( { "M": self.M_values, "power": self.powers, "pretest_form": self.pretest_form, } )
[docs] def plot( self, ax=None, show_mdv: bool = True, show_target: bool = True, color: str = "#2563eb", mdv_color: str = "#dc2626", target_color: str = "#22c55e", **kwargs, ): """ Plot the power curve. Parameters ---------- ax : matplotlib.axes.Axes, optional Axes to plot on. If None, creates new figure. show_mdv : bool, default=True Whether to show vertical line at MDV. show_target : bool, default=True Whether to show horizontal line at target power. color : str Color for power curve line. mdv_color : str Color for MDV vertical line. target_color : str Color for target power horizontal line. **kwargs Additional arguments passed to plt.plot(). Returns ------- ax : matplotlib.axes.Axes The axes with the plot. """ try: import matplotlib.pyplot as plt except ImportError: raise ImportError("matplotlib is required for plotting") if ax is None: fig, ax = plt.subplots(figsize=(10, 6)) # Plot power curve ax.plot(self.M_values, self.powers, color=color, linewidth=2, label="Power", **kwargs) # Target power line if show_target: ax.axhline( y=self.target_power, color=target_color, linestyle="--", linewidth=1.5, alpha=0.7, label=f"Target power ({self.target_power:.0%})", ) # MDV line if show_mdv and self.mdv is not None and np.isfinite(self.mdv): ax.axvline( x=self.mdv, color=mdv_color, linestyle=":", linewidth=1.5, alpha=0.7, label=f"MDV = {self.mdv:.3f}", ) ax.set_xlabel("Violation Magnitude (M)") ax.set_ylabel("Power") ax.set_title("Pre-Trends Test Power Curve") ax.set_ylim(0, 1.05) ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y:.0%}")) ax.legend(loc="lower right") ax.grid(True, alpha=0.3) return ax
# ============================================================================= # Main Class # =============================================================================
[docs] class PreTrendsPower: """ Pre-trends power analysis (Roth 2022). Computes the power of pre-trends tests to detect violations of parallel trends, and the minimum detectable violation (MDV). Parameters ---------- alpha : float, default=0.05 Significance level for the pre-trends test. power : float, default=0.80 Target power level for MDV calculation. violation_type : str, default='linear' Type of violation pattern to consider: - 'linear': Violations follow a linear trend (most common) - 'constant': Same violation in all pre-periods - 'last_period': Violation only in the last pre-period - 'custom': User-specified violation pattern (via violation_weights) violation_weights : array-like, optional Custom weights for violation pattern. Length must equal number of pre-periods. Only used when violation_type='custom'. pretest_form : {'nis', 'wald'}, default='nis' Pre-trends test acceptance-region form: - ``'nis'``: Roth (2022) no-individually-significant pretest (Section II.A-B). Acceptance region is ``B_NIS(Σ) = { b : |b_t| <= z_{1-α/2} σ_t for all t }``. Power computed via multivariate normal box probability. This is the new default (PR-B 2026-05-17), matching both the paper's primary analysis and the R ``pretrends`` package. - ``'wald'``: Noncentral chi-squared on the quadratic form ``δ' Σ_22^{-1} δ`` (the shipped behavior prior to PR-B 2026-05-17). Retained as a paper-supported alternative under Propositions 1+3+4 (Wald acceptance region is a convex ellipsoid, so all four propositions apply). Use this for backwards-compat with shipped numerical baselines. Examples -------- Basic usage with MultiPeriodDiD results: >>> from diff_diff import MultiPeriodDiD >>> from diff_diff.pretrends import PreTrendsPower >>> >>> # Fit event study >>> mp_did = MultiPeriodDiD() >>> results = mp_did.fit(data, outcome='y', treatment='treated', ... time='period', post_periods=[4, 5, 6, 7]) >>> >>> # Analyze pre-trends power >>> pt = PreTrendsPower(alpha=0.05, power=0.80) >>> power_results = pt.fit(results) >>> print(power_results.summary()) >>> >>> # Get power curve >>> curve = pt.power_curve(results) >>> curve.plot() Notes ----- The pre-trends test is typically a joint test that all pre-period coefficients are zero. This test has limited power to detect small violations, especially when: 1. There are few pre-periods 2. Standard errors are large 3. The violation pattern is smooth (e.g., linear trend) Passing a pre-trends test does NOT mean parallel trends holds. It means violations smaller than the MDV cannot be ruled out. For robust inference, combine with HonestDiD sensitivity analysis. References ---------- Roth, J. (2022). Pretest with Caution: Event-Study Estimates after Testing for Parallel Trends. American Economic Review: Insights, 4(3), 305-322. """
[docs] def __init__( self, alpha: float = 0.05, power: float = 0.80, violation_type: Literal["linear", "constant", "last_period", "custom"] = "linear", violation_weights: Optional[np.ndarray] = None, pretest_form: Literal["nis", "wald"] = "nis", ): if not 0 < alpha < 1: raise ValueError(f"alpha must be between 0 and 1, got {alpha}") if not 0 < power < 1: raise ValueError(f"power must be between 0 and 1, got {power}") if violation_type not in ["linear", "constant", "last_period", "custom"]: raise ValueError( f"violation_type must be 'linear', 'constant', 'last_period', or 'custom', " f"got '{violation_type}'" ) if violation_type == "custom" and violation_weights is None: raise ValueError("violation_weights must be provided when violation_type='custom'") if pretest_form not in ("nis", "wald"): raise ValueError(f"pretest_form must be 'nis' or 'wald', got '{pretest_form}'") self.alpha = alpha self.target_power = power self.violation_type = violation_type self.violation_weights = ( np.asarray(violation_weights) if violation_weights is not None else None ) self.pretest_form = pretest_form
[docs] def get_params(self) -> Dict[str, Any]: """Get parameters for this estimator.""" return { "alpha": self.alpha, "power": self.target_power, "violation_type": self.violation_type, "violation_weights": self.violation_weights, "pretest_form": self.pretest_form, }
[docs] def set_params(self, **params) -> "PreTrendsPower": """Set parameters for this estimator.""" for key, value in params.items(): if key == "power": self.target_power = value elif hasattr(self, key): setattr(self, key, value) else: raise ValueError(f"Invalid parameter: {key}") return self
def _get_violation_weights( self, n_pre: int, relative_times: Optional[np.ndarray] = None, ) -> np.ndarray: """ Get violation weights based on violation type. Parameters ---------- n_pre : int Number of pre-treatment periods. relative_times : np.ndarray, optional Sorted relative-time labels for the pre-period coefficients (e.g., ``[-3, -2, -1]`` for a regular grid, ``[-5, -3, -1]`` for an irregular grid, ``[-3, -2]`` for an anticipation-shifted grid with ``anticipation=1``). When provided AND ``violation_type='linear'``, weights are set to ``|t|`` directly with NO L2 normalization, so ``δ_t = M * |t|`` and the reported MDV is in Roth's γ units (δ_t = γ·t convention). When None, falls back to the legacy count-based ``[n_pre-1, ..., 1, 0] / ||·||_2`` direction (preserves the pre-PR-B shipped behavior for callers that bypass ``fit()`` and call this helper directly without relative-time labels). Returns ------- np.ndarray Violation weights, with per-violation-type normalization conventions chosen so the magnitude `M` matches what ``REGISTRY.md`` documents for the pattern: - ``'linear'`` with ``relative_times``: ``|t|`` directly, NOT L2-normalized (so ``δ_t = M * |t|`` and the reported MDV is in Roth's γ units). PR-B Step 4. - ``'linear'`` without ``relative_times`` (legacy): the count-based ``[n_pre-1, ..., 0]`` direction, L2-normalized to unit norm (preserves pre-PR-B shipped behavior). - ``'constant'``: ``[1, 1, ..., 1]`` directly, NOT normalized — ``δ_t = M`` per period (a true level shift, matching the documented ``δ_t = c`` convention). PR-B R13 fix: pre-R13 normalization gave ``δ_t = M/√K``, a silent rescaling that the REGISTRY/API did not document. - ``'last_period'``: ``[0, ..., 0, 1]`` directly. Already unit-norm so the post-normalization output was identical; the unconditional early return locks the level-shift contract. - ``'custom'``: user-supplied ``violation_weights``, L2-normalized to unit norm (M is the magnitude along the user's direction; downstream ``max_abs_pre_violation = M * max(|weights|)`` exposes the level-scale max under the MDV). """ if self.violation_type == "custom": assert self.violation_weights is not None if len(self.violation_weights) != n_pre: raise ValueError( f"violation_weights has length {len(self.violation_weights)}, " f"but there are {n_pre} pre-periods" ) weights = self.violation_weights.copy() elif self.violation_type == "linear": if relative_times is not None: # Roth (2022) δ_t = γ · t convention. Use |t| because # pre-period labels are negative; the resulting violation # vector δ_pre = M * |t| satisfies M = γ exactly. # NO L2 normalization — keep the γ-unit scale so the # reported MDV is in Roth's γ units on irregular and # anticipation-shifted grids. Early return; skip the # normalize-at-end block below. See PR-A REGISTRY ## # PreTrendsPower "Note (deviation — linear violation # pattern)" — PR-B Step 4 resolves the deviation when # relative_times is threaded through. if len(relative_times) != n_pre: raise ValueError( f"relative_times has length {len(relative_times)}, " f"but there are {n_pre} pre-periods" ) return np.abs(np.asarray(relative_times)).astype(float) # Backwards-compatible fallback (no relative_times threaded): # legacy count-based [n_pre-1, ..., 1, 0] / ||·||_2 direction. # Used by callers that bypass fit() (e.g., direct # _get_violation_weights() unit tests) or by code paths that # don't have access to the actual pre-period labels. weights = np.arange(-n_pre + 1, 1, dtype=float) weights = -weights # Now [n-1, n-2, ..., 1, 0] elif self.violation_type == "constant": # δ_t = M for all pre-periods (level shift). Skip L2 # normalization so M is exactly the per-period level shift # the REGISTRY documents (`δ_t = c`). Pre-PR-B (and the # pre-R13 PR-B state) divided by sqrt(K), making `δ_t = # M/sqrt(K)` and silently re-scaling reported MDV/power on # constant fits by sqrt(K). PR-B R13 fix: skip the norm # so the public contract matches the docs. return np.ones(n_pre, dtype=float) elif self.violation_type == "last_period": # Violation only in last pre-period (period -1). Unnormalized # `[0, ..., 0, 1]` already has L2 norm 1, so this path was # always equivalent to the post-normalization output; keep # the early return for symmetry with constant + linear-with- # relative_times so the level-shift contract is uniform # across all level-pattern violation types. weights = np.zeros(n_pre, dtype=float) weights[-1] = 1.0 return weights else: raise ValueError(f"Unknown violation_type: {self.violation_type}") # Normalize to unit norm (if not all zeros). The early-return # branches above for linear-with-relative_times, constant, and # last_period intentionally skip this normalization to preserve # the level-shift contract documented in REGISTRY.md # `## PreTrendsPower`. This block only fires for the linear- # legacy-fallback path and `violation_type='custom'`. norm = np.linalg.norm(weights) if norm > 0: weights = weights / norm return weights def _extract_pre_period_params( self, results: Union[MultiPeriodDiDResults, Any], pre_periods: Optional[List[int]] = None, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int, Optional[np.ndarray], str]: """ Extract pre-period parameters from results. Parameters ---------- results : MultiPeriodDiDResults or similar Results object from event study estimation. pre_periods : list of int, optional Explicit list of pre-treatment periods. If None, uses results.pre_periods. Returns ------- effects : np.ndarray Pre-period effect estimates. ses : np.ndarray Pre-period standard errors. vcov : np.ndarray Variance-covariance matrix for pre-period effects. n_pre : int Number of pre-periods. relative_times : np.ndarray or None Pre-period relative-time labels (Roth's δ_t = γ·t convention), or None for callers that bypass the labeled-grid path. covariance_source : str Provenance label describing which covariance path the extraction actually took: - ``"full_pre_period_vcov"`` when a full pre-period covariance sub-block was used (MPD with ``interaction_indices``, or CS/SA with populated ``event_study_vcov``). - ``"diag_fallback"`` when only the per-period standard errors were available (bootstrap / replicate-weight CS or SA fits, MPD without ``interaction_indices``). ``DiagnosticReport`` consumes this label downstream to decide whether the power-tier should be conservatively downgraded (REPORTING.md "conservative deviation" rule), rather than re-inferring covariance provenance from the result type (which would diverge from the actual extraction path the moment the routing changes — see PR-B Step 3). """ if isinstance(results, MultiPeriodDiDResults): # Get pre-period information - use explicit pre_periods if provided if pre_periods is not None: all_pre_periods = list(pre_periods) else: all_pre_periods = results.pre_periods if len(all_pre_periods) == 0: raise ValueError( "No pre-treatment periods found in results. " "Pre-trends power analysis requires pre-period coefficients. " "If you estimated all periods as post_periods, use the pre_periods " "parameter to specify which are actually pre-treatment." ) # Pre-period effects are in period_effects (excluding reference period) estimated_pre_periods = [ p for p in all_pre_periods if p in results.period_effects and results.period_effects[p].se > 0 ] if len(estimated_pre_periods) == 0: raise ValueError( "No estimated pre-period coefficients found. " "The pre-trends test requires at least one estimated " "pre-period coefficient (excluding the reference period)." ) n_pre = len(estimated_pre_periods) effects = np.array([results.period_effects[p].effect for p in estimated_pre_periods]) ses = np.array([results.period_effects[p].se for p in estimated_pre_periods]) # Extract vcov using stored interaction indices for robust extraction if ( results.vcov is not None and hasattr(results, "interaction_indices") and results.interaction_indices is not None ): indices = [results.interaction_indices[p] for p in estimated_pre_periods] vcov = results.vcov[np.ix_(indices, indices)] covariance_source = "full_pre_period_vcov" else: vcov = np.diag(ses**2) covariance_source = "diag_fallback" # For MultiPeriodDiDResults, period identifiers are generic # (often calendar years, sometimes pre-shifted relative times). # Roth's δ_t = γ·t convention needs RELATIVE offsets from the # treatment / reference period. Three label-type regimes: # # 1. Numeric (int / float / np.int64) — direct float() coercion # gives the correct relative offset. # 2. pandas.Period — period arithmetic works on the Period # object directly (``p - ref`` returns ordinal-difference); # we cast via the `n` attribute on the resulting offset for # sub-period frequencies. Datetime-like labels (Timestamp, # np.datetime64) are caught the same way and converted to # days via numpy timedelta semantics. # 3. Genuinely non-numeric / unordered labels (string period # IDs, categoricals without a ranking) — emit an explicit # UserWarning and fall back to the legacy count-based # [n_pre-1, ..., 0] / ||·||_2 normalized direction. The # reported MDV under this fallback is NOT in Roth's γ # units; users on non-numeric labels who need γ-unit MDV # should re-fit with numeric period labels. ref = getattr(results, "reference_period", None) relative_times: Optional[np.ndarray] = None if ref is not None: relative_times = _coerce_relative_times_from_reference(estimated_pre_periods, ref) return effects, ses, vcov, n_pre, relative_times, covariance_source # Try CallawaySantAnnaResults try: from diff_diff.staggered import CallawaySantAnnaResults if isinstance(results, CallawaySantAnnaResults): if results.event_study_effects is None: raise ValueError( "CallawaySantAnnaResults must have event_study_effects. " "Re-run with aggregate='event_study'." ) # Get pre-period effects. Anticipation-aware cutoff per # REGISTRY.md §CallawaySantAnna lines 355-395: with # ``anticipation=k``, true pre-periods are ``t < -k``; # ``t ∈ [-k, -1]`` is the anticipation window and must # not be used for pre-trends power. Filter out # normalization constraints (n_groups=0) and non-finite # SEs as well. _ant = getattr(results, "anticipation", 0) or 0 try: _ant = int(_ant) except (TypeError, ValueError): _ant = 0 _pre_cutoff = -_ant # ``safe_inference`` treats ``se <= 0`` as undefined # inference; filter the same way here so pre-trends # power never silently includes rows whose per-period # SE collapsed (round-33 P0 CI review on PR #318). pre_effects = { t: data for t, data in results.event_study_effects.items() if t < _pre_cutoff and data.get("n_groups", 1) > 0 and np.isfinite(data.get("se", np.nan)) and float(data.get("se", 0.0)) > 0 } if not pre_effects: raise ValueError("No pre-treatment periods found in event study.") pre_periods = sorted(pre_effects.keys()) n_pre = len(pre_periods) effects = np.array([pre_effects[t]["effect"] for t in pre_periods]) ses = np.array([pre_effects[t]["se"] for t in pre_periods]) # Route through full event_study_vcov when available # (non-bootstrap CS fits at staggered_results.py:126-128). # Bootstrap CS fits clear event_study_vcov at # staggered.py:2032-2036, falling through to diag. vcov, covariance_source = _extract_event_study_vcov_subblock( results, pre_periods, ses ) relative_times = np.asarray(pre_periods, dtype=float) return effects, ses, vcov, n_pre, relative_times, covariance_source except ImportError: pass # Try SunAbrahamResults try: from diff_diff.sun_abraham import SunAbrahamResults if isinstance(results, SunAbrahamResults): # Same anticipation-aware pre-period cutoff as # CallawaySantAnna above. _ant = getattr(results, "anticipation", 0) or 0 try: _ant = int(_ant) except (TypeError, ValueError): _ant = 0 _pre_cutoff = -_ant # Mirror the ``se > 0`` filter applied on the CS branch. pre_effects = { t: data for t, data in results.event_study_effects.items() if t < _pre_cutoff and data.get("n_groups", 1) > 0 and np.isfinite(data.get("se", np.nan)) and float(data.get("se", 0.0)) > 0 } if not pre_effects: raise ValueError("No pre-treatment periods found in event study.") pre_periods = sorted(pre_effects.keys()) n_pre = len(pre_periods) effects = np.array([pre_effects[t]["effect"] for t in pre_periods]) ses = np.array([pre_effects[t]["se"] for t in pre_periods]) # Route through full event_study_vcov when available # (non-bootstrap SA fits — sun_abraham.py builds the matrix # via W @ vcov_cohort @ W.T after _compute_iw_effects). # Bootstrap SA fits and replicate-weight survey fits clear # event_study_vcov, falling through to diag. vcov, covariance_source = _extract_event_study_vcov_subblock( results, pre_periods, ses ) relative_times = np.asarray(pre_periods, dtype=float) return effects, ses, vcov, n_pre, relative_times, covariance_source except ImportError: pass raise TypeError( f"Unsupported results type: {type(results)}. " "Expected MultiPeriodDiDResults, CallawaySantAnnaResults, or SunAbrahamResults." ) def _compute_power( self, M: float, weights: np.ndarray, vcov: np.ndarray, ) -> Tuple[float, float, float, float]: """Dispatch to the configured pretest form (NIS by default).""" if self.pretest_form == "nis": return self._compute_power_nis(M, weights, vcov) return self._compute_power_wald(M, weights, vcov) def _compute_power_wald( self, M: float, weights: np.ndarray, vcov: np.ndarray, ) -> Tuple[float, float, float, float]: """ Compute power to detect violation of magnitude M under the Wald form. Wald pre-trends test: H0: delta = 0 vs H1: delta != 0. Under H1 with violation delta = M * weights, the test statistic ``delta' V^{-1} delta`` follows a non-central chi-squared distribution with df=K and noncentrality lambda = M^2 * (w' V^{-1} w). Convex (ellipsoid) acceptance region, so Propositions 1+3+4 of Roth (2022) all apply. Parameters ---------- M : float Violation magnitude. weights : np.ndarray Normalized violation pattern. vcov : np.ndarray Variance-covariance matrix. Returns ------- power : float Power to detect this violation. noncentrality : float Non-centrality parameter. test_stat : float Expected test statistic under H1. critical_value : float Critical value for the test. """ n_pre = len(weights) # Violation vector: delta = M * weights delta = M * weights # Non-centrality parameter for chi-squared test # lambda = delta' * V^{-1} * delta try: vcov_inv = np.linalg.inv(vcov) noncentrality = delta @ vcov_inv @ delta except np.linalg.LinAlgError: # Singular matrix - use pseudo-inverse vcov_inv = np.linalg.pinv(vcov) noncentrality = delta @ vcov_inv @ delta # Critical value from chi-squared distribution critical_value = stats.chi2.ppf(1 - self.alpha, df=n_pre) # Power = P(chi2_nc > critical_value) where chi2_nc is non-central chi2 if noncentrality > 0: power = 1 - stats.ncx2.cdf(critical_value, df=n_pre, nc=noncentrality) else: power = self.alpha # Size under null # Expected test statistic under H1 test_stat = n_pre + noncentrality # Mean of non-central chi2 return power, noncentrality, test_stat, critical_value def _compute_power_nis( self, M: float, weights: np.ndarray, vcov: np.ndarray, ) -> Tuple[float, float, float, float]: """ Compute power to detect violation of magnitude M under the NIS form. NIS (no-individually-significant) pre-trends test: passes iff every pre-period coefficient lies within its own ``+/- z_{1-alpha/2} * sigma_t`` confidence interval. Roth (2022) Section II.A-B; matches the empirical convention used in 12 of 12 surveyed papers (Section I.B). Under H1 with violation ``delta_pre = M * weights``, the rejection probability is computed via the centered change-of-variable ``Y = beta_hat_pre - delta_pre ~ N(0, Sigma_22)``: .. math:: \\text{Power} = 1 - P\\bigl(Y_t \\in [-z\\sigma_t - \\delta_t, z\\sigma_t - \\delta_t] \\text{ for all } t\\bigr) Implemented via ``scipy.stats.multivariate_normal.cdf`` with rectangular bounds (Genz method; supports K up to ~20 cleanly). Parameters ---------- M : float Violation magnitude. weights : np.ndarray Violation pattern (Linear: ``|t|`` directly when fit() threads ``relative_times``; constant / last_period / custom: unit-normalized). vcov : np.ndarray Variance-covariance matrix Sigma_22 of the pre-period coefficients. Returns ------- power : float Probability the NIS test rejects under the alternative. noncentrality : float ``np.nan``. NIS does not have a noncentrality scalar; the equivalent NIS-specific output is ``nis_box_probability`` (the acceptance probability ``1 - power``) stored on ``PreTrendsPowerResults``. test_stat : float ``np.nan``. NIS rejects via a rectangular acceptance event, not a scalar test statistic. critical_value : float ``z_{1-alpha/2}``, the per-period normal critical value used to define ``B_NIS(Sigma)``. """ z_alpha = float(stats.norm.ppf(1 - self.alpha / 2)) # Centralized analytical-or-MC fallback (module-level helper); # handles both exception and non-finite-CDF cases. accept_prob = _compute_nis_acceptance_prob(M, weights, vcov, z_alpha) power = float(1.0 - accept_prob) return power, float("nan"), float("nan"), z_alpha def _compute_mdv( self, weights: np.ndarray, vcov: np.ndarray, ) -> float: """Dispatch to the configured pretest form (NIS by default).""" if self.pretest_form == "nis": return self._compute_mdv_nis(weights, vcov) return self._compute_mdv_wald(weights, vcov) def _compute_mdv_wald( self, weights: np.ndarray, vcov: np.ndarray, ) -> float: """ Compute minimum detectable violation under the Wald form. Find the smallest M such that ``_compute_power_wald(M, weights, vcov) >= target_power``. Uses binary search on the noncentrality parameter, then converts back to M via ``nc = M^2 * (w' V^{-1} w)``. Parameters ---------- weights : np.ndarray Normalized violation pattern. vcov : np.ndarray Variance-covariance matrix. Returns ------- mdv : float Minimum detectable violation in units of M (interpreted relative to the ``weights`` direction; for linear weights threaded with ``relative_times``, this is Roth's gamma in MDV units — see ``_get_violation_weights``). """ n_pre = len(weights) # Critical value critical_value = stats.chi2.ppf(1 - self.alpha, df=n_pre) # Find non-centrality parameter for target power # We need: P(ncx2 > critical_value) = target_power # Use inverse: find lambda such that ncx2.cdf(cv, df, lambda) = 1 - target_power def power_minus_target(nc): if nc <= 0: return self.alpha - self.target_power return stats.ncx2.sf(critical_value, df=n_pre, nc=nc) - self.target_power # Binary search for non-centrality parameter # Start with bounds nc_low, nc_high = 0, 1 # Expand upper bound until power exceeds target while power_minus_target(nc_high) < 0 and nc_high < 1000: nc_high *= 2 if nc_high >= 1000: # Target power not achievable - return inf return np.inf # Binary search try: result = optimize.brentq(power_minus_target, nc_low, nc_high) target_nc = result except ValueError: # Fallback: use approximate formula # For chi2, power ≈ Phi(sqrt(2*nc) - sqrt(2*cv)) # Solving: sqrt(2*nc) = z_power + sqrt(2*cv) z_power = stats.norm.ppf(self.target_power) target_nc = 0.5 * (z_power + np.sqrt(2 * critical_value)) ** 2 # Convert non-centrality to M # nc = delta' * V^{-1} * delta = M^2 * w' * V^{-1} * w try: vcov_inv = np.linalg.inv(vcov) w_Vinv_w = weights @ vcov_inv @ weights except np.linalg.LinAlgError: vcov_inv = np.linalg.pinv(vcov) w_Vinv_w = weights @ vcov_inv @ weights if w_Vinv_w > 0: mdv = np.sqrt(target_nc / w_Vinv_w) else: mdv = np.inf return mdv def _compute_mdv_nis( self, weights: np.ndarray, vcov: np.ndarray, ) -> float: """ Compute minimum detectable violation under the NIS form. Solves ``_compute_power_nis(M, weights, vcov) = target_power`` for M via a doubling expansion to bracket the root, then ``brentq`` bisect. Non-convergence cap at ``M_high = 1000`` returns ``np.inf`` (matches the Wald path's existing 1000-cap fallback). Parameters ---------- weights : np.ndarray Violation pattern. vcov : np.ndarray Variance-covariance matrix Sigma_22. Returns ------- mdv : float Minimum detectable violation. For linear weights threaded with ``relative_times``, this is Roth's gamma at the target power. """ def power_minus_target(M: float) -> float: return self._compute_power_nis(M, weights, vcov)[0] - self.target_power # Boundary short-circuit: if the NIS size under the null # (≈ 1 - (1-α)^K under independence) already meets target_power, # the MDV is zero — no violation needed to reject at target rate. # NIS size is generally LARGER than α (chi² size), so this case # is reachable for small target_power (e.g., target=0.10, α=0.05, # K=3 → null size ≈ 0.143 > 0.10). if power_minus_target(0.0) >= 0: return 0.0 # Doubling expansion to find an upper bound where power >= target. # Cap M_high at 1000 to avoid pathological infinite doubling on # numerically extreme Σ_22, but the cap itself does NOT mean # "unreachable" — explicitly check power at the capped endpoint # before returning inf (codex R2 P0 fix: previously the cap # short-circuited to inf even when power(M_high) >= target, # producing silently wrong MDV=inf for finite-root cases like # vcov=[[50000]] where MDV lies between 512 and 1024). M_high = 1.0 while power_minus_target(M_high) < 0 and M_high < 1000: M_high *= 2 # Defensive: if the doubling exited because M_high*2 would exceed 1000, # the LAST value M_high actually reached might be either above or below # target. Evaluate explicitly at the final M_high to decide. if power_minus_target(M_high) < 0: # Power at the cap still fails to reach target_power. # Genuinely unreachable in the practical range. return np.inf # Bisect on [0, M_high]. Both sign-change endpoints verified above. try: mdv = float(optimize.brentq(power_minus_target, 0.0, M_high)) except ValueError: # Defensive fallback. Should be unreachable. mdv = float(M_high) return mdv
[docs] def fit( self, results: Union[MultiPeriodDiDResults, Any], M: Optional[float] = None, pre_periods: Optional[List[int]] = None, ) -> PreTrendsPowerResults: """ Compute pre-trends power analysis. Parameters ---------- results : MultiPeriodDiDResults, CallawaySantAnnaResults, or SunAbrahamResults Results from an event study estimation. M : float, optional Specific violation magnitude to evaluate. If None, evaluates at a default magnitude based on the data. pre_periods : list of int, optional Explicit list of pre-treatment periods to use for power analysis. If None, attempts to infer from results.pre_periods. Use this when you've estimated an event study with all periods in post_periods and need to specify which are actually pre-treatment. Returns ------- PreTrendsPowerResults Power analysis results including power and MDV. """ # Extract pre-period parameters (now includes relative_times for # γ-unit MDV under linear violation_type, plus the covariance-source # provenance label for downstream DiagnosticReport / BusinessReport # tier classification). ( effects, ses, vcov, n_pre, relative_times, covariance_source, ) = self._extract_pre_period_params(results, pre_periods) # Get violation weights. relative_times threaded through so the # linear-violation path produces γ-unit MDV per Roth's δ_t = γ·t # convention (skip L2 normalization for linear-with-relative_times). weights = self._get_violation_weights(n_pre, relative_times=relative_times) # Compute MDV (dispatches on self.pretest_form) mdv = self._compute_mdv(weights, vcov) # Default M: use MDV if not specified if M is None: M = mdv if np.isfinite(mdv) else np.max(ses) # Compute power at specified M (dispatches on self.pretest_form) power, noncentrality, test_stat, critical_value = self._compute_power(M, weights, vcov) # NIS-specific output: the box acceptance probability. Wald fits leave # this as NaN; the meaningful Wald-specific scalar is `noncentrality`. nis_box_probability = 1.0 - power if self.pretest_form == "nis" else float("nan") return PreTrendsPowerResults( power=power, mdv=mdv, violation_magnitude=M, violation_type=self.violation_type, alpha=self.alpha, target_power=self.target_power, n_pre_periods=n_pre, test_statistic=test_stat, critical_value=critical_value, noncentrality=noncentrality, pre_period_effects=effects, pre_period_ses=ses, vcov=vcov, original_results=results, pretest_form=self.pretest_form, nis_box_probability=nis_box_probability, violation_weights=weights, covariance_source=covariance_source, )
[docs] def power_at( self, results: Union[MultiPeriodDiDResults, Any], M: float, pre_periods: Optional[List[int]] = None, ) -> float: """ Compute power to detect a specific violation magnitude. Parameters ---------- results : results object Event study results. M : float Violation magnitude. pre_periods : list of int, optional Explicit list of pre-treatment periods. See fit() for details. Returns ------- float Power to detect violation of magnitude M. """ result = self.fit(results, M=M, pre_periods=pre_periods) return result.power
[docs] def power_curve( self, results: Union[MultiPeriodDiDResults, Any], M_grid: Optional[List[float]] = None, n_points: int = 50, pre_periods: Optional[List[int]] = None, ) -> PreTrendsPowerCurve: """ Compute power across a range of violation magnitudes. Parameters ---------- results : results object Event study results. M_grid : list of float, optional Specific violation magnitudes to evaluate. If None, creates automatic grid from 0 to 2.5 * MDV. n_points : int, default=50 Number of points in automatic grid. pre_periods : list of int, optional Explicit list of pre-treatment periods. See fit() for details. Returns ------- PreTrendsPowerCurve Power curve data with plot method. """ # Extract parameters (6-tuple includes relative_times + covariance # source; the source label is currently unused on the curve path but # the unpack must match the helper's signature). _, ses, vcov, n_pre, relative_times, _ = self._extract_pre_period_params( results, pre_periods ) weights = self._get_violation_weights(n_pre, relative_times=relative_times) # Compute MDV mdv = self._compute_mdv(weights, vcov) # Create M grid if not provided if M_grid is None: max_M = min(2.5 * mdv if np.isfinite(mdv) else 10 * np.max(ses), 100) M_grid = np.linspace(0, max_M, n_points) else: M_grid = np.asarray(M_grid) # Compute power at each M assert M_grid is not None powers = np.array([self._compute_power(M, weights, vcov)[0] for M in M_grid]) return PreTrendsPowerCurve( M_values=M_grid, powers=powers, mdv=mdv, alpha=self.alpha, target_power=self.target_power, violation_type=self.violation_type, pretest_form=self.pretest_form, )
[docs] def sensitivity_to_honest_did( self, results: Union[MultiPeriodDiDResults, Any], pre_periods: Optional[List[int]] = None, ) -> Dict[str, Any]: """ Compare pre-trends power analysis with HonestDiD sensitivity. This method helps interpret how informative a passing pre-trends test is in the context of HonestDiD's relative magnitudes restriction. Parameters ---------- results : results object Event study results. pre_periods : list of int, optional Explicit list of pre-treatment periods. See fit() for details. Returns ------- dict Dictionary with: - mdv: Minimum detectable violation from pre-trends test - honest_M_at_mdv: Corresponding M value for HonestDiD - interpretation: Text explaining the relationship """ pt_results = self.fit(results, pre_periods=pre_periods) mdv = pt_results.mdv # Level-scale scalar for comparison against the level-scale # per-period SEs. PR-B Step 4: raw `mdv` for `linear` violations # is now Roth's γ units (a slope); the level-scale quantity is # `mdv * max(|violation_weights|)`. See PreTrendsPowerResults. max_abs_pre_violation = pt_results.max_abs_pre_violation # The MDV represents the size of violation the test could detect. # In HonestDiD's relative magnitudes framework, M=1 means # post-treatment violations can be as large as the max pre-period # violation. ``max_abs_pre_violation`` gives us that level-scale # number directly. max_pre_se = np.max(pt_results.pre_period_ses) interpretation = [] interpretation.append(f"Minimum Detectable Violation (MDV): {mdv:.4f}") interpretation.append(f"Max pre-period level deviation at MDV: {max_abs_pre_violation:.4f}") interpretation.append(f"Max pre-period SE: {max_pre_se:.4f}") if np.isfinite(max_abs_pre_violation): # Ratio of max-level-deviation to max SE — how many SEs the # largest pre-period violation under the MDV would be. mdv_in_ses = max_abs_pre_violation / max_pre_se if max_pre_se > 0 else np.inf interpretation.append(f"Max level deviation / max(SE): {mdv_in_ses:.2f}") if mdv_in_ses < 1: interpretation.append("→ Pre-trends test is fairly sensitive to violations.") elif mdv_in_ses < 2: interpretation.append("→ Pre-trends test has moderate sensitivity.") else: interpretation.append("→ Pre-trends test has low power to detect violations.") interpretation.append( " Consider using HonestDiD with larger M values for robustness." ) else: interpretation.append( "→ Pre-trends test cannot achieve target power for any violation size." ) interpretation.append(" Use HonestDiD sensitivity analysis for inference.") return { "mdv": mdv, "max_abs_pre_violation": float(max_abs_pre_violation), "max_pre_se": max_pre_se, "mdv_in_ses": ( max_abs_pre_violation / max_pre_se if max_pre_se > 0 and np.isfinite(max_abs_pre_violation) else np.inf ), "interpretation": "\n".join(interpretation), }
# ============================================================================= # Convenience Functions # =============================================================================
[docs] def compute_pretrends_power( results: Union[MultiPeriodDiDResults, Any], M: Optional[float] = None, alpha: float = 0.05, target_power: float = 0.80, violation_type: str = "linear", pre_periods: Optional[List[int]] = None, violation_weights: Optional[np.ndarray] = None, pretest_form: Literal["nis", "wald"] = "nis", ) -> PreTrendsPowerResults: """ Convenience function for pre-trends power analysis. Parameters ---------- results : results object Event study results. M : float, optional Violation magnitude to evaluate. alpha : float, default=0.05 Significance level. target_power : float, default=0.80 Target power for MDV calculation. violation_type : str, default='linear' Type of violation pattern: ``linear`` / ``constant`` / ``last_period`` / ``custom``. For ``custom``, also pass ``violation_weights``. pre_periods : list of int, optional Explicit list of pre-treatment periods. If None, attempts to infer from results. Use when you've estimated all periods as post_periods. violation_weights : np.ndarray, optional Custom violation pattern weights. Required when ``violation_type='custom'``; ignored for other violation types. pretest_form : {'nis', 'wald'}, default='nis' Pretest acceptance-region form. ``'nis'`` (default) implements Roth (2022) Section II.A-B no-individually-significant box probability via ``scipy.stats.multivariate_normal.cdf``; ``'wald'`` is the noncentral-chi-squared form retained for backwards compatibility with the pre-PR-B shipped numerical output (also a paper-supported alternative under Propositions 1+3+4). Returns ------- PreTrendsPowerResults Power analysis results. Examples -------- >>> from diff_diff import MultiPeriodDiD >>> from diff_diff.pretrends import compute_pretrends_power >>> >>> results = MultiPeriodDiD().fit(data, ...) >>> power_results = compute_pretrends_power(results, pre_periods=[0, 1, 2, 3]) >>> print(f"MDV: {power_results.mdv:.3f}") >>> print(f"Power: {power_results.power:.1%}") """ pt = PreTrendsPower( alpha=alpha, power=target_power, violation_type=violation_type, violation_weights=violation_weights, pretest_form=pretest_form, ) return pt.fit(results, M=M, pre_periods=pre_periods)
[docs] def compute_mdv( results: Union[MultiPeriodDiDResults, Any], alpha: float = 0.05, target_power: float = 0.80, violation_type: str = "linear", pre_periods: Optional[List[int]] = None, violation_weights: Optional[np.ndarray] = None, pretest_form: Literal["nis", "wald"] = "nis", ) -> float: """ Compute minimum detectable violation. Parameters ---------- results : results object Event study results. alpha : float, default=0.05 Significance level. target_power : float, default=0.80 Target power for MDV calculation. violation_type : str, default='linear' Type of violation pattern: ``linear`` / ``constant`` / ``last_period`` / ``custom``. For ``custom``, also pass ``violation_weights``. pre_periods : list of int, optional Explicit list of pre-treatment periods. If None, attempts to infer from results. Use when you've estimated all periods as post_periods. violation_weights : np.ndarray, optional Custom violation pattern weights. Required when ``violation_type='custom'``; ignored for other violation types. pretest_form : {'nis', 'wald'}, default='nis' Pretest acceptance-region form. See ``compute_pretrends_power`` and ``PreTrendsPower`` for the NIS-vs-Wald discussion. Returns ------- float Minimum detectable violation. """ pt = PreTrendsPower( alpha=alpha, power=target_power, violation_type=violation_type, violation_weights=violation_weights, pretest_form=pretest_form, ) result = pt.fit(results, pre_periods=pre_periods) return result.mdv