"""
Power analysis tools for difference-in-differences study design.
This module provides power calculations and simulation-based power analysis
for DiD study design, helping practitioners answer questions like:
- "How many units do I need to detect an effect of size X?"
- "What is the minimum detectable effect given my sample size?"
- "What power do I have to detect a given effect?"
References
----------
Bloom, H. S. (1995). "Minimum Detectable Effects: A Simple Way to Report the
Statistical Power of Experimental Designs." Evaluation Review, 19(5), 547-556.
Burlig, F., Preonas, L., & Woerman, M. (2020). "Panel Data and Experimental Design."
Journal of Development Economics, 144, 102458.
Djimeu, E. W., & Houndolo, D.-G. (2016). "Power Calculation for Causal Inference
in Social Science: Sample Size and Minimum Detectable Effect Determination."
Journal of Development Effectiveness, 8(4), 508-527.
"""
import warnings
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
from scipy import stats
# Maximum sample size returned when effect is too small to detect
# (e.g., zero effect or extremely small relative to noise)
MAX_SAMPLE_SIZE = 2**31 - 1
# ---------------------------------------------------------------------------
# Estimator registry — maps estimator class names to DGP/fit/extract profiles
# ---------------------------------------------------------------------------
@dataclass
class _EstimatorProfile:
"""Internal profile describing how to run power simulations for an estimator."""
default_dgp: Callable
dgp_kwargs_builder: Callable
fit_kwargs_builder: Callable
result_extractor: Callable
min_n: int = 20
# ---------------------------------------------------------------------------
# SurveyPowerConfig — carries DGP survey params for simulation power
# ---------------------------------------------------------------------------
@dataclass
class SurveyPowerConfig:
"""Configuration for survey-aware power simulations.
When passed to :func:`simulate_power`, :func:`simulate_mde`, or
:func:`simulate_sample_size`, the simulation loop generates data with
:func:`~diff_diff.prep.generate_survey_did_data` and automatically
injects a ``SurveyDesign`` into the estimator's ``fit()`` call.
Parameters
----------
n_strata : int, default=5
Number of geographic strata.
psu_per_stratum : int, default=8
Number of primary sampling units (PSUs) per stratum. Must be >= 2
for Taylor Series Linearization variance estimation.
fpc_per_stratum : float, default=200.0
Finite population correction (total PSUs per stratum).
weight_variation : str, default="moderate"
Sampling weight dispersion: ``"none"`` (all equal), ``"moderate"``
(range ~1-2), ``"high"`` (range ~1-4).
psu_re_sd : float, default=2.0
Standard deviation of PSU random effects. Controls intra-cluster
correlation and drives DEFF > 1.
psu_period_factor : float, default=0.5
Multiplier for PSU-period interaction shocks.
icc : float, optional
Target intra-class correlation (0 < icc < 1). Overrides
``psu_re_sd`` via variance decomposition.
weight_cv : float, optional
Target coefficient of variation for weights. Overrides
``weight_variation``.
informative_sampling : bool, default=False
If True, weights correlate with Y(0).
heterogeneous_te_by_strata : bool, default=False
If True, treatment effect varies by stratum.
include_replicate_weights : bool, default=False
If True, add JK1 delete-one-PSU replicate weight columns.
survey_design : SurveyDesign, optional
Override the auto-built SurveyDesign. When None, a default
``SurveyDesign(weights="weight", strata="stratum", psu="psu",
fpc="fpc")`` is used, matching ``generate_survey_did_data`` output.
Examples
--------
>>> from diff_diff import CallawaySantAnna, simulate_power, SurveyPowerConfig
>>> config = SurveyPowerConfig(n_strata=5, psu_per_stratum=8, icc=0.05)
>>> results = simulate_power(
... CallawaySantAnna(),
... n_units=200,
... treatment_effect=2.0,
... survey_config=config,
... n_simulations=100,
... seed=42,
... )
"""
n_strata: int = 5
psu_per_stratum: int = 8
fpc_per_stratum: float = 200.0
weight_variation: str = "moderate"
psu_re_sd: float = 2.0
psu_period_factor: float = 0.5
icc: Optional[float] = None
weight_cv: Optional[float] = None
informative_sampling: bool = False
heterogeneous_te_by_strata: bool = False
include_replicate_weights: bool = False
survey_design: Optional[Any] = None
def __post_init__(self) -> None:
if self.n_strata < 1:
raise ValueError(f"n_strata must be >= 1, got {self.n_strata}")
if self.psu_per_stratum < 2:
raise ValueError(
f"psu_per_stratum must be >= 2 for TSL variance estimation, "
f"got {self.psu_per_stratum}"
)
if self.weight_variation not in ("none", "moderate", "high"):
raise ValueError(
f"weight_variation must be 'none', 'moderate', or 'high', "
f"got '{self.weight_variation}'"
)
if not np.isfinite(self.psu_re_sd) or self.psu_re_sd < 0:
raise ValueError(f"psu_re_sd must be finite and >= 0, got {self.psu_re_sd}")
if not np.isfinite(self.fpc_per_stratum):
raise ValueError(f"fpc_per_stratum must be finite, got {self.fpc_per_stratum}")
if self.icc is not None and not (0 < self.icc < 1):
raise ValueError(f"icc must be between 0 and 1 (exclusive), got {self.icc}")
if self.icc is not None and self.psu_re_sd != 2.0:
raise ValueError(
"Cannot specify both icc and a non-default psu_re_sd. "
"icc overrides psu_re_sd via the ICC formula."
)
if self.weight_cv is not None:
if not np.isfinite(self.weight_cv) or self.weight_cv <= 0:
raise ValueError(f"weight_cv must be finite and > 0, got {self.weight_cv}")
if self.weight_variation != "moderate":
raise ValueError(
"Cannot specify both weight_cv and a non-default "
"weight_variation. weight_cv overrides weight_variation."
)
if not np.isfinite(self.psu_period_factor) or self.psu_period_factor < 0:
raise ValueError(
f"psu_period_factor must be finite and >= 0, got {self.psu_period_factor}"
)
if self.fpc_per_stratum < self.psu_per_stratum:
raise ValueError(
f"fpc_per_stratum ({self.fpc_per_stratum}) must be >= "
f"psu_per_stratum ({self.psu_per_stratum})"
)
def _build_survey_design(self) -> Any:
"""Return a SurveyDesign for this config.
Reflects the live ``self.survey_design`` value every call (no
caching). Finding #28 (axis J, silent-failures audit): the
previous ``_cached_survey_design`` was populated on first call
and never invalidated on mutation, so ``config.survey_design =
other_design`` silently kept returning the original. Since the
default ``SurveyDesign(...)`` construction is microseconds and
user-provided designs are just reference copies, there's no cache
cost worth keeping.
"""
if self.survey_design is not None:
return self.survey_design
from diff_diff.survey import SurveyDesign
return SurveyDesign(
weights="weight", strata="stratum", psu="psu", fpc="fpc"
)
@property
def min_viable_n(self) -> int:
"""Minimum n_units for a viable survey design (>= 2 units per PSU)."""
return self.n_strata * self.psu_per_stratum * 2
# -- DGP kwargs adapters -----------------------------------------------------
def _basic_dgp_kwargs(
n_units: int,
n_periods: int,
treatment_effect: float,
treatment_fraction: float,
treatment_period: int,
sigma: float,
) -> Dict[str, Any]:
return dict(
n_units=n_units,
n_periods=n_periods,
treatment_effect=treatment_effect,
treatment_fraction=treatment_fraction,
treatment_period=treatment_period,
noise_sd=sigma,
)
def _staggered_dgp_kwargs(
n_units: int,
n_periods: int,
treatment_effect: float,
treatment_fraction: float,
treatment_period: int,
sigma: float,
) -> Dict[str, Any]:
return dict(
n_units=n_units,
n_periods=n_periods,
treatment_effect=treatment_effect,
never_treated_frac=1 - treatment_fraction,
cohort_periods=[treatment_period],
dynamic_effects=False,
noise_sd=sigma,
)
def _factor_dgp_kwargs(
n_units: int,
n_periods: int,
treatment_effect: float,
treatment_fraction: float,
treatment_period: int,
sigma: float,
) -> Dict[str, Any]:
n_pre = treatment_period
n_post = n_periods - treatment_period
return dict(
n_units=n_units,
n_pre=n_pre,
n_post=n_post,
n_treated=max(1, int(n_units * treatment_fraction)),
treatment_effect=treatment_effect,
noise_sd=sigma,
)
def _ddd_dgp_kwargs(
n_units: int,
n_periods: int,
treatment_effect: float,
treatment_fraction: float,
treatment_period: int,
sigma: float,
) -> Dict[str, Any]:
return dict(
n_per_cell=max(2, n_units // 8),
treatment_effect=treatment_effect,
noise_sd=sigma,
)
# -- Fit kwargs builders ------------------------------------------------------
def _basic_fit_kwargs(
data: pd.DataFrame,
n_units: int,
n_periods: int,
treatment_period: int,
) -> Dict[str, Any]:
return dict(outcome="outcome", treatment="treated", time="post")
def _twfe_fit_kwargs(
data: pd.DataFrame,
n_units: int,
n_periods: int,
treatment_period: int,
) -> Dict[str, Any]:
return dict(outcome="outcome", treatment="treated", time="post", unit="unit")
def _multiperiod_fit_kwargs(
data: pd.DataFrame,
n_units: int,
n_periods: int,
treatment_period: int,
) -> Dict[str, Any]:
return dict(
outcome="outcome",
treatment="treated",
time="period",
post_periods=list(range(treatment_period, n_periods)),
)
def _staggered_fit_kwargs(
data: pd.DataFrame,
n_units: int,
n_periods: int,
treatment_period: int,
) -> Dict[str, Any]:
return dict(outcome="outcome", unit="unit", time="period", first_treat="first_treat")
def _ddd_fit_kwargs(
data: pd.DataFrame,
n_units: int,
n_periods: int,
treatment_period: int,
) -> Dict[str, Any]:
return dict(outcome="outcome", group="group", partition="partition", time="time")
def _trop_fit_kwargs(
data: pd.DataFrame,
n_units: int,
n_periods: int,
treatment_period: int,
) -> Dict[str, Any]:
return dict(outcome="outcome", treatment="treated", unit="unit", time="period")
def _sdid_fit_kwargs(
data: pd.DataFrame,
n_units: int,
n_periods: int,
treatment_period: int,
) -> Dict[str, Any]:
periods = sorted(data["period"].unique())
post_periods = [p for p in periods if p >= treatment_period]
return dict(
outcome="outcome",
treatment="treat",
unit="unit",
time="period",
post_periods=post_periods,
)
# -- Survey-aware DGP kwargs adapter ------------------------------------------
def _survey_dgp_kwargs(
n_units: int,
n_periods: int,
treatment_effect: float,
treatment_fraction: float,
treatment_period: int,
sigma: float,
survey_config: SurveyPowerConfig,
) -> Dict[str, Any]:
"""Build kwargs for generate_survey_did_data from simulate_power params."""
return dict(
n_units=n_units,
n_periods=n_periods,
treatment_effect=treatment_effect,
never_treated_frac=1 - treatment_fraction,
# 0-indexed treatment_period → 1-indexed cohort_periods
cohort_periods=[treatment_period + 1],
noise_sd=sigma,
dynamic_effects=False,
n_strata=survey_config.n_strata,
psu_per_stratum=survey_config.psu_per_stratum,
fpc_per_stratum=survey_config.fpc_per_stratum,
weight_variation=survey_config.weight_variation,
psu_re_sd=survey_config.psu_re_sd,
psu_period_factor=survey_config.psu_period_factor,
icc=survey_config.icc,
weight_cv=survey_config.weight_cv,
informative_sampling=survey_config.informative_sampling,
heterogeneous_te_by_strata=survey_config.heterogeneous_te_by_strata,
include_replicate_weights=survey_config.include_replicate_weights,
return_true_population_att=True,
)
# -- Survey-aware fit kwargs builders -----------------------------------------
def _survey_basic_fit_kwargs(
data: pd.DataFrame,
n_units: int,
n_periods: int,
treatment_period: int,
survey_config: SurveyPowerConfig,
) -> Dict[str, Any]:
"""Fit kwargs for DifferenceInDifferences with survey design.
Uses ``ever_treated`` (time-invariant group indicator) rather than the
survey DGP's ``treated`` column (which is post-only: 1{g>0, t>=g}).
DifferenceInDifferences internally constructs ``treatment * time``,
so passing the post-only flag would make that interaction rank-deficient.
"""
return dict(
outcome="outcome",
treatment="ever_treated",
time="post",
survey_design=survey_config._build_survey_design(),
)
def _survey_twfe_fit_kwargs(
data: pd.DataFrame,
n_units: int,
n_periods: int,
treatment_period: int,
survey_config: SurveyPowerConfig,
) -> Dict[str, Any]:
"""Fit kwargs for TwoWayFixedEffects with survey design."""
return dict(
outcome="outcome",
treatment="ever_treated",
time="post",
unit="unit",
survey_design=survey_config._build_survey_design(),
)
def _survey_multiperiod_fit_kwargs(
data: pd.DataFrame,
n_units: int,
n_periods: int,
treatment_period: int,
survey_config: SurveyPowerConfig,
) -> Dict[str, Any]:
"""Fit kwargs for MultiPeriodDiD with survey design (1-indexed periods)."""
return dict(
outcome="outcome",
treatment="ever_treated",
unit="unit",
time="period",
# 1-indexed: post periods run from treatment_period+1 to n_periods
post_periods=list(range(treatment_period + 1, n_periods + 1)),
survey_design=survey_config._build_survey_design(),
)
def _survey_staggered_fit_kwargs(
data: pd.DataFrame,
n_units: int,
n_periods: int,
treatment_period: int,
survey_config: SurveyPowerConfig,
) -> Dict[str, Any]:
"""Fit kwargs for staggered estimators (CS, SA, etc.) with survey design."""
return dict(
outcome="outcome",
unit="unit",
time="period",
first_treat="first_treat",
survey_design=survey_config._build_survey_design(),
)
# -- Result extractors --------------------------------------------------------
def _extract_simple(result: Any) -> Tuple[float, float, float, Tuple[float, float]]:
return (result.att, result.se, result.p_value, result.conf_int)
def _extract_multiperiod(
result: Any,
) -> Tuple[float, float, float, Tuple[float, float]]:
return (result.avg_att, result.avg_se, result.avg_p_value, result.avg_conf_int)
def _extract_staggered(
result: Any,
) -> Tuple[float, float, float, Tuple[float, float]]:
_nan = float("nan")
_nan_ci = (_nan, _nan)
def _first(r: Any, *attrs: str, default: Any = _nan) -> Any:
for a in attrs:
v = getattr(r, a, None)
if v is not None:
return v
return default
return (
result.overall_att,
_first(result, "overall_se", "overall_att_se"),
_first(result, "overall_p_value", "overall_att_p_value"),
_first(result, "overall_conf_int", "overall_att_ci", default=_nan_ci),
)
# Keys derived from simulate_power() public params — overriding these
# via data_generator_kwargs would desync the DGP from the result object.
_PROTECTED_DGP_KEYS = frozenset(
{
"treatment_effect", # → true_effect in results / MDE search variable
"noise_sd", # → sigma param
"n_units", # → sample-size search variable
"n_periods", # → n_periods param
"treatment_fraction", # → treatment_fraction param
"treatment_period", # → treatment_period param
"n_pre", # → derived from treatment_period in factor-model DGPs
"n_post", # → derived from n_periods - treatment_period in factor-model DGPs
}
)
# Keys managed by SurveyPowerConfig — block in data_generator_kwargs when
# survey_config is active to prevent silent conflicts.
_SURVEY_CONFIG_KEYS = frozenset(
{
"n_strata",
"psu_per_stratum",
"fpc_per_stratum",
"weight_variation",
"psu_re_sd",
"psu_period_factor",
"icc",
"weight_cv",
"informative_sampling",
"heterogeneous_te_by_strata",
"include_replicate_weights",
"return_true_population_att",
"dynamic_effects",
"cohort_periods",
"never_treated_frac",
}
)
# -- Staggered DGP compatibility check ----------------------------------------
_STAGGERED_ESTIMATORS = frozenset(
{
"CallawaySantAnna",
"SunAbraham",
"ImputationDiD",
"TwoStageDiD",
"StackedDiD",
"EfficientDiD",
}
)
# Estimators that need a derived `post` column when using survey DGP
# (survey DGP produces `period`/`first_treat` but not `post`).
_SURVEY_POST_ESTIMATORS = frozenset({"DifferenceInDifferences", "TwoWayFixedEffects"})
# Survey fit kwargs builder lookup — maps estimator name to builder function.
_SURVEY_FIT_BUILDERS: Dict[str, Callable] = {
"DifferenceInDifferences": _survey_basic_fit_kwargs,
"TwoWayFixedEffects": _survey_twfe_fit_kwargs,
"MultiPeriodDiD": _survey_multiperiod_fit_kwargs,
**{name: _survey_staggered_fit_kwargs for name in _STAGGERED_ESTIMATORS},
}
# Unsupported: factor-model and triple-diff estimators (survey DGP produces
# staggered cohort data, not factor-model or 2x2x2 data).
_SURVEY_UNSUPPORTED = frozenset({"TROP", "SyntheticDiD", "TripleDifference"})
def _check_staggered_dgp_compat(
estimator: Any,
data_generator_kwargs: Optional[Dict[str, Any]],
) -> None:
"""Warn if a staggered estimator's settings don't match the default DGP."""
name = type(estimator).__name__
if name not in _STAGGERED_ESTIMATORS:
return
dgp_overrides = data_generator_kwargs or {}
cohort_periods = dgp_overrides.get("cohort_periods")
has_multi_cohort = cohort_periods is not None and len(set(cohort_periods)) >= 2
issues: List[str] = []
# Check control_group="not_yet_treated" (CS, SA)
cg = getattr(estimator, "control_group", "never_treated")
if cg == "not_yet_treated" and not has_multi_cohort:
issues.append(
f' - {name} has control_group="not_yet_treated" but the default '
f"DGP generates a single treatment cohort with never-treated "
f"controls. Power may not reflect the intended not-yet-treated "
f"design.\n"
f" Fix: pass data_generator_kwargs="
f'{{"cohort_periods": [2, 4], "never_treated_frac": 0.0}} '
f"(or a custom data_generator)."
)
# Check anticipation > 0 (all staggered)
antic = getattr(estimator, "anticipation", 0)
if antic > 0:
issues.append(
f" - {name} has anticipation={antic} but the default DGP does "
f"not model anticipatory effects. The estimator will look for "
f"treatment effects {antic} period(s) before the DGP generates "
f"them, biasing power estimates.\n"
f" Fix: supply a custom data_generator that shifts the "
f"effect onset."
)
# Check clean_control on StackedDiD
if name == "StackedDiD":
cc = getattr(estimator, "clean_control", "not_yet_treated")
if cc == "strict" and not has_multi_cohort:
issues.append(
' - StackedDiD has clean_control="strict" but the default '
"single-cohort DGP makes strict controls equivalent to "
"never-treated controls.\n"
" Fix: pass data_generator_kwargs="
'{"cohort_periods": [2, 4]} '
"to test true strict clean-control behavior."
)
if issues:
msg = (
f"Staggered power DGP mismatch for {name}. The default "
f"single-cohort DGP may not match the estimator "
f"configuration:\n" + "\n".join(issues)
)
warnings.warn(msg, UserWarning, stacklevel=2)
def _ddd_effective_n(
n_units: int, data_generator_kwargs: Optional[Dict[str, Any]]
) -> Optional[int]:
"""Return effective DDD sample size, or None if no rounding occurred."""
overrides = data_generator_kwargs or {}
if "n_per_cell" in overrides:
eff = overrides["n_per_cell"] * 8
else:
eff = max(2, n_units // 8) * 8
return eff if eff != n_units else None
def _check_ddd_dgp_compat(
n_units: int,
n_periods: int,
treatment_fraction: float,
treatment_period: int,
data_generator_kwargs: Optional[Dict[str, Any]],
) -> None:
"""Warn when simulation inputs don't match DDD's fixed 2×2×2 design."""
issues: List[str] = []
# DDD is a fixed 2-period factorial; n_periods and treatment_period are ignored
if n_periods != 2:
issues.append(
f"n_periods={n_periods} is ignored (DDD uses a fixed " f"2-period design: pre/post)"
)
if treatment_period != 1:
issues.append(
f"treatment_period={treatment_period} is ignored (DDD "
f"always treats in the second period)"
)
# DDD's 2×2×2 factorial has inherent 50% treatment fraction
if treatment_fraction != 0.5:
issues.append(
f"treatment_fraction={treatment_fraction} is ignored "
f"(DDD uses a balanced 2×2×2 factorial where 50% of "
f"groups are treated)"
)
# n_units rounding: n_per_cell = max(2, n_units // 8)
eff_n = _ddd_effective_n(n_units, data_generator_kwargs)
if eff_n is not None:
eff_n_per_cell = eff_n // 8
issues.append(
f"effective sample size is {eff_n} "
f"(n_per_cell={eff_n_per_cell} × 8 cells), "
f"not the requested n_units={n_units}"
)
if issues:
warnings.warn(
"TripleDifference uses a fixed 2×2×2 factorial DGP "
"(group × partition × time). "
+ "; ".join(issues)
+ ". Pass a custom data_generator for non-standard DDD designs.",
UserWarning,
stacklevel=2,
)
def _check_sdid_placebo_data(
data: pd.DataFrame,
estimator: Any,
est_kwargs: Dict[str, Any],
) -> None:
"""Check SyntheticDiD placebo feasibility on realized data.
This catches infeasible designs on the custom-DGP path where the
pre-generation check (which uses ``n_units * treatment_fraction``)
cannot run because treatment allocation is determined by the DGP.
"""
vm = getattr(estimator, "variance_method", "placebo")
if vm != "placebo":
return
treat_col = est_kwargs.get("treatment", "treat")
unit_col = est_kwargs.get("unit", "unit")
if treat_col not in data.columns or unit_col not in data.columns:
return # fit will fail with a more specific error
unit_treat = data.groupby(unit_col)[treat_col].first()
n_treated = int(unit_treat.sum())
n_control = len(unit_treat) - n_treated
if n_control <= n_treated:
raise ValueError(
f"SyntheticDiD placebo variance requires more control than "
f"treated units, but the generated data has n_control={n_control}, "
f"n_treated={n_treated}. Either adjust your data_generator so that "
f"n_control > n_treated, or use "
f"SyntheticDiD(variance_method='bootstrap') (paper-faithful refit; "
f"~5-30x slower than placebo) or SyntheticDiD(variance_method='jackknife')."
)
# -- Registry construction (deferred to avoid import-time cost) ---------------
_ESTIMATOR_REGISTRY: Optional[Dict[str, _EstimatorProfile]] = None
def _get_registry() -> Dict[str, _EstimatorProfile]:
"""Lazily build and return the estimator registry."""
global _ESTIMATOR_REGISTRY # noqa: PLW0603
if _ESTIMATOR_REGISTRY is not None:
return _ESTIMATOR_REGISTRY
from diff_diff.prep import (
generate_ddd_data,
generate_did_data,
generate_factor_data,
generate_staggered_data,
)
_ESTIMATOR_REGISTRY = {
# --- Basic DiD group ---
"DifferenceInDifferences": _EstimatorProfile(
default_dgp=generate_did_data,
dgp_kwargs_builder=_basic_dgp_kwargs,
fit_kwargs_builder=_basic_fit_kwargs,
result_extractor=_extract_simple,
min_n=20,
),
"TwoWayFixedEffects": _EstimatorProfile(
default_dgp=generate_did_data,
dgp_kwargs_builder=_basic_dgp_kwargs,
fit_kwargs_builder=_twfe_fit_kwargs,
result_extractor=_extract_simple,
min_n=20,
),
"MultiPeriodDiD": _EstimatorProfile(
default_dgp=generate_did_data,
dgp_kwargs_builder=_basic_dgp_kwargs,
fit_kwargs_builder=_multiperiod_fit_kwargs,
result_extractor=_extract_multiperiod,
min_n=20,
),
# --- Staggered group ---
"CallawaySantAnna": _EstimatorProfile(
default_dgp=generate_staggered_data,
dgp_kwargs_builder=_staggered_dgp_kwargs,
fit_kwargs_builder=_staggered_fit_kwargs,
result_extractor=_extract_staggered,
min_n=40,
),
"SunAbraham": _EstimatorProfile(
default_dgp=generate_staggered_data,
dgp_kwargs_builder=_staggered_dgp_kwargs,
fit_kwargs_builder=_staggered_fit_kwargs,
result_extractor=_extract_staggered,
min_n=40,
),
"ImputationDiD": _EstimatorProfile(
default_dgp=generate_staggered_data,
dgp_kwargs_builder=_staggered_dgp_kwargs,
fit_kwargs_builder=_staggered_fit_kwargs,
result_extractor=_extract_staggered,
min_n=40,
),
"TwoStageDiD": _EstimatorProfile(
default_dgp=generate_staggered_data,
dgp_kwargs_builder=_staggered_dgp_kwargs,
fit_kwargs_builder=_staggered_fit_kwargs,
result_extractor=_extract_staggered,
min_n=40,
),
"StackedDiD": _EstimatorProfile(
default_dgp=generate_staggered_data,
dgp_kwargs_builder=_staggered_dgp_kwargs,
fit_kwargs_builder=_staggered_fit_kwargs,
result_extractor=_extract_staggered,
min_n=40,
),
"EfficientDiD": _EstimatorProfile(
default_dgp=generate_staggered_data,
dgp_kwargs_builder=_staggered_dgp_kwargs,
fit_kwargs_builder=_staggered_fit_kwargs,
result_extractor=_extract_staggered,
min_n=40,
),
# --- Factor model group ---
"TROP": _EstimatorProfile(
default_dgp=generate_factor_data,
dgp_kwargs_builder=_factor_dgp_kwargs,
fit_kwargs_builder=_trop_fit_kwargs,
result_extractor=_extract_simple,
min_n=30,
),
"SyntheticDiD": _EstimatorProfile(
default_dgp=generate_factor_data,
dgp_kwargs_builder=_factor_dgp_kwargs,
fit_kwargs_builder=_sdid_fit_kwargs,
result_extractor=_extract_simple,
min_n=30,
),
# --- Triple difference ---
"TripleDifference": _EstimatorProfile(
default_dgp=generate_ddd_data,
dgp_kwargs_builder=_ddd_dgp_kwargs,
fit_kwargs_builder=_ddd_fit_kwargs,
result_extractor=_extract_simple,
min_n=64,
),
}
return _ESTIMATOR_REGISTRY
[docs]
@dataclass
class PowerResults:
"""
Results from analytical power analysis.
Attributes
----------
power : float
Statistical power (probability of rejecting H0 when effect exists).
mde : float
Minimum detectable effect size.
required_n : int
Required total sample size (treated + control).
effect_size : float
Effect size used in calculation.
alpha : float
Significance level.
alternative : str
Alternative hypothesis ('two-sided', 'greater', 'less').
n_treated : int
Number of treated units.
n_control : int
Number of control units.
n_pre : int
Number of pre-treatment periods.
n_post : int
Number of post-treatment periods.
sigma : float
Residual standard deviation.
rho : float
Within-unit (serial) equicorrelation (Burlig 2020 Eq. 2 equicorrelated case).
deff : float
Survey design effect (variance inflation factor).
design : str
Study design type ('basic_did', 'panel', 'staggered').
"""
power: float
mde: float
required_n: int
effect_size: float
alpha: float
alternative: str
n_treated: int
n_control: int
n_pre: int
n_post: int
sigma: float
rho: float = 0.0
deff: float = 1.0
design: str = "basic_did"
[docs]
def __repr__(self) -> str:
"""Concise string representation."""
return (
f"PowerResults(power={self.power:.3f}, mde={self.mde:.4f}, "
f"required_n={self.required_n})"
)
[docs]
def summary(self) -> str:
"""
Generate a formatted summary of power analysis results.
Returns
-------
str
Formatted summary table.
"""
lines = [
"=" * 60,
"Power Analysis for Difference-in-Differences".center(60),
"=" * 60,
"",
f"{'Design:':<30} {self.design}",
f"{'Significance level (alpha):':<30} {self.alpha:.3f}",
f"{'Alternative hypothesis:':<30} {self.alternative}",
"",
"-" * 60,
"Sample Size".center(60),
"-" * 60,
f"{'Treated units:':<30} {self.n_treated:>10}",
f"{'Control units:':<30} {self.n_control:>10}",
f"{'Total units:':<30} {self.n_treated + self.n_control:>10}",
f"{'Pre-treatment periods:':<30} {self.n_pre:>10}",
f"{'Post-treatment periods:':<30} {self.n_post:>10}",
"",
"-" * 60,
"Variance Parameters".center(60),
"-" * 60,
f"{'Residual SD (sigma):':<30} {self.sigma:>10.4f}",
f"{'Within-unit equicorrelation:':<30} {self.rho:>10.4f}",
*([f"{'Design effect (DEFF):':<30} {self.deff:>10.4f}"] if self.deff != 1.0 else []),
"",
"-" * 60,
"Power Analysis Results".center(60),
"-" * 60,
f"{'Effect size:':<30} {self.effect_size:>10.4f}",
f"{'Power:':<30} {self.power:>10.1%}",
f"{'Minimum detectable effect:':<30} {self.mde:>10.4f}",
f"{'Required sample size:':<30} {self.required_n:>10}",
"=" * 60,
]
return "\n".join(lines)
[docs]
def print_summary(self) -> None:
"""Print the summary to stdout."""
print(self.summary())
[docs]
def to_dict(self) -> Dict[str, Any]:
"""
Convert results to a dictionary.
Returns
-------
Dict[str, Any]
Dictionary containing all power analysis results.
"""
return {
"power": self.power,
"mde": self.mde,
"required_n": self.required_n,
"effect_size": self.effect_size,
"alpha": self.alpha,
"alternative": self.alternative,
"n_treated": self.n_treated,
"n_control": self.n_control,
"n_pre": self.n_pre,
"n_post": self.n_post,
"sigma": self.sigma,
"rho": self.rho,
"deff": self.deff,
"design": self.design,
}
[docs]
def to_dataframe(self) -> pd.DataFrame:
"""
Convert results to a pandas DataFrame.
Returns
-------
pd.DataFrame
DataFrame with power analysis results.
"""
return pd.DataFrame([self.to_dict()])
[docs]
@dataclass
class SimulationPowerResults:
"""
Results from simulation-based power analysis.
Attributes
----------
power : float
Estimated power (proportion of simulations rejecting H0).
power_se : float
Standard error of power estimate.
power_ci : Tuple[float, float]
Confidence interval for power estimate.
rejection_rate : float
Proportion of simulations with p-value < alpha.
mean_estimate : float
Mean treatment effect estimate across simulations.
std_estimate : float
Standard deviation of estimates across simulations.
mean_se : float
Mean standard error across simulations.
coverage : float
Proportion of CIs containing true effect.
n_simulations : int
Number of simulations performed (successful count; see
``n_simulation_failures`` for failed-replicate count).
n_simulation_failures : int
Number of simulations at the primary effect size whose `estimator.fit`
(or result extraction) raised an exception and was skipped. Lets
callers programmatically detect fragile DGP/estimator pairings; a
proportional warning is also emitted above a 10% failure rate.
effect_sizes : List[float]
Effect sizes tested (if multiple).
powers : List[float]
Power at each effect size (if multiple).
true_effect : float
True treatment effect used in simulation.
alpha : float
Significance level.
estimator_name : str
Name of the estimator used.
effective_n_units : int or None
Effective sample size when it differs from the requested ``n_units``
(e.g., due to DDD grid rounding). ``None`` when no rounding occurred.
"""
power: float
power_se: float
power_ci: Tuple[float, float]
rejection_rate: float
mean_estimate: float
std_estimate: float
mean_se: float
coverage: float
n_simulations: int
effect_sizes: List[float]
powers: List[float]
true_effect: float
alpha: float
estimator_name: str
bias: float = field(init=False)
rmse: float = field(init=False)
simulation_results: Optional[List[Dict[str, Any]]] = field(default=None, repr=False)
effective_n_units: Optional[int] = None
survey_config: Optional[Any] = field(default=None, repr=False)
mean_deff: Optional[float] = None
mean_icc_realized: Optional[float] = None
n_simulation_failures: int = 0
[docs]
def __post_init__(self):
"""Compute derived statistics."""
self.bias = self.mean_estimate - self.true_effect
self.rmse = np.sqrt(self.bias**2 + self.std_estimate**2)
[docs]
def __repr__(self) -> str:
"""Concise string representation."""
return (
f"SimulationPowerResults(power={self.power:.3f} "
f"[{self.power_ci[0]:.3f}, {self.power_ci[1]:.3f}], "
f"n_simulations={self.n_simulations})"
)
[docs]
def summary(self) -> str:
"""
Generate a formatted summary of simulation power results.
Returns
-------
str
Formatted summary table.
"""
lines = [
"=" * 65,
"Simulation-Based Power Analysis Results".center(65),
"=" * 65,
"",
f"{'Estimator:':<35} {self.estimator_name}",
f"{'Number of simulations:':<35} {self.n_simulations}",
f"{'Simulation failures:':<35} {self.n_simulation_failures}",
f"{'True treatment effect:':<35} {self.true_effect:.4f}",
f"{'Significance level (alpha):':<35} {self.alpha:.3f}",
"",
"-" * 65,
"Power Estimates".center(65),
"-" * 65,
f"{'Power (rejection rate):':<35} {self.power:.1%}",
f"{'Standard error:':<35} {self.power_se:.4f}",
f"{'95% CI:':<35} [{self.power_ci[0]:.3f}, {self.power_ci[1]:.3f}]",
"",
"-" * 65,
"Estimation Performance".center(65),
"-" * 65,
f"{'Mean estimate:':<35} {self.mean_estimate:.4f}",
f"{'Bias:':<35} {self.bias:.4f}",
f"{'Std. deviation of estimates:':<35} {self.std_estimate:.4f}",
f"{'RMSE:':<35} {self.rmse:.4f}",
f"{'Mean standard error:':<35} {self.mean_se:.4f}",
f"{'Coverage (CI contains true):':<35} {self.coverage:.1%}",
]
if self.effective_n_units is not None:
lines.append(
f"{'Effective sample size:':<35} {self.effective_n_units}" f" (DDD grid-rounded)"
)
if self.survey_config is not None:
lines.extend(
[
"",
"-" * 65,
"Survey Design".center(65),
"-" * 65,
f"{'Strata:':<35} {self.survey_config.n_strata}",
f"{'PSUs per stratum:':<35} {self.survey_config.psu_per_stratum}",
]
)
if self.mean_deff is not None:
lines.append(f"{'Mean Kish DEFF:':<35} {self.mean_deff:.4f}")
if self.mean_icc_realized is not None:
lines.append(f"{'Mean realized ICC:':<35} {self.mean_icc_realized:.4f}")
lines.append("=" * 65)
return "\n".join(lines)
[docs]
def print_summary(self) -> None:
"""Print the summary to stdout."""
print(self.summary())
[docs]
def to_dict(self) -> Dict[str, Any]:
"""
Convert results to a dictionary.
Returns
-------
Dict[str, Any]
Dictionary containing simulation power results.
"""
d: Dict[str, Any] = {
"power": self.power,
"power_se": self.power_se,
"power_ci_lower": self.power_ci[0],
"power_ci_upper": self.power_ci[1],
"rejection_rate": self.rejection_rate,
"mean_estimate": self.mean_estimate,
"std_estimate": self.std_estimate,
"bias": self.bias,
"rmse": self.rmse,
"mean_se": self.mean_se,
"coverage": self.coverage,
"n_simulations": self.n_simulations,
"n_simulation_failures": self.n_simulation_failures,
"true_effect": self.true_effect,
"alpha": self.alpha,
"estimator_name": self.estimator_name,
"effective_n_units": self.effective_n_units,
"mean_deff": self.mean_deff,
"mean_icc_realized": self.mean_icc_realized,
}
return d
[docs]
def to_dataframe(self) -> pd.DataFrame:
"""
Convert results to a pandas DataFrame.
Returns
-------
pd.DataFrame
DataFrame with simulation power results.
"""
return pd.DataFrame([self.to_dict()])
[docs]
def power_curve_df(self) -> pd.DataFrame:
"""
Get power curve data as a DataFrame.
Returns
-------
pd.DataFrame
DataFrame with effect_size and power columns.
"""
return pd.DataFrame({"effect_size": self.effect_sizes, "power": self.powers})
[docs]
class PowerAnalysis:
"""
Power analysis for difference-in-differences designs.
Provides analytical power calculations for basic 2x2 DiD and panel DiD
designs. For complex designs like staggered adoption, use simulate_power()
instead.
Parameters
----------
alpha : float, default=0.05
Significance level for hypothesis testing.
power : float, default=0.80
Target statistical power.
alternative : str, default='two-sided'
Alternative hypothesis: 'two-sided', 'greater', or 'less'.
Examples
--------
Calculate minimum detectable effect:
>>> from diff_diff import PowerAnalysis
>>> pa = PowerAnalysis(alpha=0.05, power=0.80)
>>> results = pa.mde(n_treated=50, n_control=50, sigma=1.0)
>>> print(f"MDE: {results.mde:.3f}")
Calculate required sample size:
>>> results = pa.sample_size(effect_size=0.5, sigma=1.0)
>>> print(f"Required N: {results.required_n}")
Calculate power for given sample and effect:
>>> results = pa.power(effect_size=0.5, n_treated=50, n_control=50, sigma=1.0)
>>> print(f"Power: {results.power:.1%}")
Notes
-----
The power calculations are based on the variance of the DiD estimator.
Critical values use the **normal (z)** distribution following Bloom (1995):
``MDE = (z_{1-alpha/2} + z_{1-kappa}) * SE``. This is a large-sample
approximation to Burlig et al.'s t-based multiplier (their Eq. 1) and is
mildly anti-conservative for very small numbers of units.
The variance is the **within-unit equicorrelated special case of Burlig,
Preonas & Woerman (2020), Eq. 2** (psi^B = psi^A = psi^X = rho * sigma^2), for
m = n_pre pre-periods and r = n_post post-periods::
Var(ATT) = sigma^2 * (1/N_treated + 1/N_control) * (1/m + 1/r) * (1 - rho)
where rho is the within-unit (serial) equicorrelation. Cross-period
correlation **lowers** the DiD variance (differencing cancels the shared
within-unit component), so the MDE *decreases* as rho increases -- the
opposite of a Moulton mean-inflation factor.
The basic 2x2 design (n_pre = n_post = 1) is the m = r = 1 special case
(Burlig footnote 11), Var(ATT) = 2 * sigma^2 * (1/N_treated + 1/N_control) *
(1 - rho), reducing to Bloom (1995) Eq. 1's DiD analog at rho = 0.
The fully general serial-correlation-robust form (independent psi^B, psi^A,
psi^X) is not implemented; see ``docs/methodology/REGISTRY.md``
``## PowerAnalysis`` and the source audits under ``docs/methodology/papers/``.
References
----------
Bloom, H. S. (1995). "Minimum Detectable Effects."
Burlig, F., Preonas, L., & Woerman, M. (2020). "Panel Data and Experimental Design."
"""
[docs]
def __init__(
self,
alpha: float = 0.05,
power: float = 0.80,
alternative: str = "two-sided",
):
if not 0 < alpha < 1:
raise ValueError("alpha must be between 0 and 1")
if not 0 < power < 1:
raise ValueError("power must be between 0 and 1")
if alternative not in ("two-sided", "greater", "less"):
raise ValueError("alternative must be 'two-sided', 'greater', or 'less'")
self.alpha = alpha
self.target_power = power
self.alternative = alternative
@staticmethod
def _validate_deff(deff: float) -> None:
"""Validate deff parameter and warn if < 1."""
if not np.isfinite(deff) or deff <= 0:
raise ValueError(f"deff must be finite and > 0, got {deff}")
if deff < 1.0:
warnings.warn(
f"deff={deff:.4f} < 1.0 implies net variance reduction "
f"(e.g., from stratification). This is valid but unusual.",
stacklevel=3,
)
@staticmethod
def _validate_design_params(n_pre: int, n_post: int, rho: float) -> None:
"""Validate analytical-power inputs for the Burlig (2020) Eq. 2 variance.
Applies to BOTH the 2x2 (n_pre = n_post = 1) and multi-period panel paths
-- the 2x2 case is the m = r = 1 special case of the same equicorrelated
variance. Requires at least one pre- and one post-period, and a within-unit
equicorrelation ``rho`` in ``[-1/(T-1), 1)`` (T = n_pre + n_post):
``rho >= 1`` yields a non-positive residual variance, and
``rho < -1/(T-1)`` is not a valid equicorrelation structure.
"""
if n_pre < 1 or n_post < 1:
raise ValueError(
"Power analysis requires n_pre >= 1 and n_post >= 1 (a DiD design "
f"needs both pre- and post-periods), got n_pre={n_pre}, n_post={n_post}."
)
T = n_pre + n_post
rho_min = -1.0 / (T - 1)
if not rho_min <= rho < 1.0:
raise ValueError(
f"rho must lie in [{rho_min:.4g}, 1) (valid within-unit "
f"equicorrelation over T={T} periods; rho >= 1 implies zero "
f"residual variance), got rho={rho}."
)
def _get_critical_values(self) -> Tuple[float, float]:
"""Get normal (z) critical values for alpha and power.
Uses standard-normal quantiles (the Bloom 1995 multiplier) -- a
large-sample approximation to Burlig et al. (2020) Eq. 1's t-based
multiplier.
"""
if self.alternative == "two-sided":
z_alpha = stats.norm.ppf(1 - self.alpha / 2)
else:
z_alpha = stats.norm.ppf(1 - self.alpha)
z_beta = stats.norm.ppf(self.target_power)
return z_alpha, z_beta
def _compute_variance(
self,
n_treated: int,
n_control: int,
n_pre: int,
n_post: int,
sigma: float,
rho: float = 0.0,
deff: float = 1.0,
design: str = "basic_did",
) -> float:
"""
Compute variance of the DiD estimator.
Parameters
----------
n_treated : int
Number of treated units.
n_control : int
Number of control units.
n_pre : int
Number of pre-treatment periods.
n_post : int
Number of post-treatment periods.
sigma : float
Residual standard deviation.
rho : float
Within-unit (serial) equicorrelation for the panel design (Burlig
2020 Eq. 2, equicorrelated case); higher rho lowers the variance.
deff : float
Survey design effect (variance inflation factor). Not redundant
with ``rho``: ``rho`` models within-unit (serial) equicorrelation
(Burlig 2020 Eq. 2 ``(1/m+1/r)(1-rho)`` factor), ``deff`` models
survey clustering/weighting.
design : str
Study design type.
Returns
-------
float
Variance of the DiD estimator.
"""
# Validate inputs before routing so invalid two-period shapes (e.g.
# n_pre=0) and out-of-range rho cannot fall through to basic_did silently.
self._validate_design_params(n_pre, n_post, rho)
if not np.isfinite(sigma) or sigma < 0:
raise ValueError(f"sigma (residual SD) must be finite and >= 0, got {sigma}")
if n_treated <= 0 or n_control <= 0:
raise ValueError(
"n_treated and n_control must be > 0, got "
f"n_treated={n_treated}, n_control={n_control}"
)
if design == "basic_did":
# 2x2 DiD (n_pre = n_post = 1): the m = r = 1 special case of the
# equicorrelated Burlig (2020) Eq. 2 variance (footnote 11 drops the
# within-pre / within-post covariance terms, leaving the cross-period
# term). Reduces to Bloom (1995) Eq. 1's DiD analog
# 2 * sigma^2 * (1/n_T + 1/n_C) at rho = 0; the (1 - rho) factor
# applies the correlation between the single pre and post observation.
n_t_pre = n_treated # treated units in pre-period
n_t_post = n_treated # treated units in post-period
n_c_pre = n_control
n_c_post = n_control
cell_factor = 1 / n_t_post + 1 / n_t_pre + 1 / n_c_post + 1 / n_c_pre
variance = sigma**2 * cell_factor * (1 - rho)
elif design == "panel":
# Burlig, Preonas & Woerman (2020), Eq. 2, specialized to within-unit
# equicorrelation (psi^B = psi^A = psi^X = rho * sigma^2):
# Var(ATT) = sigma^2 (1/n_T + 1/n_C) (1/m + 1/r) (1 - rho)
# with m = n_pre, r = n_post. Cross-period correlation (rho) LOWERS the
# DiD variance because differencing cancels the shared within-unit
# component -- the opposite sign of a Moulton mean-inflation factor.
period_factor = 1 / n_pre + 1 / n_post # = (m + r) / (m * r)
base_var = sigma**2 * (1 / n_treated + 1 / n_control)
variance = base_var * period_factor * (1 - rho)
else:
raise ValueError(f"Unknown design: {design}")
# Survey design effect (multiplicative variance inflation)
variance *= deff
return variance
[docs]
def power(
self,
effect_size: float,
n_treated: int,
n_control: int,
sigma: float,
n_pre: int = 1,
n_post: int = 1,
rho: float = 0.0,
deff: float = 1.0,
) -> PowerResults:
"""
Calculate statistical power for given effect size and sample.
Parameters
----------
effect_size : float
Expected treatment effect size.
n_treated : int
Number of treated units.
n_control : int
Number of control units.
sigma : float
Residual standard deviation.
n_pre : int, default=1
Number of pre-treatment periods.
n_post : int, default=1
Number of post-treatment periods.
rho : float, default=0.0
Within-unit (serial) equicorrelation for panel designs. Higher rho
LOWERS the MDE (Burlig et al. 2020, Eq. 2, equicorrelated case);
valid range [-1/(T-1), 1).
deff : float, default=1.0
Survey design effect (variance inflation factor). Not redundant
with ``rho``: ``rho`` models within-unit serial correlation,
``deff`` models survey clustering/weighting.
Returns
-------
PowerResults
Power analysis results.
Examples
--------
>>> pa = PowerAnalysis()
>>> results = pa.power(effect_size=2.0, n_treated=50, n_control=50, sigma=5.0)
>>> print(f"Power: {results.power:.1%}")
"""
self._validate_deff(deff)
T = n_pre + n_post
design = "panel" if T > 2 else "basic_did"
variance = self._compute_variance(
n_treated, n_control, n_pre, n_post, sigma, rho, deff=deff, design=design
)
se = np.sqrt(variance)
# Calculate power
if self.alternative == "two-sided":
z_alpha = stats.norm.ppf(1 - self.alpha / 2)
# Power = P(reject | effect) = P(|Z| > z_alpha | effect)
power_val = (
1
- stats.norm.cdf(z_alpha - effect_size / se)
+ stats.norm.cdf(-z_alpha - effect_size / se)
)
elif self.alternative == "greater":
z_alpha = stats.norm.ppf(1 - self.alpha)
power_val = 1 - stats.norm.cdf(z_alpha - effect_size / se)
else: # less
z_alpha = stats.norm.ppf(1 - self.alpha)
power_val = stats.norm.cdf(-z_alpha - effect_size / se)
# Also compute MDE and required N for reference
mde = self._compute_mde_from_se(se)
required_n = self._compute_required_n(
effect_size,
sigma,
n_pre,
n_post,
rho,
design,
n_treated / (n_treated + n_control),
deff=deff,
)
return PowerResults(
power=power_val,
mde=mde,
required_n=required_n,
effect_size=effect_size,
alpha=self.alpha,
alternative=self.alternative,
n_treated=n_treated,
n_control=n_control,
n_pre=n_pre,
n_post=n_post,
sigma=sigma,
rho=rho,
deff=deff,
design=design,
)
def _compute_mde_from_se(self, se: float) -> float:
"""Compute MDE given standard error."""
z_alpha, z_beta = self._get_critical_values()
return (z_alpha + z_beta) * se
[docs]
def mde(
self,
n_treated: int,
n_control: int,
sigma: float,
n_pre: int = 1,
n_post: int = 1,
rho: float = 0.0,
deff: float = 1.0,
) -> PowerResults:
"""
Calculate minimum detectable effect given sample size.
The MDE is the smallest effect size that can be detected with the
specified power and significance level.
Parameters
----------
n_treated : int
Number of treated units.
n_control : int
Number of control units.
sigma : float
Residual standard deviation.
n_pre : int, default=1
Number of pre-treatment periods.
n_post : int, default=1
Number of post-treatment periods.
rho : float, default=0.0
Within-unit (serial) equicorrelation for panel designs. Higher rho
LOWERS the MDE (Burlig et al. 2020, Eq. 2, equicorrelated case);
valid range [-1/(T-1), 1).
deff : float, default=1.0
Survey design effect (variance inflation factor).
Returns
-------
PowerResults
Power analysis results including MDE.
Examples
--------
>>> pa = PowerAnalysis(power=0.80)
>>> results = pa.mde(n_treated=100, n_control=100, sigma=10.0)
>>> print(f"MDE: {results.mde:.2f}")
"""
self._validate_deff(deff)
T = n_pre + n_post
design = "panel" if T > 2 else "basic_did"
variance = self._compute_variance(
n_treated, n_control, n_pre, n_post, sigma, rho, deff=deff, design=design
)
se = np.sqrt(variance)
mde = self._compute_mde_from_se(se)
return PowerResults(
power=self.target_power,
mde=mde,
required_n=n_treated + n_control,
effect_size=mde,
alpha=self.alpha,
alternative=self.alternative,
n_treated=n_treated,
n_control=n_control,
n_pre=n_pre,
n_post=n_post,
sigma=sigma,
rho=rho,
deff=deff,
design=design,
)
def _compute_required_n(
self,
effect_size: float,
sigma: float,
n_pre: int,
n_post: int,
rho: float,
design: str,
treat_frac: float = 0.5,
deff: float = 1.0,
) -> int:
"""Compute required sample size for given effect.
Note: this method has its own formula independent of _compute_variance,
so deff must be applied here separately (not double-counting).
"""
# Validate inputs before routing (mirrors _compute_variance).
self._validate_design_params(n_pre, n_post, rho)
if not np.isfinite(sigma) or sigma < 0:
raise ValueError(f"sigma (residual SD) must be finite and >= 0, got {sigma}")
if not 0 < treat_frac < 1:
raise ValueError(f"treat_frac must be in (0, 1), got {treat_frac}")
# Handle edge case of zero effect size
if effect_size == 0:
return MAX_SAMPLE_SIZE # Can't detect zero effect
z_alpha, z_beta = self._get_critical_values()
if design == "basic_did":
# 2x2 DiD = the m = r = 1 equicorrelated case (period factor
# 1/1 + 1/1 = 2); the (1 - rho) factor mirrors _compute_variance and
# reduces to Bloom's 2 * sigma^2 * (z..)^2 / (delta^2 f(1-f)) at rho = 0.
n_total = (
2
* sigma**2
* (z_alpha + z_beta) ** 2
* (1 - rho)
/ (effect_size**2 * treat_frac * (1 - treat_frac))
)
else: # panel
# Burlig (2020) Eq. 2 (equicorrelated), inverted for required N.
# period_factor = 1/n_pre + 1/n_post equals 2 at n_pre=n_post=1, so this
# is continuous with the basic_did branch.
period_factor = 1 / n_pre + 1 / n_post
n_total = (
sigma**2
* (z_alpha + z_beta) ** 2
* period_factor
* (1 - rho)
/ (effect_size**2 * treat_frac * (1 - treat_frac))
)
# Survey design effect (multiplicative sample size inflation)
n_total *= deff
# Handle infinity case (extremely small effect)
if np.isinf(n_total):
return MAX_SAMPLE_SIZE
return max(4, int(np.ceil(n_total))) # At least 4 units
[docs]
def sample_size(
self,
effect_size: float,
sigma: float,
n_pre: int = 1,
n_post: int = 1,
rho: float = 0.0,
treat_frac: float = 0.5,
deff: float = 1.0,
) -> PowerResults:
"""
Calculate required sample size to detect given effect.
Parameters
----------
effect_size : float
Treatment effect to detect.
sigma : float
Residual standard deviation.
n_pre : int, default=1
Number of pre-treatment periods.
n_post : int, default=1
Number of post-treatment periods.
rho : float, default=0.0
Within-unit (serial) equicorrelation for panel designs. Higher rho
LOWERS the MDE (Burlig et al. 2020, Eq. 2, equicorrelated case);
valid range [-1/(T-1), 1).
treat_frac : float, default=0.5
Fraction of units assigned to treatment.
deff : float, default=1.0
Survey design effect (variance inflation factor).
Returns
-------
PowerResults
Power analysis results including required sample size.
Examples
--------
>>> pa = PowerAnalysis(power=0.80)
>>> results = pa.sample_size(effect_size=5.0, sigma=10.0)
>>> print(f"Required N: {results.required_n}")
"""
self._validate_deff(deff)
T = n_pre + n_post
design = "panel" if T > 2 else "basic_did"
n_total = self._compute_required_n(
effect_size, sigma, n_pre, n_post, rho, design, treat_frac, deff=deff
)
n_treated = max(2, int(np.ceil(n_total * treat_frac)))
n_control = max(2, n_total - n_treated)
n_total = n_treated + n_control
# Compute actual power achieved
variance = self._compute_variance(
n_treated, n_control, n_pre, n_post, sigma, rho, deff=deff, design=design
)
se = np.sqrt(variance)
mde = self._compute_mde_from_se(se)
return PowerResults(
power=self.target_power,
mde=mde,
required_n=n_total,
effect_size=effect_size,
alpha=self.alpha,
alternative=self.alternative,
n_treated=n_treated,
n_control=n_control,
n_pre=n_pre,
n_post=n_post,
sigma=sigma,
rho=rho,
deff=deff,
design=design,
)
[docs]
def power_curve(
self,
n_treated: int,
n_control: int,
sigma: float,
effect_sizes: Optional[List[float]] = None,
n_pre: int = 1,
n_post: int = 1,
rho: float = 0.0,
deff: float = 1.0,
) -> pd.DataFrame:
"""
Compute power for a range of effect sizes.
Parameters
----------
n_treated : int
Number of treated units.
n_control : int
Number of control units.
sigma : float
Residual standard deviation.
effect_sizes : list of float, optional
Effect sizes to evaluate. If None, uses a range from 0 to 3*MDE.
n_pre : int, default=1
Number of pre-treatment periods.
n_post : int, default=1
Number of post-treatment periods.
rho : float, default=0.0
Within-unit (serial) equicorrelation for panel designs. Higher rho
LOWERS the MDE (Burlig et al. 2020, Eq. 2, equicorrelated case);
valid range [-1/(T-1), 1).
deff : float, default=1.0
Survey design effect (variance inflation factor).
Returns
-------
pd.DataFrame
DataFrame with columns 'effect_size' and 'power'.
Examples
--------
>>> pa = PowerAnalysis()
>>> curve = pa.power_curve(n_treated=50, n_control=50, sigma=5.0)
>>> print(curve)
"""
# First get MDE to determine default range
mde_result = self.mde(n_treated, n_control, sigma, n_pre, n_post, rho, deff=deff)
if effect_sizes is None:
# Generate range from 0 to 2*MDE
effect_sizes = np.linspace(0, 2.5 * mde_result.mde, 50).tolist()
powers = []
for es in effect_sizes:
result = self.power(
effect_size=es,
n_treated=n_treated,
n_control=n_control,
sigma=sigma,
n_pre=n_pre,
n_post=n_post,
rho=rho,
deff=deff,
)
powers.append(result.power)
return pd.DataFrame({"effect_size": effect_sizes, "power": powers})
[docs]
def sample_size_curve(
self,
effect_size: float,
sigma: float,
sample_sizes: Optional[List[int]] = None,
n_pre: int = 1,
n_post: int = 1,
rho: float = 0.0,
treat_frac: float = 0.5,
deff: float = 1.0,
) -> pd.DataFrame:
"""
Compute power for a range of sample sizes.
Parameters
----------
effect_size : float
Treatment effect size.
sigma : float
Residual standard deviation.
sample_sizes : list of int, optional
Total sample sizes to evaluate. If None, uses sensible range.
n_pre : int, default=1
Number of pre-treatment periods.
n_post : int, default=1
Number of post-treatment periods.
rho : float, default=0.0
Within-unit (serial) equicorrelation for panel designs. Higher rho
LOWERS the MDE (Burlig et al. 2020, Eq. 2, equicorrelated case);
valid range [-1/(T-1), 1).
treat_frac : float, default=0.5
Fraction assigned to treatment.
deff : float, default=1.0
Survey design effect (variance inflation factor).
Returns
-------
pd.DataFrame
DataFrame with columns 'sample_size' and 'power'.
"""
# Get required N to determine default range
required = self.sample_size(effect_size, sigma, n_pre, n_post, rho, treat_frac, deff=deff)
if sample_sizes is None:
min_n = max(10, required.required_n // 4)
max_n = required.required_n * 2
sample_sizes = list(range(min_n, max_n + 1, max(1, (max_n - min_n) // 50)))
powers = []
for n in sample_sizes:
n_treated = max(2, int(n * treat_frac))
n_control = max(2, n - n_treated)
result = self.power(
effect_size=effect_size,
n_treated=n_treated,
n_control=n_control,
sigma=sigma,
n_pre=n_pre,
n_post=n_post,
rho=rho,
deff=deff,
)
powers.append(result.power)
return pd.DataFrame({"sample_size": sample_sizes, "power": powers})
[docs]
def simulate_power(
estimator: Any,
n_units: int = 100,
n_periods: int = 4,
treatment_effect: float = 5.0,
treatment_fraction: float = 0.5,
treatment_period: int = 2,
sigma: float = 1.0,
n_simulations: int = 500,
alpha: float = 0.05,
effect_sizes: Optional[List[float]] = None,
seed: Optional[int] = None,
data_generator: Optional[Callable] = None,
data_generator_kwargs: Optional[Dict[str, Any]] = None,
estimator_kwargs: Optional[Dict[str, Any]] = None,
result_extractor: Optional[Callable] = None,
progress: bool = True,
survey_config: Optional[SurveyPowerConfig] = None,
) -> SimulationPowerResults:
"""
Estimate power using Monte Carlo simulation.
This function simulates datasets with known treatment effects and estimates
power as the fraction of simulations where the null hypothesis is rejected.
Most built-in estimators are supported via an internal registry that selects
the appropriate data-generating process and fit signature automatically.
Parameters
----------
estimator : estimator object
DiD estimator to use (e.g., DifferenceInDifferences, CallawaySantAnna).
n_units : int, default=100
Number of units per simulation.
n_periods : int, default=4
Number of time periods.
treatment_effect : float, default=5.0
True treatment effect to simulate.
treatment_fraction : float, default=0.5
Fraction of units that are treated.
treatment_period : int, default=2
First post-treatment period (0-indexed).
sigma : float, default=1.0
Residual standard deviation (noise level).
n_simulations : int, default=500
Number of Monte Carlo simulations.
alpha : float, default=0.05
Significance level for hypothesis tests.
effect_sizes : list of float, optional
Multiple effect sizes to evaluate for power curve.
If None, uses only treatment_effect.
seed : int, optional
Random seed for reproducibility.
data_generator : callable, optional
Custom data generation function. When provided, bypasses the
registry DGP and calls this function with the standard kwargs
(n_units, n_periods, treatment_effect, etc.).
data_generator_kwargs : dict, optional
Additional keyword arguments for data generator.
estimator_kwargs : dict, optional
Additional keyword arguments for estimator.fit().
result_extractor : callable, optional
Custom function to extract results from the estimator output.
Takes the estimator result object and returns a tuple of
``(att, se, p_value, conf_int)``. Useful for unregistered
estimators with non-standard result schemas.
progress : bool, default=True
Whether to print progress updates.
survey_config : SurveyPowerConfig, optional
When provided, generates survey-structured data via
``generate_survey_did_data`` and injects ``SurveyDesign`` into
estimator ``fit()``. Mutually exclusive with ``data_generator``.
Supported estimators: DiD, TWFE, MultiPeriod, CS, SA, Imputation,
TwoStage, Stacked, Efficient. Unsupported: TROP, SyntheticDiD,
TripleDifference. ``heterogeneous_te_by_strata`` must be False.
Returns
-------
SimulationPowerResults
Simulation-based power analysis results.
Examples
--------
Basic power simulation:
>>> from diff_diff import DifferenceInDifferences, simulate_power
>>> did = DifferenceInDifferences()
>>> results = simulate_power(
... estimator=did,
... n_units=100,
... treatment_effect=5.0,
... sigma=5.0,
... n_simulations=500,
... seed=42
... )
>>> print(f"Power: {results.power:.1%}")
Power curve over multiple effect sizes:
>>> results = simulate_power(
... estimator=did,
... effect_sizes=[1.0, 2.0, 3.0, 5.0, 7.0],
... n_simulations=200,
... seed=42
... )
>>> print(results.power_curve_df())
With Callaway-Sant'Anna (auto-detected, no custom DGP needed):
>>> from diff_diff import CallawaySantAnna
>>> cs = CallawaySantAnna()
>>> results = simulate_power(cs, n_simulations=200, seed=42)
Notes
-----
The simulation approach:
1. Generate data with known treatment effect
2. Fit the estimator and record the p-value
3. Repeat n_simulations times
4. Power = fraction of simulations where p-value < alpha
The analytical reference formulas this Monte Carlo path complements (the Bloom 1995 normal
multiplier and the Burlig et al. 2020 Eq. 2 equicorrelated panel variance) are documented in
``docs/methodology/REGISTRY.md`` ``## PowerAnalysis``.
References
----------
Burlig, F., Preonas, L., & Woerman, M. (2020). "Panel Data and Experimental Design."
"""
rng = np.random.default_rng(seed)
estimator_name = type(estimator).__name__
registry = _get_registry()
profile = registry.get(estimator_name)
# If no profile and no custom data_generator, raise
if profile is None and data_generator is None:
raise ValueError(
f"Estimator '{estimator_name}' not in registry. "
f"Provide a custom data_generator and estimator_kwargs "
f"(the full dict of keyword arguments for estimator.fit(), "
f"e.g. dict(outcome='y', treatment='treat', time='period'))."
)
# When a custom data_generator is provided, bypass registry DGP
use_custom_dgp = data_generator is not None
use_survey_dgp = survey_config is not None
# --- Survey config validation ---
if use_survey_dgp:
assert survey_config is not None # for type narrowing
if estimator_name in _SURVEY_UNSUPPORTED:
raise ValueError(
f"survey_config is not supported with {estimator_name}. "
f"generate_survey_did_data produces staggered cohort data "
f"incompatible with this estimator's DGP. Use the custom "
f"data_generator path for survey power with {estimator_name}."
)
if use_custom_dgp:
raise ValueError(
"survey_config and data_generator are mutually exclusive. "
"survey_config uses generate_survey_did_data internally."
)
if treatment_period < 1:
raise ValueError(
f"treatment_period must be >= 1 with survey_config "
f"(need at least one pre-treatment period), got {treatment_period}."
)
if estimator_name not in _SURVEY_FIT_BUILDERS:
raise ValueError(
f"No survey power profile for {estimator_name}. "
f"Supported: {sorted(_SURVEY_FIT_BUILDERS.keys())}."
)
if survey_config.heterogeneous_te_by_strata:
raise ValueError(
"heterogeneous_te_by_strata=True is not supported with "
"simulation power analysis. The DGP's population ATT diverges "
"from the input treatment_effect under heterogeneous effects, "
"which would make bias/coverage/RMSE metrics misleading."
)
data_gen_kwargs = data_generator_kwargs or {}
est_kwargs = estimator_kwargs or {}
# Block survey_design in estimator_kwargs when survey_config is active.
# Custom survey design overrides go through SurveyPowerConfig.survey_design.
if use_survey_dgp and "survey_design" in est_kwargs:
raise ValueError(
"estimator_kwargs cannot contain 'survey_design' when survey_config "
"is set. To override the auto-built SurveyDesign, pass it via "
"SurveyPowerConfig(survey_design=...)."
)
# Block survey-config-managed keys in data_generator_kwargs
if use_survey_dgp and data_gen_kwargs:
collisions = _SURVEY_CONFIG_KEYS & set(data_gen_kwargs)
if collisions:
raise ValueError(
f"data_generator_kwargs contains keys managed by survey_config: "
f"{sorted(collisions)}. Set these on SurveyPowerConfig instead."
)
# Block DGP params that make realized ATT diverge from scalar input,
# which would misstate bias/coverage/RMSE (same rationale as
# heterogeneous_te_by_strata rejection above).
te_interaction = data_gen_kwargs.get("te_covariate_interaction", 0.0)
if te_interaction != 0.0:
raise ValueError(
f"te_covariate_interaction={te_interaction} is not supported "
f"with survey_config. The DGP's population ATT diverges from "
f"the input treatment_effect under covariate-interaction "
f"heterogeneity, which would make bias/coverage/RMSE misleading."
)
# Enforce panel-mode alignment between DGP and estimator.
# Runs even with empty data_gen_kwargs to catch CS(panel=False) + default DGP.
if use_survey_dgp:
dgp_panel = data_gen_kwargs.get("panel", True)
est_panel = getattr(estimator, "panel", True)
if not dgp_panel:
if estimator_name != "CallawaySantAnna":
raise ValueError(
f"panel=False (repeated cross-sections) is not supported "
f"with {estimator_name} under survey_config. Only "
f"CallawaySantAnna supports repeated cross-sections."
)
if est_panel:
raise ValueError(
"data_generator_kwargs has panel=False but "
"CallawaySantAnna.panel=True. Use "
"CallawaySantAnna(panel=False) to match."
)
elif estimator_name == "CallawaySantAnna" and not est_panel:
raise ValueError(
"CallawaySantAnna(panel=False) requires "
"data_generator_kwargs={'panel': False} to generate "
"repeated cross-section data."
)
# Reject estimator settings that require a multi-cohort DGP.
# survey_config hard-codes a single-cohort DGP and blocks
# cohort_periods/never_treated_frac overrides.
control_group = getattr(estimator, "control_group", "never_treated")
clean_control = getattr(estimator, "clean_control", None)
if control_group in ("not_yet_treated", "last_cohort"):
raise ValueError(
f"survey_config does not support control_group='{control_group}' "
f"(requires multi-cohort DGP). Use the custom data_generator "
f"path for survey power with this control-group design."
)
if clean_control == "strict":
raise ValueError(
f"survey_config does not support clean_control='strict' "
f"(requires multi-cohort DGP). Use the custom data_generator "
f"path for survey power with strict clean controls."
)
# SyntheticDiD placebo variance requires n_control > n_treated.
# Check after merging data_generator_kwargs so overrides of n_treated
# are accounted for.
if estimator_name == "SyntheticDiD" and not use_custom_dgp:
vm = getattr(estimator, "variance_method", "placebo")
effective_n_treated = data_gen_kwargs.get(
"n_treated", max(1, int(n_units * treatment_fraction))
)
n_control = n_units - effective_n_treated
if vm == "placebo" and n_control <= effective_n_treated:
raise ValueError(
f"SyntheticDiD placebo variance requires more control than "
f"treated units (got n_control={n_control}, "
f"n_treated={effective_n_treated}). Either lower "
f"treatment_fraction so that n_control > n_treated, or use "
f"SyntheticDiD(variance_method='bootstrap') (paper-faithful refit; "
f"~5-30x slower than placebo) or "
f"SyntheticDiD(variance_method='jackknife')."
)
# Warn if staggered estimator settings don't match auto DGP
if profile is not None and not use_custom_dgp:
_check_staggered_dgp_compat(estimator, data_generator_kwargs)
# Block registry-path collisions on search-critical keys
if profile is not None and not use_custom_dgp and data_gen_kwargs:
sample_dgp_keys = set(
profile.dgp_kwargs_builder(
n_units=n_units,
n_periods=n_periods,
treatment_effect=treatment_effect,
treatment_fraction=treatment_fraction,
treatment_period=treatment_period,
sigma=sigma,
).keys()
)
collisions = _PROTECTED_DGP_KEYS & set(data_gen_kwargs) & sample_dgp_keys
if collisions:
raise ValueError(
f"data_generator_kwargs contains keys that conflict with "
f"registry-managed simulation inputs: {sorted(collisions)}. "
f"These are controlled by simulate_power() parameters directly. "
f"Use the corresponding function parameters instead, or pass a "
f"custom data_generator to override the DGP entirely."
)
# Warn if DDD design inputs are silently ignored
if estimator_name == "TripleDifference" and not use_custom_dgp:
_check_ddd_dgp_compat(
n_units,
n_periods,
treatment_fraction,
treatment_period,
data_generator_kwargs,
)
effective_n_units = _ddd_effective_n(n_units, data_generator_kwargs)
else:
effective_n_units = None
# Determine effect sizes to test
if effect_sizes is None:
effect_sizes = [treatment_effect]
all_powers = []
# For the primary effect, collect detailed results
if len(effect_sizes) == 1:
primary_idx = 0
else:
primary_idx = -1
for i, es in enumerate(effect_sizes):
if np.isclose(es, treatment_effect):
primary_idx = i
break
if primary_idx == -1:
primary_idx = len(effect_sizes) - 1
primary_effect = effect_sizes[primary_idx]
# Initialize so they are always bound
primary_estimates: List[float] = []
primary_ses: List[float] = []
primary_p_values: List[float] = []
primary_rejections: List[bool] = []
primary_ci_contains: List[bool] = []
primary_n_failures = 0
# Survey DGP truth accumulation (DEFF/ICC are DGP properties,
# independent of effect size, so averaging across all sims is correct)
deff_values: List[float] = []
icc_values: List[float] = []
# Lazy import for survey DGP (mirrors registry's lazy import pattern)
_generate_survey_did_data: Optional[Callable] = None
if use_survey_dgp:
from diff_diff.prep import generate_survey_did_data as _generate_survey_did_data
for effect_idx, effect in enumerate(effect_sizes):
is_primary = effect_idx == primary_idx
estimates: List[float] = []
ses: List[float] = []
p_values: List[float] = []
rejections: List[bool] = []
ci_contains_true: List[bool] = []
n_failures = 0
for sim in range(n_simulations):
if progress and sim % 100 == 0 and sim > 0:
pct = (sim + effect_idx * n_simulations) / (len(effect_sizes) * n_simulations)
print(f" Simulation progress: {pct:.0%}")
sim_seed = rng.integers(0, 2**31)
# --- Generate data ---
if use_survey_dgp:
assert survey_config is not None
assert _generate_survey_did_data is not None
dgp_kwargs = _survey_dgp_kwargs(
n_units=n_units,
n_periods=n_periods,
treatment_effect=effect,
treatment_fraction=treatment_fraction,
treatment_period=treatment_period,
sigma=sigma,
survey_config=survey_config,
)
dgp_kwargs.update(data_gen_kwargs)
dgp_kwargs.pop("seed", None)
data = _generate_survey_did_data(seed=sim_seed, **dgp_kwargs)
# Derive columns for non-staggered estimators.
# Survey DGP's `treated` is time-varying (1{g>0, t>=g}); basic/TWFE/
# MultiPeriod need a time-invariant group indicator (`ever_treated`).
if estimator_name not in _STAGGERED_ESTIMATORS:
data["ever_treated"] = (data["first_treat"] > 0).astype(int)
# Basic/TWFE also need a `post` period indicator.
if estimator_name in _SURVEY_POST_ESTIMATORS:
data["post"] = (data["period"] >= treatment_period + 1).astype(int)
# Collect DGP truth for metadata
dgp_truth = data.attrs.get("dgp_truth", {})
if dgp_truth:
kish = dgp_truth.get("deff_kish")
icc_r = dgp_truth.get("icc_realized")
if kish is not None:
deff_values.append(kish)
if icc_r is not None:
icc_values.append(icc_r)
elif use_custom_dgp:
assert data_generator is not None
data = data_generator(
n_units=n_units,
n_periods=n_periods,
treatment_effect=effect,
treatment_fraction=treatment_fraction,
treatment_period=treatment_period,
noise_sd=sigma,
seed=sim_seed,
**data_gen_kwargs,
)
else:
assert profile is not None
dgp_kwargs = profile.dgp_kwargs_builder(
n_units=n_units,
n_periods=n_periods,
treatment_effect=effect,
treatment_fraction=treatment_fraction,
treatment_period=treatment_period,
sigma=sigma,
)
dgp_kwargs.update(data_gen_kwargs)
dgp_kwargs.pop("seed", None)
data = profile.default_dgp(seed=sim_seed, **dgp_kwargs)
# Check SDID placebo feasibility on realized data (custom DGP path)
if effect_idx == 0 and sim == 0 and estimator_name == "SyntheticDiD":
_check_sdid_placebo_data(data, estimator, est_kwargs)
try:
# --- Fit estimator ---
if use_survey_dgp:
assert survey_config is not None
fit_builder = _SURVEY_FIT_BUILDERS[estimator_name]
fit_kwargs = fit_builder(
data, n_units, n_periods, treatment_period, survey_config
)
fit_kwargs.update(est_kwargs)
elif profile is not None and not use_custom_dgp:
fit_kwargs = profile.fit_kwargs_builder(
data, n_units, n_periods, treatment_period
)
fit_kwargs.update(est_kwargs)
else:
# Custom DGP fallback: use registry fit kwargs if available,
# otherwise use basic DiD signature
if profile is not None:
fit_kwargs = profile.fit_kwargs_builder(
data, n_units, n_periods, treatment_period
)
fit_kwargs.update(est_kwargs)
else:
fit_kwargs = dict(est_kwargs)
result = estimator.fit(data, **fit_kwargs)
# --- Extract results ---
if profile is not None:
att, se, p_val, ci = profile.result_extractor(result)
elif result_extractor is not None:
att, se, p_val, ci = result_extractor(result)
else:
att = result.att if hasattr(result, "att") else result.avg_att
se = result.se if hasattr(result, "se") else result.avg_se
p_val = result.p_value if hasattr(result, "p_value") else result.avg_p_value
ci = result.conf_int if hasattr(result, "conf_int") else result.avg_conf_int
# NaN p-value → treat as non-rejection
rejected = bool(p_val < alpha) if not np.isnan(p_val) else False
estimates.append(att)
ses.append(se)
p_values.append(p_val)
rejections.append(rejected)
ci_contains_true.append(ci[0] <= effect <= ci[1])
except (
ValueError,
np.linalg.LinAlgError,
KeyError,
RuntimeError,
ZeroDivisionError,
) as e:
n_failures += 1
if progress:
print(f" Warning: Simulation {sim} failed: {e}")
continue
# Warn if too many simulations failed
failure_rate = n_failures / n_simulations
if failure_rate > 0.1:
warnings.warn(
f"{n_failures}/{n_simulations} simulations ({failure_rate:.1%}) "
f"failed for effect_size={effect}. "
f"Check estimator and data generator.",
UserWarning,
)
if len(estimates) == 0:
raise RuntimeError("All simulations failed. Check estimator and data generator.")
power_val = np.mean(rejections)
all_powers.append(power_val)
if is_primary:
primary_estimates = estimates
primary_ses = ses
primary_p_values = p_values
primary_rejections = rejections
primary_ci_contains = ci_contains_true
primary_n_failures = n_failures
# Compute confidence interval for power (primary effect)
power_val = all_powers[primary_idx]
n_valid = len(primary_rejections)
power_se = np.sqrt(power_val * (1 - power_val) / n_valid)
z = stats.norm.ppf(0.975)
power_ci = (
max(0.0, power_val - z * power_se),
min(1.0, power_val + z * power_se),
)
mean_estimate = np.mean(primary_estimates)
std_estimate = np.std(primary_estimates, ddof=1)
mean_se = np.mean(primary_ses)
coverage = np.mean(primary_ci_contains)
return SimulationPowerResults(
power=power_val,
power_se=power_se,
power_ci=power_ci,
rejection_rate=power_val,
mean_estimate=mean_estimate,
std_estimate=std_estimate,
mean_se=mean_se,
coverage=coverage,
n_simulations=n_valid,
n_simulation_failures=primary_n_failures,
effect_sizes=effect_sizes,
powers=all_powers,
true_effect=primary_effect,
alpha=alpha,
estimator_name=estimator_name,
simulation_results=[
{"estimate": e, "se": s, "p_value": p, "rejected": r}
for e, s, p, r in zip(
primary_estimates,
primary_ses,
primary_p_values,
primary_rejections,
)
],
effective_n_units=effective_n_units,
survey_config=survey_config,
mean_deff=float(np.nanmean(deff_values)) if deff_values else None,
mean_icc_realized=float(np.nanmean(icc_values)) if icc_values else None,
)
# ---------------------------------------------------------------------------
# Simulation-based MDE and sample-size search
# ---------------------------------------------------------------------------
[docs]
@dataclass
class SimulationMDEResults:
"""
Results from simulation-based minimum detectable effect search.
Attributes
----------
mde : float
Minimum detectable effect (smallest effect achieving target power).
power_at_mde : float
Power achieved at the MDE.
target_power : float
Target power used in the search.
alpha : float
Significance level.
n_units : int
Sample size used.
n_simulations_per_step : int
Number of simulations per bisection step.
n_steps : int
Number of bisection steps performed.
search_path : list of dict
Diagnostic trace of ``{effect_size, power}`` at each step.
estimator_name : str
Name of the estimator used.
effective_n_units : int or None
Effective sample size when it differs from the requested ``n_units``
(e.g., due to DDD grid rounding). ``None`` when no rounding occurred.
"""
mde: float
power_at_mde: float
target_power: float
alpha: float
n_units: int
n_simulations_per_step: int
n_steps: int
search_path: List[Dict[str, float]]
estimator_name: str
effective_n_units: Optional[int] = None
survey_config: Optional[Any] = field(default=None, repr=False)
def __repr__(self) -> str:
return (
f"SimulationMDEResults(mde={self.mde:.4f}, "
f"power_at_mde={self.power_at_mde:.3f}, "
f"n_steps={self.n_steps})"
)
[docs]
def summary(self) -> str:
"""Generate a formatted summary."""
lines = [
"=" * 65,
"Simulation-Based MDE Results".center(65),
"=" * 65,
"",
f"{'Estimator:':<35} {self.estimator_name}",
f"{'Significance level (alpha):':<35} {self.alpha:.3f}",
f"{'Target power:':<35} {self.target_power:.1%}",
f"{'Sample size (n_units):':<35} {self.n_units}",
]
if self.effective_n_units is not None:
lines.append(
f"{'Effective sample size:':<35} {self.effective_n_units}" f" (DDD grid-rounded)"
)
lines += [
f"{'Simulations per step:':<35} {self.n_simulations_per_step}",
"",
"-" * 65,
"Search Results".center(65),
"-" * 65,
f"{'Minimum detectable effect:':<35} {self.mde:.4f}",
f"{'Power at MDE:':<35} {self.power_at_mde:.1%}",
f"{'Bisection steps:':<35} {self.n_steps}",
"=" * 65,
]
return "\n".join(lines)
[docs]
def to_dict(self) -> Dict[str, Any]:
"""Convert results to a dictionary."""
return {
"mde": self.mde,
"power_at_mde": self.power_at_mde,
"target_power": self.target_power,
"alpha": self.alpha,
"n_units": self.n_units,
"effective_n_units": self.effective_n_units,
"n_simulations_per_step": self.n_simulations_per_step,
"n_steps": self.n_steps,
"estimator_name": self.estimator_name,
}
[docs]
def to_dataframe(self) -> pd.DataFrame:
"""Convert results to a single-row DataFrame."""
return pd.DataFrame([self.to_dict()])
[docs]
@dataclass
class SimulationSampleSizeResults:
"""
Results from simulation-based sample size search.
Attributes
----------
required_n : int
Required number of units to achieve target power.
power_at_n : float
Power achieved at the required N.
target_power : float
Target power used in the search.
alpha : float
Significance level.
effect_size : float
Effect size used in the search.
n_simulations_per_step : int
Number of simulations per bisection step.
n_steps : int
Number of bisection steps performed.
search_path : list of dict
Diagnostic trace of ``{n_units, power}`` at each step.
estimator_name : str
Name of the estimator used.
effective_n_units : int or None
Effective sample size when it differs from ``required_n``
(e.g., due to DDD grid rounding). ``None`` when no rounding occurred
or when the search already snapped to the estimator's grid.
"""
required_n: int
power_at_n: float
target_power: float
alpha: float
effect_size: float
n_simulations_per_step: int
n_steps: int
search_path: List[Dict[str, float]]
estimator_name: str
effective_n_units: Optional[int] = None
survey_config: Optional[Any] = field(default=None, repr=False)
def __repr__(self) -> str:
return (
f"SimulationSampleSizeResults(required_n={self.required_n}, "
f"power_at_n={self.power_at_n:.3f}, "
f"n_steps={self.n_steps})"
)
[docs]
def summary(self) -> str:
"""Generate a formatted summary."""
lines = [
"=" * 65,
"Simulation-Based Sample Size Results".center(65),
"=" * 65,
"",
f"{'Estimator:':<35} {self.estimator_name}",
f"{'Significance level (alpha):':<35} {self.alpha:.3f}",
f"{'Target power:':<35} {self.target_power:.1%}",
f"{'Effect size:':<35} {self.effect_size:.4f}",
f"{'Simulations per step:':<35} {self.n_simulations_per_step}",
"",
"-" * 65,
"Search Results".center(65),
"-" * 65,
f"{'Required sample size:':<35} {self.required_n}",
f"{'Power at required N:':<35} {self.power_at_n:.1%}",
f"{'Bisection steps:':<35} {self.n_steps}",
]
if self.effective_n_units is not None:
lines.append(
f"{'Effective sample size:':<35} {self.effective_n_units}" f" (DDD grid-rounded)"
)
lines.append("=" * 65)
return "\n".join(lines)
[docs]
def to_dict(self) -> Dict[str, Any]:
"""Convert results to a dictionary."""
return {
"required_n": self.required_n,
"power_at_n": self.power_at_n,
"target_power": self.target_power,
"alpha": self.alpha,
"effect_size": self.effect_size,
"n_simulations_per_step": self.n_simulations_per_step,
"n_steps": self.n_steps,
"estimator_name": self.estimator_name,
"effective_n_units": self.effective_n_units,
}
[docs]
def to_dataframe(self) -> pd.DataFrame:
"""Convert results to a single-row DataFrame."""
return pd.DataFrame([self.to_dict()])
[docs]
def simulate_mde(
estimator: Any,
n_units: int = 100,
n_periods: int = 4,
treatment_fraction: float = 0.5,
treatment_period: int = 2,
sigma: float = 1.0,
n_simulations: int = 200,
power: float = 0.80,
alpha: float = 0.05,
effect_range: Optional[Tuple[float, float]] = None,
tol: float = 0.02,
max_steps: int = 15,
seed: Optional[int] = None,
data_generator: Optional[Callable] = None,
data_generator_kwargs: Optional[Dict[str, Any]] = None,
estimator_kwargs: Optional[Dict[str, Any]] = None,
result_extractor: Optional[Callable] = None,
progress: bool = True,
survey_config: Optional[SurveyPowerConfig] = None,
) -> SimulationMDEResults:
"""
Find the minimum detectable effect via simulation-based bisection search.
Searches over effect sizes to find the smallest effect that achieves the
target power, using ``simulate_power()`` at each step.
Parameters
----------
estimator : estimator object
DiD estimator to use.
n_units : int, default=100
Number of units per simulation.
n_periods : int, default=4
Number of time periods.
treatment_fraction : float, default=0.5
Fraction of units that are treated.
treatment_period : int, default=2
First post-treatment period (0-indexed).
sigma : float, default=1.0
Residual standard deviation.
n_simulations : int, default=200
Simulations per bisection step.
power : float, default=0.80
Target power.
alpha : float, default=0.05
Significance level.
effect_range : tuple of (float, float), optional
``(lo, hi)`` bracket for the search. If None, auto-brackets.
tol : float, default=0.02
Convergence tolerance on power.
max_steps : int, default=15
Maximum bisection steps.
seed : int, optional
Random seed for reproducibility.
data_generator : callable, optional
Custom data generation function.
data_generator_kwargs : dict, optional
Additional keyword arguments for data generator.
estimator_kwargs : dict, optional
Additional keyword arguments for estimator.fit().
result_extractor : callable, optional
Custom function to extract results from the estimator output.
Forwarded to ``simulate_power()``.
progress : bool, default=True
Whether to print progress updates.
survey_config : SurveyPowerConfig, optional
Survey-aware simulation config. Forwarded to ``simulate_power()``.
See :func:`simulate_power` for details and constraints.
Returns
-------
SimulationMDEResults
Results including the MDE and search diagnostics.
Examples
--------
>>> from diff_diff import simulate_mde, DifferenceInDifferences
>>> result = simulate_mde(DifferenceInDifferences(), n_simulations=100, seed=42)
>>> print(f"MDE: {result.mde:.3f}")
"""
master_rng = np.random.default_rng(seed)
estimator_name = type(estimator).__name__
search_path: List[Dict[str, float]] = []
# Compute effective N for DDD (N is fixed throughout MDE search)
if estimator_name == "TripleDifference" and data_generator is None:
effective_n_units = _ddd_effective_n(n_units, data_generator_kwargs)
else:
effective_n_units = None
common_kwargs: Dict[str, Any] = dict(
estimator=estimator,
n_units=n_units,
n_periods=n_periods,
treatment_fraction=treatment_fraction,
treatment_period=treatment_period,
sigma=sigma,
n_simulations=n_simulations,
alpha=alpha,
data_generator=data_generator,
data_generator_kwargs=data_generator_kwargs,
estimator_kwargs=estimator_kwargs,
result_extractor=result_extractor,
progress=False,
survey_config=survey_config,
)
def _power_at(effect: float) -> float:
step_seed = int(master_rng.integers(0, 2**31))
res = simulate_power(treatment_effect=effect, seed=step_seed, **common_kwargs)
pwr = float(res.power)
search_path.append({"effect_size": effect, "power": pwr})
if progress:
print(f" MDE search: effect={effect:.4f}, power={pwr:.3f}")
return pwr
# --- Bracket ---
if effect_range is not None:
lo, hi = effect_range
power_lo = _power_at(lo)
power_hi = _power_at(hi)
if power_lo >= power:
warnings.warn(
f"Power at effect={lo} is {power_lo:.2f} >= target {power}. "
f"Lower bound already exceeds target power. Returning lo as MDE.",
UserWarning,
)
return SimulationMDEResults(
mde=lo,
power_at_mde=power_lo,
target_power=power,
alpha=alpha,
n_units=n_units,
n_simulations_per_step=n_simulations,
n_steps=len(search_path),
search_path=search_path,
estimator_name=estimator_name,
effective_n_units=effective_n_units,
survey_config=survey_config,
)
if power_hi < power:
warnings.warn(
f"Target power {power} not bracketed: power at effect={hi} "
f"is {power_hi:.2f}. Upper bound may be too low.",
UserWarning,
)
else:
lo = 0.0
# Check that power at zero is below target (no inflated Type I error)
power_at_zero = _power_at(0.0)
if power_at_zero >= power:
warnings.warn(
f"Power at effect=0 is {power_at_zero:.2f} >= target {power}. "
f"This suggests inflated Type I error. Returning MDE=0.",
UserWarning,
)
return SimulationMDEResults(
mde=0.0,
power_at_mde=power_at_zero,
target_power=power,
alpha=alpha,
n_units=n_units,
n_simulations_per_step=n_simulations,
n_steps=len(search_path),
search_path=search_path,
estimator_name=estimator_name,
survey_config=survey_config,
effective_n_units=effective_n_units,
)
hi = sigma
for _ in range(10):
if _power_at(hi) >= power:
break
hi *= 2
else:
warnings.warn(
f"Could not bracket MDE (power at effect={hi} still below "
f"{power}). Returning best upper bound.",
UserWarning,
)
# --- Bisect ---
best_effect = hi
best_power = search_path[-1]["power"] if search_path else 0.0
for _ in range(max_steps):
mid = (lo + hi) / 2
pwr = _power_at(mid)
if pwr >= power:
hi = mid
best_effect = mid
best_power = pwr
else:
lo = mid
# Convergence: effect range is tight or power is close enough
if hi - lo < max(tol * hi, 1e-6) or abs(pwr - power) < tol:
break
return SimulationMDEResults(
mde=best_effect,
power_at_mde=best_power,
target_power=power,
alpha=alpha,
n_units=n_units,
n_simulations_per_step=n_simulations,
n_steps=len(search_path),
search_path=search_path,
estimator_name=estimator_name,
effective_n_units=effective_n_units,
survey_config=survey_config,
)
[docs]
def simulate_sample_size(
estimator: Any,
treatment_effect: float = 5.0,
n_periods: int = 4,
treatment_fraction: float = 0.5,
treatment_period: int = 2,
sigma: float = 1.0,
n_simulations: int = 200,
power: float = 0.80,
alpha: float = 0.05,
n_range: Optional[Tuple[int, int]] = None,
max_steps: int = 15,
seed: Optional[int] = None,
data_generator: Optional[Callable] = None,
data_generator_kwargs: Optional[Dict[str, Any]] = None,
estimator_kwargs: Optional[Dict[str, Any]] = None,
result_extractor: Optional[Callable] = None,
progress: bool = True,
survey_config: Optional[SurveyPowerConfig] = None,
) -> SimulationSampleSizeResults:
"""
Find the required sample size via simulation-based bisection search.
Searches over ``n_units`` to find the smallest N that achieves the
target power, using ``simulate_power()`` at each step.
Parameters
----------
estimator : estimator object
DiD estimator to use.
treatment_effect : float, default=5.0
True treatment effect to simulate.
n_periods : int, default=4
Number of time periods.
treatment_fraction : float, default=0.5
Fraction of units that are treated.
treatment_period : int, default=2
First post-treatment period (0-indexed).
sigma : float, default=1.0
Residual standard deviation.
n_simulations : int, default=200
Simulations per bisection step.
power : float, default=0.80
Target power.
alpha : float, default=0.05
Significance level.
n_range : tuple of (int, int), optional
``(lo, hi)`` bracket for sample size. If None, auto-brackets.
max_steps : int, default=15
Maximum bisection steps.
seed : int, optional
Random seed for reproducibility.
data_generator : callable, optional
Custom data generation function.
data_generator_kwargs : dict, optional
Additional keyword arguments for data generator.
estimator_kwargs : dict, optional
Additional keyword arguments for estimator.fit().
result_extractor : callable, optional
Custom function to extract results from the estimator output.
Forwarded to ``simulate_power()``.
progress : bool, default=True
Whether to print progress updates.
survey_config : SurveyPowerConfig, optional
Survey-aware simulation config. Forwarded to ``simulate_power()``.
When set, the bisection floor is raised to
``survey_config.min_viable_n`` to ensure viable survey structure.
See :func:`simulate_power` for details and constraints.
Returns
-------
SimulationSampleSizeResults
Results including the required N and search diagnostics.
Examples
--------
>>> from diff_diff import simulate_sample_size, DifferenceInDifferences
>>> result = simulate_sample_size(
... DifferenceInDifferences(), treatment_effect=5.0, n_simulations=100, seed=42
... )
>>> print(f"Required N: {result.required_n}")
"""
master_rng = np.random.default_rng(seed)
estimator_name = type(estimator).__name__
search_path: List[Dict[str, float]] = []
# Determine min_n from registry
registry = _get_registry()
profile = registry.get(estimator_name)
min_n = profile.min_n if profile is not None else 20
# DDD grid snapping: bisection candidates must be multiples of 8
is_ddd_grid = estimator_name == "TripleDifference" and data_generator is None
grid_step = 8 if is_ddd_grid else 1
convergence_threshold = grid_step + 1 # 9 for DDD, 2 for others
if is_ddd_grid and data_generator_kwargs and "n_per_cell" in data_generator_kwargs:
raise ValueError(
"data_generator_kwargs contains 'n_per_cell', which conflicts with "
"the sample-size search in simulate_sample_size(). For "
"TripleDifference, n_per_cell is derived from n_units (the search "
"variable). Use simulate_power() with a fixed n_per_cell override "
"instead, or pass a custom data_generator."
)
def _snap_n(n: int, direction: str = "down", floor: Optional[int] = None) -> int:
actual_floor = floor if floor is not None else min_n
if grid_step == 1:
return max(actual_floor, n)
if direction == "up":
return max(actual_floor, ((n + grid_step - 1) // grid_step) * grid_step)
return max(actual_floor, (n // grid_step) * grid_step)
common_kwargs: Dict[str, Any] = dict(
estimator=estimator,
n_periods=n_periods,
treatment_effect=treatment_effect,
treatment_fraction=treatment_fraction,
treatment_period=treatment_period,
sigma=sigma,
n_simulations=n_simulations,
alpha=alpha,
data_generator=data_generator,
data_generator_kwargs=data_generator_kwargs,
estimator_kwargs=estimator_kwargs,
result_extractor=result_extractor,
progress=False,
survey_config=survey_config,
)
def _power_at_n(n: int) -> float:
step_seed = int(master_rng.integers(0, 2**31))
res = simulate_power(n_units=n, seed=step_seed, **common_kwargs)
pwr = float(res.power)
search_path.append({"n_units": float(n), "power": pwr})
if progress:
print(f" Sample size search: n={n}, power={pwr:.3f}")
return pwr
# Block strata_sizes in sample-size search (same class as n_per_cell for DDD):
# strata_sizes requires sum(strata_sizes) == n_units, but n_units varies
# during bisection so a fixed strata_sizes would fail mid-search.
if survey_config is not None and data_generator_kwargs:
if "strata_sizes" in data_generator_kwargs:
raise ValueError(
"strata_sizes in data_generator_kwargs is not supported with "
"simulate_sample_size() because n_units varies during the "
"bisection search. Use simulate_power() with a fixed n_units "
"and strata_sizes instead."
)
# --- Bracket ---
abs_min = 16 if is_ddd_grid else 4
if survey_config is not None:
abs_min = max(abs_min, survey_config.min_viable_n)
if n_range is not None:
lo, hi = _snap_n(n_range[0], "up", floor=abs_min), _snap_n(
n_range[1], "down", floor=abs_min
)
if lo > hi:
lo = hi # collapsed bracket — evaluate single point
power_lo = _power_at_n(lo)
if power_lo >= power:
warnings.warn(
f"Power at n={lo} is {power_lo:.2f} >= target {power}. "
f"Lower bound already achieves target power. Returning lo.",
UserWarning,
)
return SimulationSampleSizeResults(
required_n=lo,
power_at_n=power_lo,
target_power=power,
alpha=alpha,
effect_size=treatment_effect,
n_simulations_per_step=n_simulations,
n_steps=len(search_path),
search_path=search_path,
estimator_name=estimator_name,
survey_config=survey_config,
)
power_hi = _power_at_n(hi)
if power_hi < power:
warnings.warn(
f"Target power {power} not bracketed: power at n={hi} "
f"is {power_hi:.2f}. Upper bound may be too low.",
UserWarning,
)
else:
lo = max(min_n, abs_min)
power_lo = _power_at_n(lo)
if power_lo >= power:
# Floor achieves target — search downward for true minimum
hi = lo
found_lower = False
probe = _snap_n(max(abs_min, lo // 2), floor=abs_min)
for _ in range(8):
if probe >= hi or probe < abs_min:
break
pwr = _power_at_n(probe)
if pwr < power:
lo = probe
found_lower = True
break
hi = probe
probe = _snap_n(max(abs_min, probe // 2), floor=abs_min)
if not found_lower:
# Even smallest viable N achieves target — return best found
best = min(
(s for s in search_path if s["power"] >= power),
key=lambda s: s["n_units"],
)
# Clamp to abs_min (enforces survey min_viable_n contract)
best_n = max(int(best["n_units"]), abs_min)
warnings.warn(
f"Power at n={best_n} is "
f"{best['power']:.2f} >= target {power}. Could not "
f"find a smaller N below target power. Pass "
f"n_range=(lo, hi) to refine.",
UserWarning,
)
return SimulationSampleSizeResults(
required_n=best_n,
power_at_n=best["power"],
target_power=power,
alpha=alpha,
effect_size=treatment_effect,
n_simulations_per_step=n_simulations,
n_steps=len(search_path),
search_path=search_path,
estimator_name=estimator_name,
survey_config=survey_config,
)
# Fall through to bisection with lo..hi bracket
else:
hi = max(2 * lo, abs_min, 100)
for _ in range(10):
if _power_at_n(hi) >= power:
break
hi *= 2
else:
warnings.warn(
f"Could not bracket required N (power at n={hi} still "
f"below {power}). Returning best upper bound.",
UserWarning,
)
# --- Bisect on integer n_units ---
best_n = hi
# Look up power at hi (search_path[-1] may not be hi after downward search)
best_power = next(
(s["power"] for s in reversed(search_path) if int(s["n_units"]) == hi),
search_path[-1]["power"] if search_path else 0.0,
)
for _ in range(max_steps):
if hi - lo <= convergence_threshold:
break
mid = _snap_n((lo + hi) // 2, floor=abs_min)
if mid <= lo or mid >= hi:
break
pwr = _power_at_n(mid)
if pwr >= power:
hi = mid
best_n = mid
best_power = pwr
else:
lo = mid
# Final answer is hi (conservative ceiling) — skip if already evaluated
if best_n != hi:
final_pwr = _power_at_n(hi)
if final_pwr >= power:
best_n = hi
best_power = final_pwr
return SimulationSampleSizeResults(
required_n=best_n,
power_at_n=best_power,
target_power=power,
alpha=alpha,
effect_size=treatment_effect,
n_simulations_per_step=n_simulations,
n_steps=len(search_path),
search_path=search_path,
estimator_name=estimator_name,
survey_config=survey_config,
)
[docs]
def compute_mde(
n_treated: int,
n_control: int,
sigma: float,
power: float = 0.80,
alpha: float = 0.05,
n_pre: int = 1,
n_post: int = 1,
rho: float = 0.0,
deff: float = 1.0,
) -> float:
"""
Convenience function to compute minimum detectable effect.
Parameters
----------
n_treated : int
Number of treated units.
n_control : int
Number of control units.
sigma : float
Residual standard deviation.
power : float, default=0.80
Target statistical power.
alpha : float, default=0.05
Significance level.
n_pre : int, default=1
Number of pre-treatment periods.
n_post : int, default=1
Number of post-treatment periods.
rho : float, default=0.0
Within-unit (serial) equicorrelation for panel designs. Higher rho
LOWERS the MDE (Burlig et al. 2020, Eq. 2, equicorrelated case);
valid range [-1/(T-1), 1).
deff : float, default=1.0
Survey design effect (variance inflation factor).
Returns
-------
float
Minimum detectable effect size.
Examples
--------
>>> mde = compute_mde(n_treated=50, n_control=50, sigma=10.0)
>>> print(f"MDE: {mde:.2f}")
"""
pa = PowerAnalysis(alpha=alpha, power=power)
result = pa.mde(n_treated, n_control, sigma, n_pre, n_post, rho, deff=deff)
return result.mde
[docs]
def compute_power(
effect_size: float,
n_treated: int,
n_control: int,
sigma: float,
alpha: float = 0.05,
n_pre: int = 1,
n_post: int = 1,
rho: float = 0.0,
deff: float = 1.0,
) -> float:
"""
Convenience function to compute power for given effect and sample.
Parameters
----------
effect_size : float
Expected treatment effect.
n_treated : int
Number of treated units.
n_control : int
Number of control units.
sigma : float
Residual standard deviation.
alpha : float, default=0.05
Significance level.
n_pre : int, default=1
Number of pre-treatment periods.
n_post : int, default=1
Number of post-treatment periods.
rho : float, default=0.0
Within-unit (serial) equicorrelation for panel designs. Higher rho
LOWERS the MDE (Burlig et al. 2020, Eq. 2, equicorrelated case);
valid range [-1/(T-1), 1).
deff : float, default=1.0
Survey design effect (variance inflation factor).
Returns
-------
float
Statistical power.
Examples
--------
>>> power = compute_power(effect_size=5.0, n_treated=50, n_control=50, sigma=10.0)
>>> print(f"Power: {power:.1%}")
"""
pa = PowerAnalysis(alpha=alpha)
result = pa.power(effect_size, n_treated, n_control, sigma, n_pre, n_post, rho, deff=deff)
return result.power
[docs]
def compute_sample_size(
effect_size: float,
sigma: float,
power: float = 0.80,
alpha: float = 0.05,
n_pre: int = 1,
n_post: int = 1,
rho: float = 0.0,
treat_frac: float = 0.5,
deff: float = 1.0,
) -> int:
"""
Convenience function to compute required sample size.
Parameters
----------
effect_size : float
Treatment effect to detect.
sigma : float
Residual standard deviation.
power : float, default=0.80
Target statistical power.
alpha : float, default=0.05
Significance level.
n_pre : int, default=1
Number of pre-treatment periods.
n_post : int, default=1
Number of post-treatment periods.
rho : float, default=0.0
Within-unit (serial) equicorrelation for panel designs. Higher rho
LOWERS the MDE (Burlig et al. 2020, Eq. 2, equicorrelated case);
valid range [-1/(T-1), 1).
treat_frac : float, default=0.5
Fraction assigned to treatment.
deff : float, default=1.0
Survey design effect (variance inflation factor).
Returns
-------
int
Required total sample size.
Examples
--------
>>> n = compute_sample_size(effect_size=5.0, sigma=10.0)
>>> print(f"Required N: {n}")
"""
pa = PowerAnalysis(alpha=alpha, power=power)
result = pa.sample_size(effect_size, sigma, n_pre, n_post, rho, treat_frac, deff=deff)
return result.required_n