"""
Honest DiD sensitivity analysis (Rambachan & Roth 2023).
Provides robust inference for difference-in-differences designs when
parallel trends may be violated. Instead of assuming parallel trends
holds exactly, this module allows for bounded violations and computes
partially identified treatment effect bounds.
References
----------
Rambachan, A., & Roth, J. (2023). A More Credible Approach to Parallel Trends.
The Review of Economic Studies, 90(5), 2555-2591.
https://doi.org/10.1093/restud/rdad018
See Also
--------
https://github.com/asheshrambachan/HonestDiD - R package implementation
"""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import numpy as np
import pandas as pd
from scipy import optimize, stats
from diff_diff.results import (
MultiPeriodDiDResults,
)
# =============================================================================
# Delta Restriction Classes
# =============================================================================
[docs]
@dataclass
class DeltaSD:
"""
Smoothness restriction on trend violations (Delta^{SD}).
Restricts the second differences of the trend violations:
|delta_{t+1} - 2*delta_t + delta_{t-1}| <= M
When M=0, this enforces that violations follow a linear trend
(linear extrapolation of pre-trends). Larger M allows more
curvature in the violation path.
Parameters
----------
M : float
Maximum allowed second difference. M=0 means linear trends only.
Examples
--------
>>> delta = DeltaSD(M=0.5)
>>> delta.M
0.5
"""
M: float = 0.0
def __post_init__(self):
if self.M < 0:
raise ValueError(f"M must be non-negative, got M={self.M}")
def __repr__(self) -> str:
return f"DeltaSD(M={self.M})"
[docs]
@dataclass
class DeltaRM:
"""
Relative magnitudes restriction on trend violations (Delta^{RM}).
Post-treatment violations are bounded by Mbar times the maximum
absolute pre-treatment violation:
|delta_post| <= Mbar * max(|delta_pre|)
When Mbar=0, this enforces exact parallel trends post-treatment.
Mbar=1 means post-period violations can be as large as the worst
observed pre-period violation.
Parameters
----------
Mbar : float
Scaling factor for maximum pre-period violation.
Examples
--------
>>> delta = DeltaRM(Mbar=1.0)
>>> delta.Mbar
1.0
"""
Mbar: float = 1.0
def __post_init__(self):
if self.Mbar < 0:
raise ValueError(f"Mbar must be non-negative, got Mbar={self.Mbar}")
def __repr__(self) -> str:
return f"DeltaRM(Mbar={self.Mbar})"
[docs]
@dataclass
class DeltaSDRM:
"""
Combined smoothness and relative magnitudes restriction.
Imposes both:
1. Smoothness: |delta_{t+1} - 2*delta_t + delta_{t-1}| <= M
2. Relative magnitudes: |delta_post| <= Mbar * max(|delta_pre|)
This is more restrictive than either constraint alone.
Parameters
----------
M : float
Maximum allowed second difference (smoothness).
Mbar : float
Scaling factor for maximum pre-period violation (relative magnitudes).
Examples
--------
>>> delta = DeltaSDRM(M=0.5, Mbar=1.0)
"""
M: float = 0.0
Mbar: float = 1.0
def __post_init__(self):
if self.M < 0:
raise ValueError(f"M must be non-negative, got M={self.M}")
if self.Mbar < 0:
raise ValueError(f"Mbar must be non-negative, got Mbar={self.Mbar}")
def __repr__(self) -> str:
return f"DeltaSDRM(M={self.M}, Mbar={self.Mbar})"
DeltaType = Union[DeltaSD, DeltaRM, DeltaSDRM]
# =============================================================================
# Results Classes
# =============================================================================
[docs]
@dataclass
class HonestDiDResults:
"""
Results from Honest DiD sensitivity analysis.
Contains bounds on the treatment effect under the specified
restrictions on violations of parallel trends.
Attributes
----------
lb : float
Lower bound of identified set.
ub : float
Upper bound of identified set.
ci_lb : float
Lower bound of robust confidence interval.
ci_ub : float
Upper bound of robust confidence interval.
M : float
The restriction parameter value used.
method : str
The type of restriction ("smoothness", "relative_magnitude", or "combined").
original_estimate : float
The original point estimate (under parallel trends).
original_se : float
The original standard error.
alpha : float
Significance level for confidence interval.
ci_method : str
Method used for CI construction ("FLCI" or "C-LF").
original_results : Any
The original estimation results object.
"""
lb: float
ub: float
ci_lb: float
ci_ub: float
M: float
method: str
original_estimate: float
original_se: float
alpha: float = 0.05
ci_method: str = "FLCI"
original_results: Optional[Any] = field(default=None, repr=False)
# Event study bounds (optional)
event_study_bounds: Optional[Dict[Any, Dict[str, float]]] = field(default=None, repr=False)
def __repr__(self) -> str:
sig = "" if self.ci_lb <= 0 <= self.ci_ub else "*"
return (
f"HonestDiDResults(bounds=[{self.lb:.4f}, {self.ub:.4f}], "
f"CI=[{self.ci_lb:.4f}, {self.ci_ub:.4f}]{sig}, "
f"M={self.M})"
)
@property
def is_significant(self) -> bool:
"""Check if CI excludes zero (effect is robust to violations)."""
return not (self.ci_lb <= 0 <= self.ci_ub)
@property
def significance_stars(self) -> str:
"""
Return significance indicator if robust CI excludes zero.
Note: Unlike point estimation, partial identification does not yield
a single p-value. This returns "*" if the robust CI excludes zero
at the specified alpha level, indicating the effect is robust to
the assumed violations of parallel trends.
"""
return "*" if self.is_significant else ""
@property
def identified_set_width(self) -> float:
"""Width of the identified set."""
return self.ub - self.lb
@property
def ci_width(self) -> float:
"""Width of the confidence interval."""
return self.ci_ub - self.ci_lb
[docs]
def summary(self) -> str:
"""
Generate formatted summary of sensitivity analysis results.
Returns
-------
str
Formatted summary.
"""
conf_level = int((1 - self.alpha) * 100)
method_names = {
"smoothness": "Smoothness (Delta^SD)",
"relative_magnitude": "Relative Magnitudes (Delta^RM)",
"combined": "Combined (Delta^SDRM)",
}
method_display = method_names.get(self.method, self.method)
lines = [
"=" * 70,
"Honest DiD Sensitivity Analysis Results".center(70),
"(Rambachan & Roth 2023)".center(70),
"=" * 70,
"",
f"{'Method:':<30} {method_display}",
f"{'Restriction parameter (M):':<30} {self.M:.4f}",
f"{'CI method:':<30} {self.ci_method}",
"",
"-" * 70,
"Original Estimate (under parallel trends)".center(70),
"-" * 70,
f"{'Point estimate:':<30} {self.original_estimate:.4f}",
f"{'Standard error:':<30} {self.original_se:.4f}",
"",
"-" * 70,
"Robust Results (allowing for violations)".center(70),
"-" * 70,
f"{'Identified set:':<30} [{self.lb:.4f}, {self.ub:.4f}]",
f"{f'{conf_level}% Robust CI:':<30} [{self.ci_lb:.4f}, {self.ci_ub:.4f}]",
"",
f"{'Effect robust to violations:':<30} {'Yes' if self.is_significant else 'No'}",
"",
]
# Interpretation
lines.extend(
[
"-" * 70,
"Interpretation".center(70),
"-" * 70,
]
)
if self.method == "relative_magnitude":
lines.append(
f"Post-treatment violations bounded at {self.M:.1f}x max pre-period violation."
)
elif self.method == "smoothness":
if self.M == 0:
lines.append("Violations follow linear extrapolation of pre-trends.")
else:
lines.append(
f"Violation curvature (second diff) bounded by {self.M:.4f} per period."
)
else:
lines.append(f"Combined smoothness (M={self.M:.2f}) and relative magnitude bounds.")
if self.is_significant:
if self.ci_lb > 0:
lines.append(f"Effect remains POSITIVE even with violations up to M={self.M}.")
else:
lines.append(f"Effect remains NEGATIVE even with violations up to M={self.M}.")
else:
lines.append(f"Cannot rule out zero effect when allowing violations up to M={self.M}.")
lines.extend(["", "=" * 70])
return "\n".join(lines)
[docs]
def print_summary(self) -> None:
"""Print summary to stdout."""
print(self.summary())
[docs]
def to_dict(self) -> Dict[str, Any]:
"""Convert results to dictionary."""
return {
"lb": self.lb,
"ub": self.ub,
"ci_lb": self.ci_lb,
"ci_ub": self.ci_ub,
"M": self.M,
"method": self.method,
"original_estimate": self.original_estimate,
"original_se": self.original_se,
"alpha": self.alpha,
"ci_method": self.ci_method,
"is_significant": self.is_significant,
"identified_set_width": self.identified_set_width,
"ci_width": self.ci_width,
}
[docs]
def to_dataframe(self) -> pd.DataFrame:
"""Convert results to DataFrame."""
return pd.DataFrame([self.to_dict()])
[docs]
@dataclass
class SensitivityResults:
"""
Results from sensitivity analysis over a grid of M values.
Contains bounds and confidence intervals for each M value,
plus the breakdown value.
Attributes
----------
M_values : np.ndarray
Grid of M parameter values.
bounds : List[Tuple[float, float]]
List of (lb, ub) identified set bounds for each M.
robust_cis : List[Tuple[float, float]]
List of (ci_lb, ci_ub) robust CIs for each M.
breakdown_M : float
Smallest M where robust CI includes zero.
method : str
Type of restriction used.
original_estimate : float
Original point estimate.
original_se : float
Original standard error.
alpha : float
Significance level.
"""
M_values: np.ndarray
bounds: List[Tuple[float, float]]
robust_cis: List[Tuple[float, float]]
breakdown_M: Optional[float]
method: str
original_estimate: float
original_se: float
alpha: float = 0.05
def __repr__(self) -> str:
breakdown_str = f"{self.breakdown_M:.4f}" if self.breakdown_M else "None"
return f"SensitivityResults(n_M={len(self.M_values)}, " f"breakdown_M={breakdown_str})"
@property
def has_breakdown(self) -> bool:
"""Check if there is a finite breakdown value."""
return self.breakdown_M is not None
[docs]
def summary(self) -> str:
"""Generate formatted summary."""
lines = [
"=" * 70,
"Honest DiD Sensitivity Analysis".center(70),
"=" * 70,
"",
f"{'Method:':<30} {self.method}",
f"{'Original estimate:':<30} {self.original_estimate:.4f}",
f"{'Original SE:':<30} {self.original_se:.4f}",
f"{'M values tested:':<30} {len(self.M_values)}",
"",
]
if self.breakdown_M is not None:
lines.append(f"{'Breakdown value:':<30} {self.breakdown_M:.4f}")
lines.append("")
lines.append(f"Result is robust to violations up to M = {self.breakdown_M:.4f}")
else:
lines.append(f"{'Breakdown value:':<30} None (always significant)")
lines.extend(
[
"",
"-" * 70,
f"{'M':<10} {'Lower Bound':>12} {'Upper Bound':>12} {'CI Lower':>12} {'CI Upper':>12}",
"-" * 70,
]
)
for i, M in enumerate(self.M_values):
lb, ub = self.bounds[i]
ci_lb, ci_ub = self.robust_cis[i]
lines.append(f"{M:<10.4f} {lb:>12.4f} {ub:>12.4f} {ci_lb:>12.4f} {ci_ub:>12.4f}")
lines.extend(["", "=" * 70])
return "\n".join(lines)
[docs]
def print_summary(self) -> None:
"""Print summary to stdout."""
print(self.summary())
[docs]
def to_dataframe(self) -> pd.DataFrame:
"""Convert to DataFrame with one row per M value."""
rows = []
for i, M in enumerate(self.M_values):
lb, ub = self.bounds[i]
ci_lb, ci_ub = self.robust_cis[i]
rows.append(
{
"M": M,
"lb": lb,
"ub": ub,
"ci_lb": ci_lb,
"ci_ub": ci_ub,
"is_significant": not (ci_lb <= 0 <= ci_ub),
}
)
return pd.DataFrame(rows)
[docs]
def plot(
self,
ax=None,
show_bounds: bool = True,
show_ci: bool = True,
breakdown_line: bool = True,
**kwargs,
):
"""
Plot sensitivity analysis results.
Parameters
----------
ax : matplotlib.axes.Axes, optional
Axes to plot on. If None, creates new figure.
show_bounds : bool
Whether to show identified set bounds.
show_ci : bool
Whether to show confidence intervals.
breakdown_line : bool
Whether to show vertical line at breakdown value.
**kwargs
Additional arguments passed to plotting functions.
Returns
-------
ax : matplotlib.axes.Axes
The axes with the plot.
"""
try:
import matplotlib.pyplot as plt
except ImportError:
raise ImportError("matplotlib is required for plotting")
if ax is None:
fig, ax = plt.subplots(figsize=(10, 6))
M = self.M_values
bounds_arr = np.array(self.bounds)
ci_arr = np.array(self.robust_cis)
# Plot original estimate
ax.axhline(
y=self.original_estimate,
color="black",
linestyle="-",
linewidth=1.5,
label="Original estimate",
alpha=0.7,
)
# Plot zero line
ax.axhline(y=0, color="gray", linestyle="--", linewidth=1, alpha=0.5)
if show_bounds:
ax.fill_between(
M,
bounds_arr[:, 0],
bounds_arr[:, 1],
alpha=0.3,
color="blue",
label="Identified set",
)
if show_ci:
ax.plot(M, ci_arr[:, 0], "b-", linewidth=1.5, label="Robust CI")
ax.plot(M, ci_arr[:, 1], "b-", linewidth=1.5)
if breakdown_line and self.breakdown_M is not None:
ax.axvline(
x=self.breakdown_M,
color="red",
linestyle=":",
linewidth=2,
label=f"Breakdown (M={self.breakdown_M:.2f})",
)
ax.set_xlabel("M (restriction parameter)")
ax.set_ylabel("Treatment Effect")
ax.set_title("Sensitivity Analysis: Treatment Effect Bounds")
ax.legend(loc="best")
return ax
# =============================================================================
# Helper Functions
# =============================================================================
def _extract_event_study_params(
results: Union[MultiPeriodDiDResults, Any],
) -> Tuple[np.ndarray, np.ndarray, int, int, List[Any], List[Any]]:
"""
Extract event study parameters from results objects.
Parameters
----------
results : MultiPeriodDiDResults or CallawaySantAnnaResults
Estimation results with event study structure.
Returns
-------
beta_hat : np.ndarray
Vector of event study coefficients (pre + post periods).
sigma : np.ndarray
Variance-covariance matrix of coefficients.
num_pre_periods : int
Number of pre-treatment periods.
num_post_periods : int
Number of post-treatment periods.
pre_periods : list
Pre-period identifiers.
post_periods : list
Post-period identifiers.
"""
if isinstance(results, MultiPeriodDiDResults):
# Extract from MultiPeriodDiD
pre_periods = results.pre_periods
post_periods = results.post_periods
# Filter periods with finite effects/SEs, maintaining pre-then-post order
finite_periods = {
p
for p in results.period_effects.keys()
if np.isfinite(results.period_effects[p].effect)
and np.isfinite(results.period_effects[p].se)
}
pre_estimated = [p for p in pre_periods if p in finite_periods]
post_estimated = [p for p in post_periods if p in finite_periods]
all_estimated = pre_estimated + post_estimated
if not all_estimated:
raise ValueError(
"No period effects with finite estimates found. " "Cannot compute HonestDiD bounds."
)
effects = [results.period_effects[p].effect for p in all_estimated]
ses = [results.period_effects[p].se for p in all_estimated]
beta_hat = np.array(effects)
num_pre_periods = sum(1 for p in all_estimated if p in pre_periods)
num_post_periods = sum(1 for p in all_estimated if p in post_periods)
if num_pre_periods == 0:
raise ValueError(
"No pre-period effects with finite estimates found. "
"HonestDiD requires at least one identified pre-period "
"coefficient."
)
# Extract proper sub-VCV for interaction terms
if (
results.vcov is not None
and hasattr(results, "interaction_indices")
and results.interaction_indices is not None
):
indices = [results.interaction_indices[p] for p in all_estimated]
sigma = results.vcov[np.ix_(indices, indices)]
else:
# Fallback: diagonal from SEs
sigma = np.diag(np.array(ses) ** 2)
return beta_hat, sigma, num_pre_periods, num_post_periods, pre_periods, post_periods
else:
# Try CallawaySantAnnaResults
try:
from diff_diff.staggered import CallawaySantAnnaResults
if isinstance(results, CallawaySantAnnaResults):
if results.event_study_effects is None:
raise ValueError(
"CallawaySantAnnaResults must have event_study_effects for HonestDiD. "
"Re-run CallawaySantAnna.fit() with aggregate='event_study' to compute "
"event study effects."
)
# Extract event study effects by relative time
# Filter out normalization constraints (n_groups=0) and non-finite SEs
event_effects = {
t: data
for t, data in results.event_study_effects.items()
if data.get("n_groups", 1) > 0 and np.isfinite(data.get("se", np.nan))
}
rel_times = sorted(event_effects.keys())
# Split into pre and post
pre_times = [t for t in rel_times if t < 0]
post_times = [t for t in rel_times if t >= 0]
effects = []
ses = []
for t in rel_times:
effects.append(event_effects[t]["effect"])
ses.append(event_effects[t]["se"])
beta_hat = np.array(effects)
sigma = np.diag(np.array(ses) ** 2)
return (beta_hat, sigma, len(pre_times), len(post_times), pre_times, post_times)
except ImportError:
pass
raise TypeError(
f"Unsupported results type: {type(results)}. "
"Expected MultiPeriodDiDResults or CallawaySantAnnaResults."
)
def _construct_A_sd(num_periods: int) -> np.ndarray:
"""
Construct constraint matrix for smoothness (second differences).
For T periods, creates matrix A such that:
A @ delta gives the second differences.
Parameters
----------
num_periods : int
Number of time periods.
Returns
-------
A : np.ndarray
Constraint matrix of shape (num_periods - 2, num_periods).
"""
if num_periods < 3:
return np.zeros((0, num_periods))
n_constraints = num_periods - 2
A = np.zeros((n_constraints, num_periods))
for i in range(n_constraints):
# Second difference: delta_{t+1} - 2*delta_t + delta_{t-1}
A[i, i] = 1 # delta_{t-1}
A[i, i + 1] = -2 # delta_t
A[i, i + 2] = 1 # delta_{t+1}
return A
def _construct_constraints_sd(
num_pre_periods: int, num_post_periods: int, M: float
) -> Tuple[np.ndarray, np.ndarray]:
"""
Construct smoothness constraint matrices.
Returns A, b such that delta in DeltaSD iff |A @ delta| <= b.
Parameters
----------
num_pre_periods : int
Number of pre-treatment periods.
num_post_periods : int
Number of post-treatment periods.
M : float
Smoothness parameter.
Returns
-------
A_ineq : np.ndarray
Inequality constraint matrix.
b_ineq : np.ndarray
Inequality constraint vector.
"""
total_periods = num_pre_periods + num_post_periods
A_base = _construct_A_sd(total_periods)
if A_base.shape[0] == 0:
return np.zeros((0, total_periods)), np.zeros(0)
# |A @ delta| <= M becomes:
# A @ delta <= M and -A @ delta <= M
A_ineq = np.vstack([A_base, -A_base])
b_ineq = np.full(2 * A_base.shape[0], M)
return A_ineq, b_ineq
def _construct_constraints_rm(
num_pre_periods: int, num_post_periods: int, Mbar: float, max_pre_violation: float
) -> Tuple[np.ndarray, np.ndarray]:
"""
Construct relative magnitudes constraint matrices.
Parameters
----------
num_pre_periods : int
Number of pre-treatment periods.
num_post_periods : int
Number of post-treatment periods.
Mbar : float
Relative magnitude scaling factor.
max_pre_violation : float
Maximum absolute pre-period violation (estimated from data).
Returns
-------
A_ineq : np.ndarray
Inequality constraint matrix.
b_ineq : np.ndarray
Inequality constraint vector.
"""
total_periods = num_pre_periods + num_post_periods
# Bound post-period violations: |delta_post| <= Mbar * max_pre_violation
bound = Mbar * max_pre_violation
# Create constraints for each post-period
# delta_post[i] <= bound and -delta_post[i] <= bound
n_constraints = 2 * num_post_periods
A_ineq = np.zeros((n_constraints, total_periods))
b_ineq = np.full(n_constraints, bound)
for i in range(num_post_periods):
post_idx = num_pre_periods + i
A_ineq[2 * i, post_idx] = 1 # delta <= bound
A_ineq[2 * i + 1, post_idx] = -1 # -delta <= bound
return A_ineq, b_ineq
def _solve_bounds_lp(
beta_post: np.ndarray,
l_vec: np.ndarray,
A_ineq: np.ndarray,
b_ineq: np.ndarray,
num_pre_periods: int,
lp_method: str = "highs",
) -> Tuple[float, float]:
"""
Solve for identified set bounds using linear programming.
The parameter of interest is theta = l' @ (beta_post - delta_post).
We find min and max over delta in the constraint set.
Note: The optimization is over delta for ALL periods (pre + post), but
only the post-period components contribute to the objective function.
This correctly handles smoothness constraints that link pre and post periods.
Parameters
----------
beta_post : np.ndarray
Post-period coefficient estimates.
l_vec : np.ndarray
Weighting vector for aggregation.
A_ineq : np.ndarray
Inequality constraint matrix (for all periods).
b_ineq : np.ndarray
Inequality constraint vector.
num_pre_periods : int
Number of pre-periods (for indexing).
lp_method : str
LP solver method for scipy.optimize.linprog. Default 'highs' requires
scipy >= 1.6.0. Alternatives: 'interior-point', 'revised simplex'.
Returns
-------
lb : float
Lower bound.
ub : float
Upper bound.
"""
num_post = len(beta_post)
total_periods = A_ineq.shape[1] if A_ineq.shape[0] > 0 else num_pre_periods + num_post
# theta = l' @ beta_post - l' @ delta_post
# We optimize over delta (all periods including pre for smoothness constraints)
# Extract post-period part of constraints
# For delta in R^total_periods, we want min/max of -l' @ delta_post
# where delta_post = delta[num_pre_periods:]
c = np.zeros(total_periods)
c[num_pre_periods : num_pre_periods + num_post] = -l_vec # min -l'@delta = max l'@delta
# For upper bound: max l'@(beta - delta) = l'@beta + max(-l'@delta)
# For lower bound: min l'@(beta - delta) = l'@beta + min(-l'@delta)
if A_ineq.shape[0] == 0:
# No constraints - unbounded
return -np.inf, np.inf
# Solve for lower bound of -l'@delta (which gives upper bound of theta)
try:
result_min = optimize.linprog(
c, A_ub=A_ineq, b_ub=b_ineq, bounds=(None, None), method=lp_method
)
if result_min.success:
min_val = result_min.fun
else:
min_val = -np.inf
except (ValueError, TypeError):
# Optimization failed - return unbounded
min_val = -np.inf
# Solve for upper bound of -l'@delta (which gives lower bound of theta)
try:
result_max = optimize.linprog(
-c, A_ub=A_ineq, b_ub=b_ineq, bounds=(None, None), method=lp_method
)
if result_max.success:
max_val = -result_max.fun
else:
max_val = np.inf
except (ValueError, TypeError):
# Optimization failed - return unbounded
max_val = np.inf
theta_base = np.dot(l_vec, beta_post)
lb = theta_base + min_val # = l'@beta + min(-l'@delta) = min(l'@(beta-delta))
ub = theta_base + max_val # = l'@beta + max(-l'@delta) = max(l'@(beta-delta))
return lb, ub
def _compute_flci(lb: float, ub: float, se: float, alpha: float = 0.05) -> Tuple[float, float]:
"""
Compute Fixed Length Confidence Interval (FLCI).
The FLCI extends the identified set by a critical value times
the standard error on each side.
Parameters
----------
lb : float
Lower bound of identified set.
ub : float
Upper bound of identified set.
se : float
Standard error of the estimator.
alpha : float
Significance level.
Returns
-------
ci_lb : float
Lower bound of confidence interval.
ci_ub : float
Upper bound of confidence interval.
Raises
------
ValueError
If se <= 0 or alpha is not in (0, 1).
"""
if se <= 0:
raise ValueError(f"Standard error must be positive, got se={se}")
if not (0 < alpha < 1):
raise ValueError(f"alpha must be between 0 and 1, got alpha={alpha}")
z = stats.norm.ppf(1 - alpha / 2)
ci_lb = lb - z * se
ci_ub = ub + z * se
return ci_lb, ci_ub
def _compute_clf_ci(
beta_post: np.ndarray,
sigma_post: np.ndarray,
l_vec: np.ndarray,
Mbar: float,
max_pre_violation: float,
alpha: float = 0.05,
n_draws: int = 1000,
) -> Tuple[float, float, float, float]:
"""
Compute Conditional Least Favorable (C-LF) confidence interval.
For relative magnitudes, accounts for estimation of max_pre_violation.
Parameters
----------
beta_post : np.ndarray
Post-period coefficient estimates.
sigma_post : np.ndarray
Variance-covariance matrix for post-period coefficients.
l_vec : np.ndarray
Weighting vector.
Mbar : float
Relative magnitude parameter.
max_pre_violation : float
Estimated max pre-period violation.
alpha : float
Significance level.
n_draws : int
Number of Monte Carlo draws for conditional CI.
Returns
-------
lb : float
Lower bound of identified set.
ub : float
Upper bound of identified set.
ci_lb : float
Lower bound of confidence interval.
ci_ub : float
Upper bound of confidence interval.
"""
# For simplicity, use FLCI approach with adjustment for estimation uncertainty
# A full implementation would condition on the estimated max_pre_violation
theta = np.dot(l_vec, beta_post)
se = np.sqrt(l_vec @ sigma_post @ l_vec)
bound = Mbar * max_pre_violation
# Simple bounds: theta +/- bound
lb = theta - bound
ub = theta + bound
# CI with estimation uncertainty
z = stats.norm.ppf(1 - alpha / 2)
ci_lb = lb - z * se
ci_ub = ub + z * se
return lb, ub, ci_lb, ci_ub
# =============================================================================
# Main Class
# =============================================================================
[docs]
class HonestDiD:
"""
Honest DiD sensitivity analysis (Rambachan & Roth 2023).
Computes robust inference for difference-in-differences allowing
for bounded violations of parallel trends.
Parameters
----------
method : {"smoothness", "relative_magnitude", "combined"}
Type of restriction on trend violations:
- "smoothness": Bounds on second differences (Delta^SD)
- "relative_magnitude": Post violations <= M * max pre violation (Delta^RM)
- "combined": Both restrictions (Delta^SDRM)
M : float, optional
Restriction parameter. Interpretation depends on method:
- smoothness: Max second difference
- relative_magnitude: Scaling factor for max pre-period violation
Default is 1.0 for relative_magnitude, 0.0 for smoothness.
alpha : float
Significance level for confidence intervals.
l_vec : array-like or None
Weighting vector for scalar parameter (length = num_post_periods).
If None, uses uniform weights (average effect).
Examples
--------
>>> from diff_diff import MultiPeriodDiD
>>> from diff_diff.honest_did import HonestDiD
>>>
>>> # Fit event study
>>> mp_did = MultiPeriodDiD()
>>> results = mp_did.fit(data, outcome='y', treatment='treated',
... time='period', post_periods=[4,5,6,7])
>>>
>>> # Sensitivity analysis with relative magnitudes
>>> honest = HonestDiD(method='relative_magnitude', M=1.0)
>>> bounds = honest.fit(results)
>>> print(bounds.summary())
>>>
>>> # Sensitivity curve over M values
>>> sensitivity = honest.sensitivity_analysis(results, M_grid=[0, 0.5, 1, 1.5, 2])
>>> sensitivity.plot()
"""
[docs]
def __init__(
self,
method: Literal["smoothness", "relative_magnitude", "combined"] = "relative_magnitude",
M: Optional[float] = None,
alpha: float = 0.05,
l_vec: Optional[np.ndarray] = None,
):
self.method = method
self.alpha = alpha
self.l_vec = l_vec
# Set default M based on method
if M is None:
self.M = 1.0 if method == "relative_magnitude" else 0.0
else:
self.M = M
self._validate_params()
def _validate_params(self):
"""Validate initialization parameters."""
if self.method not in ["smoothness", "relative_magnitude", "combined"]:
raise ValueError(
f"method must be 'smoothness', 'relative_magnitude', or 'combined', "
f"got method='{self.method}'"
)
if self.M < 0:
raise ValueError(f"M must be non-negative, got M={self.M}")
if not 0 < self.alpha < 1:
raise ValueError(f"alpha must be between 0 and 1, got alpha={self.alpha}")
[docs]
def get_params(self) -> Dict[str, Any]:
"""Get parameters for this estimator."""
return {
"method": self.method,
"M": self.M,
"alpha": self.alpha,
"l_vec": self.l_vec,
}
[docs]
def set_params(self, **params) -> "HonestDiD":
"""Set parameters for this estimator."""
for key, value in params.items():
if hasattr(self, key):
setattr(self, key, value)
else:
raise ValueError(f"Invalid parameter: {key}")
self._validate_params()
return self
[docs]
def fit(
self,
results: Union[MultiPeriodDiDResults, Any],
M: Optional[float] = None,
) -> HonestDiDResults:
"""
Compute bounds and robust confidence intervals.
Parameters
----------
results : MultiPeriodDiDResults or CallawaySantAnnaResults
Results from event study estimation.
M : float, optional
Override the M parameter for this fit.
Returns
-------
HonestDiDResults
Results containing bounds and robust confidence intervals.
"""
M = M if M is not None else self.M
# Extract event study parameters
(beta_hat, sigma, num_pre, num_post, pre_periods, post_periods) = (
_extract_event_study_params(results)
)
# beta_hat contains [pre-period effects, post-period effects] in order.
# Extract just the post-period effects for HonestDiD bounds.
if len(beta_hat) == num_post:
# Already just post-period effects
beta_post = beta_hat
elif len(beta_hat) == num_pre + num_post:
# Full event study, extract post-periods
beta_post = beta_hat[num_pre:]
else:
# Assume it's post-period effects
beta_post = beta_hat
num_post = len(beta_hat)
# Handle sigma extraction for post periods
if sigma.shape[0] == num_post and sigma.shape[0] == len(beta_post):
sigma_post = sigma
elif sigma.shape[0] == num_pre + num_post:
sigma_post = sigma[num_pre:, num_pre:]
else:
# Construct diagonal from available dimensions
sigma_post = sigma[: len(beta_post), : len(beta_post)]
# Update num_post to match actual data
num_post = len(beta_post)
if num_post == 0:
raise ValueError(
"No post-period effects with finite estimates found. "
"HonestDiD requires at least one identified post-period "
"coefficient to compute bounds."
)
# Set up weighting vector
if self.l_vec is None:
l_vec = np.ones(num_post) / num_post # Uniform weights
else:
l_vec = np.asarray(self.l_vec)
if len(l_vec) != num_post:
raise ValueError(f"l_vec must have length {num_post}, got {len(l_vec)}")
# Compute original estimate and SE
original_estimate = np.dot(l_vec, beta_post)
original_se = np.sqrt(l_vec @ sigma_post @ l_vec)
# Compute bounds based on method
if self.method == "smoothness":
lb, ub, ci_lb, ci_ub = self._compute_smoothness_bounds(
beta_post, sigma_post, l_vec, num_pre, num_post, M
)
ci_method = "FLCI"
elif self.method == "relative_magnitude":
lb, ub, ci_lb, ci_ub = self._compute_rm_bounds(
beta_post, sigma_post, l_vec, num_pre, num_post, M, pre_periods, results
)
ci_method = "C-LF"
else: # combined
lb, ub, ci_lb, ci_ub = self._compute_combined_bounds(
beta_post, sigma_post, l_vec, num_pre, num_post, M, pre_periods, results
)
ci_method = "FLCI"
return HonestDiDResults(
lb=lb,
ub=ub,
ci_lb=ci_lb,
ci_ub=ci_ub,
M=M,
method=self.method,
original_estimate=original_estimate,
original_se=original_se,
alpha=self.alpha,
ci_method=ci_method,
original_results=results,
)
def _compute_smoothness_bounds(
self,
beta_post: np.ndarray,
sigma_post: np.ndarray,
l_vec: np.ndarray,
num_pre: int,
num_post: int,
M: float,
) -> Tuple[float, float, float, float]:
"""Compute bounds under smoothness restriction."""
# Construct constraints
A_ineq, b_ineq = _construct_constraints_sd(num_pre, num_post, M)
# Solve for bounds
lb, ub = _solve_bounds_lp(beta_post, l_vec, A_ineq, b_ineq, num_pre)
# Compute FLCI
se = np.sqrt(l_vec @ sigma_post @ l_vec)
ci_lb, ci_ub = _compute_flci(lb, ub, se, self.alpha)
return lb, ub, ci_lb, ci_ub
def _compute_rm_bounds(
self,
beta_post: np.ndarray,
sigma_post: np.ndarray,
l_vec: np.ndarray,
num_pre: int,
num_post: int,
Mbar: float,
pre_periods: List,
results: Any,
) -> Tuple[float, float, float, float]:
"""Compute bounds under relative magnitudes restriction."""
# Estimate max pre-period violation from pre-trends
# For relative magnitudes, we use the pre-period coefficients
max_pre_violation = self._estimate_max_pre_violation(results, pre_periods)
if max_pre_violation == 0:
# No pre-period violations detected - use point estimate
theta = np.dot(l_vec, beta_post)
se = np.sqrt(l_vec @ sigma_post @ l_vec)
z = stats.norm.ppf(1 - self.alpha / 2)
return theta, theta, theta - z * se, theta + z * se
# Compute bounds
lb, ub, ci_lb, ci_ub = _compute_clf_ci(
beta_post, sigma_post, l_vec, Mbar, max_pre_violation, self.alpha
)
return lb, ub, ci_lb, ci_ub
def _compute_combined_bounds(
self,
beta_post: np.ndarray,
sigma_post: np.ndarray,
l_vec: np.ndarray,
num_pre: int,
num_post: int,
M: float,
pre_periods: List,
results: Any,
) -> Tuple[float, float, float, float]:
"""Compute bounds under combined smoothness + RM restriction."""
# Get smoothness bounds
lb_sd, ub_sd, _, _ = self._compute_smoothness_bounds(
beta_post, sigma_post, l_vec, num_pre, num_post, M
)
# Get RM bounds (use M as Mbar for combined)
lb_rm, ub_rm, _, _ = self._compute_rm_bounds(
beta_post, sigma_post, l_vec, num_pre, num_post, M, pre_periods, results
)
# Combined bounds are intersection
lb = max(lb_sd, lb_rm)
ub = min(ub_sd, ub_rm)
# If bounds cross, use the original estimate
if lb > ub:
theta = np.dot(l_vec, beta_post)
lb = ub = theta
# Compute FLCI on combined bounds
se = np.sqrt(l_vec @ sigma_post @ l_vec)
ci_lb, ci_ub = _compute_flci(lb, ub, se, self.alpha)
return lb, ub, ci_lb, ci_ub
def _estimate_max_pre_violation(self, results: Any, pre_periods: List) -> float:
"""
Estimate the maximum pre-period violation.
Uses pre-period coefficients if available, otherwise returns
a default based on the overall SE.
"""
if isinstance(results, MultiPeriodDiDResults):
# Pre-period effects are now in period_effects directly
# Filter out non-finite effects (e.g. from rank-deficient designs)
pre_effects = [
abs(results.period_effects[p].effect)
for p in pre_periods
if p in results.period_effects and np.isfinite(results.period_effects[p].effect)
]
if pre_effects:
return max(pre_effects)
# Fallback: use avg_se as a scale
return results.avg_se
# For CallawaySantAnna, use pre-period event study effects
try:
from diff_diff.staggered import CallawaySantAnnaResults
if isinstance(results, CallawaySantAnnaResults):
if results.event_study_effects:
# Filter out normalization constraints (n_groups=0, e.g. reference period)
pre_effects = [
abs(results.event_study_effects[t]["effect"])
for t in results.event_study_effects
if t < 0 and results.event_study_effects[t].get("n_groups", 1) > 0
]
if pre_effects:
return max(pre_effects)
return results.overall_se
except ImportError:
pass
# Default fallback
return 0.1
[docs]
def sensitivity_analysis(
self,
results: Union[MultiPeriodDiDResults, Any],
M_grid: Optional[List[float]] = None,
) -> SensitivityResults:
"""
Perform sensitivity analysis over a grid of M values.
Parameters
----------
results : MultiPeriodDiDResults or CallawaySantAnnaResults
Results from event study estimation.
M_grid : list of float, optional
Grid of M values to evaluate. If None, uses default grid
based on method.
Returns
-------
SensitivityResults
Results containing bounds and CIs for each M value.
"""
if M_grid is None:
if self.method == "relative_magnitude":
M_grid = [0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0]
else:
M_grid = [0, 0.1, 0.2, 0.3, 0.5, 0.75, 1.0]
M_values = np.array(M_grid)
bounds_list = []
ci_list = []
for M in M_values:
result = self.fit(results, M=M)
bounds_list.append((result.lb, result.ub))
ci_list.append((result.ci_lb, result.ci_ub))
# Find breakdown value
breakdown_M = self._find_breakdown(results, M_values, ci_list)
# Get original estimate info
first_result = self.fit(results, M=0)
return SensitivityResults(
M_values=M_values,
bounds=bounds_list,
robust_cis=ci_list,
breakdown_M=breakdown_M,
method=self.method,
original_estimate=first_result.original_estimate,
original_se=first_result.original_se,
alpha=self.alpha,
)
def _find_breakdown(
self, results: Any, M_values: np.ndarray, ci_list: List[Tuple[float, float]]
) -> Optional[float]:
"""
Find the breakdown value where CI first includes zero.
Uses binary search for precision.
"""
# Check if any CI includes zero
includes_zero = [ci_lb <= 0 <= ci_ub for ci_lb, ci_ub in ci_list]
if not any(includes_zero):
# Always significant - no breakdown
return None
if all(includes_zero):
# Never significant - breakdown at 0
return 0.0
# Find first transition point
for i, (inc, M) in enumerate(zip(includes_zero, M_values)):
if inc and (i == 0 or not includes_zero[i - 1]):
# Binary search between M_values[i-1] and M_values[i]
if i == 0:
return 0.0
lo, hi = M_values[i - 1], M_values[i]
for _ in range(20): # 20 iterations for precision
mid = (lo + hi) / 2
result = self.fit(results, M=mid)
if result.ci_lb <= 0 <= result.ci_ub:
hi = mid
else:
lo = mid
return (lo + hi) / 2
return None
[docs]
def breakdown_value(
self, results: Union[MultiPeriodDiDResults, Any], tol: float = 0.01
) -> Optional[float]:
"""
Find the breakdown value directly using binary search.
The breakdown value is the smallest M where the robust
confidence interval includes zero.
Parameters
----------
results : MultiPeriodDiDResults or CallawaySantAnnaResults
Results from event study estimation.
tol : float
Tolerance for binary search.
Returns
-------
float or None
Breakdown value, or None if effect is always significant.
"""
# Check at M=0
result_0 = self.fit(results, M=0)
if result_0.ci_lb <= 0 <= result_0.ci_ub:
return 0.0
# Check if significant even for large M
result_large = self.fit(results, M=10)
if not (result_large.ci_lb <= 0 <= result_large.ci_ub):
return None # Always significant
# Binary search
lo, hi = 0.0, 10.0
while hi - lo > tol:
mid = (lo + hi) / 2
result = self.fit(results, M=mid)
if result.ci_lb <= 0 <= result.ci_ub:
hi = mid
else:
lo = mid
return (lo + hi) / 2
# =============================================================================
# Convenience Functions
# =============================================================================
[docs]
def compute_honest_did(
results: Union[MultiPeriodDiDResults, Any],
method: str = "relative_magnitude",
M: float = 1.0,
alpha: float = 0.05,
) -> HonestDiDResults:
"""
Convenience function for computing Honest DiD bounds.
Parameters
----------
results : MultiPeriodDiDResults or CallawaySantAnnaResults
Results from event study estimation.
method : str
Type of restriction ("smoothness", "relative_magnitude", "combined").
M : float
Restriction parameter.
alpha : float
Significance level.
Returns
-------
HonestDiDResults
Bounds and robust confidence intervals.
Examples
--------
>>> bounds = compute_honest_did(event_study_results, method='relative_magnitude', M=1.0)
>>> print(f"Robust CI: [{bounds.ci_lb:.3f}, {bounds.ci_ub:.3f}]")
"""
honest = HonestDiD(method=method, M=M, alpha=alpha)
return honest.fit(results)
[docs]
def sensitivity_plot(
results: Union[MultiPeriodDiDResults, Any],
method: str = "relative_magnitude",
M_grid: Optional[List[float]] = None,
alpha: float = 0.05,
ax=None,
**kwargs,
):
"""
Create a sensitivity analysis plot.
Parameters
----------
results : MultiPeriodDiDResults or CallawaySantAnnaResults
Results from event study estimation.
method : str
Type of restriction.
M_grid : list of float, optional
Grid of M values.
alpha : float
Significance level.
ax : matplotlib.axes.Axes, optional
Axes to plot on.
**kwargs
Additional arguments passed to plot method.
Returns
-------
ax : matplotlib.axes.Axes
The axes with the plot.
"""
honest = HonestDiD(method=method, alpha=alpha)
sensitivity = honest.sensitivity_analysis(results, M_grid=M_grid)
return sensitivity.plot(ax=ax, **kwargs)