Source code for diff_diff.pretrends

"""
Pre-trends power analysis for difference-in-differences designs.

This module implements the power analysis framework from Roth (2022) for assessing
the informativeness of pre-trends tests. It answers the question: "If my pre-trends
test passed, what violations would I have been able to detect?"

Key concepts:
- **Minimum Detectable Violation (MDV)**: The smallest pre-trends violation that
  would be detected with given power (e.g., 80%).
- **Power of Pre-Trends Test**: Probability of rejecting parallel trends given
  a specific violation pattern.
- **Relationship to HonestDiD**: If MDV is large relative to your estimated effect,
  a passing pre-trends test provides limited reassurance.

References
----------
Roth, J. (2022). Pretest with Caution: Event-Study Estimates after Testing for
    Parallel Trends. American Economic Review: Insights, 4(3), 305-322.
    https://doi.org/10.1257/aeri.20210236

See Also
--------
https://github.com/jonathandroth/pretrends - R package implementation
diff_diff.honest_did - Sensitivity analysis for parallel trends violations
"""

from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

import numpy as np
import pandas as pd
from scipy import optimize, stats

from diff_diff.results import MultiPeriodDiDResults

# =============================================================================
# Results Classes
# =============================================================================


[docs] @dataclass class PreTrendsPowerResults: """ Results from pre-trends power analysis. Attributes ---------- power : float Power to detect the specified violation pattern at given alpha. mdv : float Minimum detectable violation (smallest M detectable at target power). violation_magnitude : float The magnitude of violation tested (M parameter). violation_type : str Type of violation pattern ('linear', 'constant', 'last_period', 'custom'). alpha : float Significance level for the pre-trends test. target_power : float Target power level used for MDV calculation. n_pre_periods : int Number of pre-treatment periods in the event study. test_statistic : float Expected test statistic under the specified violation. critical_value : float Critical value for the pre-trends test. noncentrality : float Non-centrality parameter under the alternative hypothesis. pre_period_effects : np.ndarray Estimated pre-period effects from the event study. pre_period_ses : np.ndarray Standard errors of pre-period effects. vcov : np.ndarray Variance-covariance matrix of pre-period effects. """ power: float mdv: float violation_magnitude: float violation_type: str alpha: float target_power: float n_pre_periods: int test_statistic: float critical_value: float noncentrality: float pre_period_effects: np.ndarray = field(repr=False) pre_period_ses: np.ndarray = field(repr=False) vcov: np.ndarray = field(repr=False) original_results: Optional[Any] = field(default=None, repr=False) def __repr__(self) -> str: return ( f"PreTrendsPowerResults(power={self.power:.3f}, " f"mdv={self.mdv:.4f}, M={self.violation_magnitude:.4f})" ) @property def is_informative(self) -> bool: """ Check if the pre-trends test is informative. A pre-trends test is considered informative if the MDV is reasonably small relative to typical effect sizes. This is a heuristic check; see the summary for interpretation guidance. """ # Heuristic: MDV < 2x the max observed pre-period SE max_se = np.max(self.pre_period_ses) if len(self.pre_period_ses) > 0 else 1.0 return bool(self.mdv < 2 * max_se) @property def power_adequate(self) -> bool: """Check if power meets the target threshold.""" return bool(self.power >= self.target_power)
[docs] def summary(self) -> str: """ Generate formatted summary of pre-trends power analysis. Returns ------- str Formatted summary. """ lines = [ "=" * 70, "Pre-Trends Power Analysis Results".center(70), "(Roth 2022)".center(70), "=" * 70, "", f"{'Number of pre-periods:':<35} {self.n_pre_periods}", f"{'Significance level (alpha):':<35} {self.alpha:.3f}", f"{'Target power:':<35} {self.target_power:.1%}", f"{'Violation type:':<35} {self.violation_type}", "", "-" * 70, "Power Analysis".center(70), "-" * 70, f"{'Violation magnitude (M):':<35} {self.violation_magnitude:.4f}", f"{'Power to detect this violation:':<35} {self.power:.1%}", f"{'Minimum detectable violation:':<35} {self.mdv:.4f}", "", f"{'Test statistic (expected):':<35} {self.test_statistic:.4f}", f"{'Critical value:':<35} {self.critical_value:.4f}", f"{'Non-centrality parameter:':<35} {self.noncentrality:.4f}", "", "-" * 70, "Interpretation".center(70), "-" * 70, ] if self.power_adequate: lines.append(f"✓ Power ({self.power:.0%}) meets target ({self.target_power:.0%}).") lines.append( f" The pre-trends test would detect violations of magnitude {self.violation_magnitude:.3f}." ) else: lines.append(f"✗ Power ({self.power:.0%}) below target ({self.target_power:.0%}).") lines.append( f" Would need violations of {self.mdv:.3f} to achieve {self.target_power:.0%} power." ) lines.append("") lines.append(f"Minimum detectable violation (MDV): {self.mdv:.4f}") lines.append(" → Passing pre-trends test does NOT rule out violations up to this size.") lines.extend(["", "=" * 70]) return "\n".join(lines)
[docs] def print_summary(self) -> None: """Print summary to stdout.""" print(self.summary())
[docs] def to_dict(self) -> Dict[str, Any]: """Convert results to dictionary.""" return { "power": self.power, "mdv": self.mdv, "violation_magnitude": self.violation_magnitude, "violation_type": self.violation_type, "alpha": self.alpha, "target_power": self.target_power, "n_pre_periods": self.n_pre_periods, "test_statistic": self.test_statistic, "critical_value": self.critical_value, "noncentrality": self.noncentrality, "is_informative": self.is_informative, "power_adequate": self.power_adequate, }
[docs] def to_dataframe(self) -> pd.DataFrame: """Convert results to DataFrame.""" return pd.DataFrame([self.to_dict()])
[docs] def power_at(self, M: float) -> float: """ Compute power to detect a specific violation magnitude. This method allows computing power at different M values without re-fitting the model, using the stored variance-covariance matrix. Parameters ---------- M : float Violation magnitude to evaluate. Returns ------- float Power to detect violation of magnitude M. """ from scipy import stats n_pre = self.n_pre_periods # Reconstruct violation weights based on violation type # Must match PreTrendsPower._get_violation_weights() exactly if self.violation_type == "linear": # Linear trend: weights decrease toward treatment # [n-1, n-2, ..., 1, 0] for n pre-periods weights = np.arange(-n_pre + 1, 1, dtype=float) weights = -weights # Now [n-1, n-2, ..., 1, 0] elif self.violation_type == "constant": weights = np.ones(n_pre) elif self.violation_type == "last_period": weights = np.zeros(n_pre) weights[-1] = 1.0 else: # For custom, we can't reconstruct - use equal weights as fallback weights = np.ones(n_pre) # Normalize weights to unit L2 norm norm = np.linalg.norm(weights) if norm > 0: weights = weights / norm # Compute non-centrality parameter try: vcov_inv = np.linalg.inv(self.vcov) except np.linalg.LinAlgError: vcov_inv = np.linalg.pinv(self.vcov) # delta = M * weights # nc = delta' * V^{-1} * delta noncentrality = M**2 * (weights @ vcov_inv @ weights) # Compute power using non-central chi-squared power = 1 - stats.ncx2.cdf(self.critical_value, df=n_pre, nc=noncentrality) return float(power)
[docs] @dataclass class PreTrendsPowerCurve: """ Power curve across violation magnitudes. Attributes ---------- M_values : np.ndarray Grid of violation magnitudes tested. powers : np.ndarray Power at each violation magnitude. mdv : float Minimum detectable violation. alpha : float Significance level. target_power : float Target power level. violation_type : str Type of violation pattern. """ M_values: np.ndarray powers: np.ndarray mdv: float alpha: float target_power: float violation_type: str def __repr__(self) -> str: return f"PreTrendsPowerCurve(n_points={len(self.M_values)}, " f"mdv={self.mdv:.4f})"
[docs] def to_dataframe(self) -> pd.DataFrame: """Convert to DataFrame with M and power columns.""" return pd.DataFrame( { "M": self.M_values, "power": self.powers, } )
[docs] def plot( self, ax=None, show_mdv: bool = True, show_target: bool = True, color: str = "#2563eb", mdv_color: str = "#dc2626", target_color: str = "#22c55e", **kwargs, ): """ Plot the power curve. Parameters ---------- ax : matplotlib.axes.Axes, optional Axes to plot on. If None, creates new figure. show_mdv : bool, default=True Whether to show vertical line at MDV. show_target : bool, default=True Whether to show horizontal line at target power. color : str Color for power curve line. mdv_color : str Color for MDV vertical line. target_color : str Color for target power horizontal line. **kwargs Additional arguments passed to plt.plot(). Returns ------- ax : matplotlib.axes.Axes The axes with the plot. """ try: import matplotlib.pyplot as plt except ImportError: raise ImportError("matplotlib is required for plotting") if ax is None: fig, ax = plt.subplots(figsize=(10, 6)) # Plot power curve ax.plot(self.M_values, self.powers, color=color, linewidth=2, label="Power", **kwargs) # Target power line if show_target: ax.axhline( y=self.target_power, color=target_color, linestyle="--", linewidth=1.5, alpha=0.7, label=f"Target power ({self.target_power:.0%})", ) # MDV line if show_mdv and self.mdv is not None and np.isfinite(self.mdv): ax.axvline( x=self.mdv, color=mdv_color, linestyle=":", linewidth=1.5, alpha=0.7, label=f"MDV = {self.mdv:.3f}", ) ax.set_xlabel("Violation Magnitude (M)") ax.set_ylabel("Power") ax.set_title("Pre-Trends Test Power Curve") ax.set_ylim(0, 1.05) ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y:.0%}")) ax.legend(loc="lower right") ax.grid(True, alpha=0.3) return ax
# ============================================================================= # Main Class # =============================================================================
[docs] class PreTrendsPower: """ Pre-trends power analysis (Roth 2022). Computes the power of pre-trends tests to detect violations of parallel trends, and the minimum detectable violation (MDV). Parameters ---------- alpha : float, default=0.05 Significance level for the pre-trends test. power : float, default=0.80 Target power level for MDV calculation. violation_type : str, default='linear' Type of violation pattern to consider: - 'linear': Violations follow a linear trend (most common) - 'constant': Same violation in all pre-periods - 'last_period': Violation only in the last pre-period - 'custom': User-specified violation pattern (via violation_weights) violation_weights : array-like, optional Custom weights for violation pattern. Length must equal number of pre-periods. Only used when violation_type='custom'. Examples -------- Basic usage with MultiPeriodDiD results: >>> from diff_diff import MultiPeriodDiD >>> from diff_diff.pretrends import PreTrendsPower >>> >>> # Fit event study >>> mp_did = MultiPeriodDiD() >>> results = mp_did.fit(data, outcome='y', treatment='treated', ... time='period', post_periods=[4, 5, 6, 7]) >>> >>> # Analyze pre-trends power >>> pt = PreTrendsPower(alpha=0.05, power=0.80) >>> power_results = pt.fit(results) >>> print(power_results.summary()) >>> >>> # Get power curve >>> curve = pt.power_curve(results) >>> curve.plot() Notes ----- The pre-trends test is typically a joint test that all pre-period coefficients are zero. This test has limited power to detect small violations, especially when: 1. There are few pre-periods 2. Standard errors are large 3. The violation pattern is smooth (e.g., linear trend) Passing a pre-trends test does NOT mean parallel trends holds. It means violations smaller than the MDV cannot be ruled out. For robust inference, combine with HonestDiD sensitivity analysis. References ---------- Roth, J. (2022). Pretest with Caution: Event-Study Estimates after Testing for Parallel Trends. American Economic Review: Insights, 4(3), 305-322. """
[docs] def __init__( self, alpha: float = 0.05, power: float = 0.80, violation_type: Literal["linear", "constant", "last_period", "custom"] = "linear", violation_weights: Optional[np.ndarray] = None, ): if not 0 < alpha < 1: raise ValueError(f"alpha must be between 0 and 1, got {alpha}") if not 0 < power < 1: raise ValueError(f"power must be between 0 and 1, got {power}") if violation_type not in ["linear", "constant", "last_period", "custom"]: raise ValueError( f"violation_type must be 'linear', 'constant', 'last_period', or 'custom', " f"got '{violation_type}'" ) if violation_type == "custom" and violation_weights is None: raise ValueError("violation_weights must be provided when violation_type='custom'") self.alpha = alpha self.target_power = power self.violation_type = violation_type self.violation_weights = ( np.asarray(violation_weights) if violation_weights is not None else None )
[docs] def get_params(self) -> Dict[str, Any]: """Get parameters for this estimator.""" return { "alpha": self.alpha, "power": self.target_power, "violation_type": self.violation_type, "violation_weights": self.violation_weights, }
[docs] def set_params(self, **params) -> "PreTrendsPower": """Set parameters for this estimator.""" for key, value in params.items(): if key == "power": self.target_power = value elif hasattr(self, key): setattr(self, key, value) else: raise ValueError(f"Invalid parameter: {key}") return self
def _get_violation_weights(self, n_pre: int) -> np.ndarray: """ Get violation weights based on violation type. Parameters ---------- n_pre : int Number of pre-treatment periods. Returns ------- np.ndarray Violation weights, normalized to have L2 norm of 1. """ if self.violation_type == "custom": assert self.violation_weights is not None if len(self.violation_weights) != n_pre: raise ValueError( f"violation_weights has length {len(self.violation_weights)}, " f"but there are {n_pre} pre-periods" ) weights = self.violation_weights.copy() elif self.violation_type == "linear": # Linear trend: weights = [-n+1, -n+2, ..., -1, 0] for periods ending at -1 # Normalized so that violation at period -1 = 0 and grows linearly backward weights = np.arange(-n_pre + 1, 1, dtype=float) # Shift so that weights are positive and represent deviation from PT weights = -weights # Now [n-1, n-2, ..., 1, 0] elif self.violation_type == "constant": # Same violation in all periods weights = np.ones(n_pre) elif self.violation_type == "last_period": # Violation only in last pre-period (period -1) weights = np.zeros(n_pre) weights[-1] = 1.0 else: raise ValueError(f"Unknown violation_type: {self.violation_type}") # Normalize to unit norm (if not all zeros) norm = np.linalg.norm(weights) if norm > 0: weights = weights / norm return weights def _extract_pre_period_params( self, results: Union[MultiPeriodDiDResults, Any], pre_periods: Optional[List[int]] = None, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int]: """ Extract pre-period parameters from results. Parameters ---------- results : MultiPeriodDiDResults or similar Results object from event study estimation. pre_periods : list of int, optional Explicit list of pre-treatment periods. If None, uses results.pre_periods. Returns ------- effects : np.ndarray Pre-period effect estimates. ses : np.ndarray Pre-period standard errors. vcov : np.ndarray Variance-covariance matrix for pre-period effects. n_pre : int Number of pre-periods. """ if isinstance(results, MultiPeriodDiDResults): # Get pre-period information - use explicit pre_periods if provided if pre_periods is not None: all_pre_periods = list(pre_periods) else: all_pre_periods = results.pre_periods if len(all_pre_periods) == 0: raise ValueError( "No pre-treatment periods found in results. " "Pre-trends power analysis requires pre-period coefficients. " "If you estimated all periods as post_periods, use the pre_periods " "parameter to specify which are actually pre-treatment." ) # Pre-period effects are in period_effects (excluding reference period) estimated_pre_periods = [ p for p in all_pre_periods if p in results.period_effects and results.period_effects[p].se > 0 ] if len(estimated_pre_periods) == 0: raise ValueError( "No estimated pre-period coefficients found. " "The pre-trends test requires at least one estimated " "pre-period coefficient (excluding the reference period)." ) n_pre = len(estimated_pre_periods) effects = np.array([results.period_effects[p].effect for p in estimated_pre_periods]) ses = np.array([results.period_effects[p].se for p in estimated_pre_periods]) # Extract vcov using stored interaction indices for robust extraction if ( results.vcov is not None and hasattr(results, "interaction_indices") and results.interaction_indices is not None ): indices = [results.interaction_indices[p] for p in estimated_pre_periods] vcov = results.vcov[np.ix_(indices, indices)] else: vcov = np.diag(ses**2) return effects, ses, vcov, n_pre # Try CallawaySantAnnaResults try: from diff_diff.staggered import CallawaySantAnnaResults if isinstance(results, CallawaySantAnnaResults): if results.event_study_effects is None: raise ValueError( "CallawaySantAnnaResults must have event_study_effects. " "Re-run with aggregate='event_study'." ) # Get pre-period effects. Anticipation-aware cutoff per # REGISTRY.md §CallawaySantAnna lines 355-395: with # ``anticipation=k``, true pre-periods are ``t < -k``; # ``t ∈ [-k, -1]`` is the anticipation window and must # not be used for pre-trends power. Filter out # normalization constraints (n_groups=0) and non-finite # SEs as well. _ant = getattr(results, "anticipation", 0) or 0 try: _ant = int(_ant) except (TypeError, ValueError): _ant = 0 _pre_cutoff = -_ant # ``safe_inference`` treats ``se <= 0`` as undefined # inference; filter the same way here so pre-trends # power never silently includes rows whose per-period # SE collapsed (round-33 P0 CI review on PR #318). pre_effects = { t: data for t, data in results.event_study_effects.items() if t < _pre_cutoff and data.get("n_groups", 1) > 0 and np.isfinite(data.get("se", np.nan)) and float(data.get("se", 0.0)) > 0 } if not pre_effects: raise ValueError("No pre-treatment periods found in event study.") pre_periods = sorted(pre_effects.keys()) n_pre = len(pre_periods) effects = np.array([pre_effects[t]["effect"] for t in pre_periods]) ses = np.array([pre_effects[t]["se"] for t in pre_periods]) vcov = np.diag(ses**2) return effects, ses, vcov, n_pre except ImportError: pass # Try SunAbrahamResults try: from diff_diff.sun_abraham import SunAbrahamResults if isinstance(results, SunAbrahamResults): # Same anticipation-aware pre-period cutoff as # CallawaySantAnna above. _ant = getattr(results, "anticipation", 0) or 0 try: _ant = int(_ant) except (TypeError, ValueError): _ant = 0 _pre_cutoff = -_ant # Mirror the ``se > 0`` filter applied on the CS branch. pre_effects = { t: data for t, data in results.event_study_effects.items() if t < _pre_cutoff and data.get("n_groups", 1) > 0 and np.isfinite(data.get("se", np.nan)) and float(data.get("se", 0.0)) > 0 } if not pre_effects: raise ValueError("No pre-treatment periods found in event study.") pre_periods = sorted(pre_effects.keys()) n_pre = len(pre_periods) effects = np.array([pre_effects[t]["effect"] for t in pre_periods]) ses = np.array([pre_effects[t]["se"] for t in pre_periods]) vcov = np.diag(ses**2) return effects, ses, vcov, n_pre except ImportError: pass raise TypeError( f"Unsupported results type: {type(results)}. " "Expected MultiPeriodDiDResults, CallawaySantAnnaResults, or SunAbrahamResults." ) def _compute_power( self, M: float, weights: np.ndarray, vcov: np.ndarray, ) -> Tuple[float, float, float, float]: """ Compute power to detect violation of magnitude M. The pre-trends test is a Wald test: H0: delta = 0 vs H1: delta != 0 Under H1 with violation delta = M * weights, the test statistic follows a non-central chi-squared distribution. Parameters ---------- M : float Violation magnitude. weights : np.ndarray Normalized violation pattern. vcov : np.ndarray Variance-covariance matrix. Returns ------- power : float Power to detect this violation. noncentrality : float Non-centrality parameter. test_stat : float Expected test statistic under H1. critical_value : float Critical value for the test. """ n_pre = len(weights) # Violation vector: delta = M * weights delta = M * weights # Non-centrality parameter for chi-squared test # lambda = delta' * V^{-1} * delta try: vcov_inv = np.linalg.inv(vcov) noncentrality = delta @ vcov_inv @ delta except np.linalg.LinAlgError: # Singular matrix - use pseudo-inverse vcov_inv = np.linalg.pinv(vcov) noncentrality = delta @ vcov_inv @ delta # Critical value from chi-squared distribution critical_value = stats.chi2.ppf(1 - self.alpha, df=n_pre) # Power = P(chi2_nc > critical_value) where chi2_nc is non-central chi2 if noncentrality > 0: power = 1 - stats.ncx2.cdf(critical_value, df=n_pre, nc=noncentrality) else: power = self.alpha # Size under null # Expected test statistic under H1 test_stat = n_pre + noncentrality # Mean of non-central chi2 return power, noncentrality, test_stat, critical_value def _compute_mdv( self, weights: np.ndarray, vcov: np.ndarray, ) -> float: """ Compute minimum detectable violation. Find the smallest M such that power >= target_power. Parameters ---------- weights : np.ndarray Normalized violation pattern. vcov : np.ndarray Variance-covariance matrix. Returns ------- mdv : float Minimum detectable violation. """ n_pre = len(weights) # Critical value critical_value = stats.chi2.ppf(1 - self.alpha, df=n_pre) # Find non-centrality parameter for target power # We need: P(ncx2 > critical_value) = target_power # Use inverse: find lambda such that ncx2.cdf(cv, df, lambda) = 1 - target_power def power_minus_target(nc): if nc <= 0: return self.alpha - self.target_power return stats.ncx2.sf(critical_value, df=n_pre, nc=nc) - self.target_power # Binary search for non-centrality parameter # Start with bounds nc_low, nc_high = 0, 1 # Expand upper bound until power exceeds target while power_minus_target(nc_high) < 0 and nc_high < 1000: nc_high *= 2 if nc_high >= 1000: # Target power not achievable - return inf return np.inf # Binary search try: result = optimize.brentq(power_minus_target, nc_low, nc_high) target_nc = result except ValueError: # Fallback: use approximate formula # For chi2, power ≈ Phi(sqrt(2*nc) - sqrt(2*cv)) # Solving: sqrt(2*nc) = z_power + sqrt(2*cv) z_power = stats.norm.ppf(self.target_power) target_nc = 0.5 * (z_power + np.sqrt(2 * critical_value)) ** 2 # Convert non-centrality to M # nc = delta' * V^{-1} * delta = M^2 * w' * V^{-1} * w try: vcov_inv = np.linalg.inv(vcov) w_Vinv_w = weights @ vcov_inv @ weights except np.linalg.LinAlgError: vcov_inv = np.linalg.pinv(vcov) w_Vinv_w = weights @ vcov_inv @ weights if w_Vinv_w > 0: mdv = np.sqrt(target_nc / w_Vinv_w) else: mdv = np.inf return mdv
[docs] def fit( self, results: Union[MultiPeriodDiDResults, Any], M: Optional[float] = None, pre_periods: Optional[List[int]] = None, ) -> PreTrendsPowerResults: """ Compute pre-trends power analysis. Parameters ---------- results : MultiPeriodDiDResults, CallawaySantAnnaResults, or SunAbrahamResults Results from an event study estimation. M : float, optional Specific violation magnitude to evaluate. If None, evaluates at a default magnitude based on the data. pre_periods : list of int, optional Explicit list of pre-treatment periods to use for power analysis. If None, attempts to infer from results.pre_periods. Use this when you've estimated an event study with all periods in post_periods and need to specify which are actually pre-treatment. Returns ------- PreTrendsPowerResults Power analysis results including power and MDV. """ # Extract pre-period parameters effects, ses, vcov, n_pre = self._extract_pre_period_params(results, pre_periods) # Get violation weights weights = self._get_violation_weights(n_pre) # Compute MDV mdv = self._compute_mdv(weights, vcov) # Default M: use MDV if not specified if M is None: M = mdv if np.isfinite(mdv) else np.max(ses) # Compute power at specified M power, noncentrality, test_stat, critical_value = self._compute_power(M, weights, vcov) return PreTrendsPowerResults( power=power, mdv=mdv, violation_magnitude=M, violation_type=self.violation_type, alpha=self.alpha, target_power=self.target_power, n_pre_periods=n_pre, test_statistic=test_stat, critical_value=critical_value, noncentrality=noncentrality, pre_period_effects=effects, pre_period_ses=ses, vcov=vcov, original_results=results, )
[docs] def power_at( self, results: Union[MultiPeriodDiDResults, Any], M: float, pre_periods: Optional[List[int]] = None, ) -> float: """ Compute power to detect a specific violation magnitude. Parameters ---------- results : results object Event study results. M : float Violation magnitude. pre_periods : list of int, optional Explicit list of pre-treatment periods. See fit() for details. Returns ------- float Power to detect violation of magnitude M. """ result = self.fit(results, M=M, pre_periods=pre_periods) return result.power
[docs] def power_curve( self, results: Union[MultiPeriodDiDResults, Any], M_grid: Optional[List[float]] = None, n_points: int = 50, pre_periods: Optional[List[int]] = None, ) -> PreTrendsPowerCurve: """ Compute power across a range of violation magnitudes. Parameters ---------- results : results object Event study results. M_grid : list of float, optional Specific violation magnitudes to evaluate. If None, creates automatic grid from 0 to 2.5 * MDV. n_points : int, default=50 Number of points in automatic grid. pre_periods : list of int, optional Explicit list of pre-treatment periods. See fit() for details. Returns ------- PreTrendsPowerCurve Power curve data with plot method. """ # Extract parameters _, ses, vcov, n_pre = self._extract_pre_period_params(results, pre_periods) weights = self._get_violation_weights(n_pre) # Compute MDV mdv = self._compute_mdv(weights, vcov) # Create M grid if not provided if M_grid is None: max_M = min(2.5 * mdv if np.isfinite(mdv) else 10 * np.max(ses), 100) M_grid = np.linspace(0, max_M, n_points) else: M_grid = np.asarray(M_grid) # Compute power at each M assert M_grid is not None powers = np.array([self._compute_power(M, weights, vcov)[0] for M in M_grid]) return PreTrendsPowerCurve( M_values=M_grid, powers=powers, mdv=mdv, alpha=self.alpha, target_power=self.target_power, violation_type=self.violation_type, )
[docs] def sensitivity_to_honest_did( self, results: Union[MultiPeriodDiDResults, Any], pre_periods: Optional[List[int]] = None, ) -> Dict[str, Any]: """ Compare pre-trends power analysis with HonestDiD sensitivity. This method helps interpret how informative a passing pre-trends test is in the context of HonestDiD's relative magnitudes restriction. Parameters ---------- results : results object Event study results. pre_periods : list of int, optional Explicit list of pre-treatment periods. See fit() for details. Returns ------- dict Dictionary with: - mdv: Minimum detectable violation from pre-trends test - honest_M_at_mdv: Corresponding M value for HonestDiD - interpretation: Text explaining the relationship """ pt_results = self.fit(results, pre_periods=pre_periods) mdv = pt_results.mdv # The MDV represents the size of violation the test could detect # In HonestDiD's relative magnitudes framework, M=1 means # post-treatment violations can be as large as the max pre-period violation # The MDV gives us a sense of how large that max violation could be max_pre_se = np.max(pt_results.pre_period_ses) interpretation = [] interpretation.append(f"Minimum Detectable Violation (MDV): {mdv:.4f}") interpretation.append(f"Max pre-period SE: {max_pre_se:.4f}") if np.isfinite(mdv): # Ratio of MDV to max SE - gives sense of how many SEs the MDV is mdv_in_ses = mdv / max_pre_se if max_pre_se > 0 else np.inf interpretation.append(f"MDV / max(SE): {mdv_in_ses:.2f}") if mdv_in_ses < 1: interpretation.append("→ Pre-trends test is fairly sensitive to violations.") elif mdv_in_ses < 2: interpretation.append("→ Pre-trends test has moderate sensitivity.") else: interpretation.append("→ Pre-trends test has low power to detect violations.") interpretation.append( " Consider using HonestDiD with larger M values for robustness." ) else: interpretation.append( "→ Pre-trends test cannot achieve target power for any violation size." ) interpretation.append(" Use HonestDiD sensitivity analysis for inference.") return { "mdv": mdv, "max_pre_se": max_pre_se, "mdv_in_ses": mdv / max_pre_se if max_pre_se > 0 and np.isfinite(mdv) else np.inf, "interpretation": "\n".join(interpretation), }
# ============================================================================= # Convenience Functions # =============================================================================
[docs] def compute_pretrends_power( results: Union[MultiPeriodDiDResults, Any], M: Optional[float] = None, alpha: float = 0.05, target_power: float = 0.80, violation_type: str = "linear", pre_periods: Optional[List[int]] = None, ) -> PreTrendsPowerResults: """ Convenience function for pre-trends power analysis. Parameters ---------- results : results object Event study results. M : float, optional Violation magnitude to evaluate. alpha : float, default=0.05 Significance level. target_power : float, default=0.80 Target power for MDV calculation. violation_type : str, default='linear' Type of violation pattern. pre_periods : list of int, optional Explicit list of pre-treatment periods. If None, attempts to infer from results. Use when you've estimated all periods as post_periods. Returns ------- PreTrendsPowerResults Power analysis results. Examples -------- >>> from diff_diff import MultiPeriodDiD >>> from diff_diff.pretrends import compute_pretrends_power >>> >>> results = MultiPeriodDiD().fit(data, ...) >>> power_results = compute_pretrends_power(results, pre_periods=[0, 1, 2, 3]) >>> print(f"MDV: {power_results.mdv:.3f}") >>> print(f"Power: {power_results.power:.1%}") """ pt = PreTrendsPower( alpha=alpha, power=target_power, violation_type=violation_type, ) return pt.fit(results, M=M, pre_periods=pre_periods)
[docs] def compute_mdv( results: Union[MultiPeriodDiDResults, Any], alpha: float = 0.05, target_power: float = 0.80, violation_type: str = "linear", pre_periods: Optional[List[int]] = None, ) -> float: """ Compute minimum detectable violation. Parameters ---------- results : results object Event study results. alpha : float, default=0.05 Significance level. target_power : float, default=0.80 Target power for MDV calculation. violation_type : str, default='linear' Type of violation pattern. pre_periods : list of int, optional Explicit list of pre-treatment periods. If None, attempts to infer from results. Use when you've estimated all periods as post_periods. Returns ------- float Minimum detectable violation. """ pt = PreTrendsPower( alpha=alpha, power=target_power, violation_type=violation_type, ) result = pt.fit(results, pre_periods=pre_periods) return result.mdv