"""
Data preparation utilities for difference-in-differences analysis.
This module provides helper functions to prepare data for DiD estimation,
including creating treatment indicators, reshaping panel data, and
generating synthetic datasets for testing.
Data generation functions (generate_*) are defined in prep_dgp.py and
re-exported here for backward compatibility.
"""
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
# Re-export data generation functions from prep_dgp for backward compatibility
from diff_diff.prep_dgp import ( # noqa: F401
generate_continuous_did_data,
generate_ddd_data,
generate_did_data,
generate_event_study_data,
generate_factor_data,
generate_panel_data,
generate_reversible_did_data,
generate_staggered_data,
generate_staggered_ddd_data,
generate_survey_did_data,
)
from diff_diff.survey import (
ResolvedSurveyDesign,
SurveyDesign,
_compute_if_variance_fast,
_precompute_psu_scaffolding,
_PsuScaffolding,
compute_replicate_if_variance,
compute_survey_if_variance,
)
from diff_diff.utils import _compute_noise_level, _sc_weight_fw
# Constants for rank_control_units
_SIMILARITY_THRESHOLD_SD = 0.5 # Controls within this many SDs are "similar"
_OUTLIER_PENALTY_WEIGHT = 0.3 # Penalty weight for outcome outliers in treatment candidate scoring
[docs]
def make_treatment_indicator(
data: pd.DataFrame,
column: str,
treated_values: Optional[Union[Any, List[Any]]] = None,
threshold: Optional[float] = None,
above_threshold: bool = True,
new_column: str = "treated",
) -> pd.DataFrame:
"""
Create a binary treatment indicator column from various input types.
Parameters
----------
data : pd.DataFrame
Input DataFrame.
column : str
Name of the column to use for creating the treatment indicator.
treated_values : Any or list, optional
Value(s) that indicate treatment. Units with these values get
treatment=1, others get treatment=0.
threshold : float, optional
Numeric threshold for creating treatment. Used when the treatment
is based on a continuous variable (e.g., treat firms above median size).
above_threshold : bool, default=True
If True, values >= threshold are treated. If False, values <= threshold
are treated. Only used when threshold is specified.
new_column : str, default="treated"
Name of the new treatment indicator column.
Returns
-------
pd.DataFrame
DataFrame with the new treatment indicator column added.
Examples
--------
Create treatment from categorical variable:
>>> df = pd.DataFrame({'group': ['A', 'A', 'B', 'B'], 'y': [1, 2, 3, 4]})
>>> df = make_treatment_indicator(df, 'group', treated_values='A')
>>> df['treated'].tolist()
[1, 1, 0, 0]
Create treatment from numeric threshold:
>>> df = pd.DataFrame({'size': [10, 50, 100, 200], 'y': [1, 2, 3, 4]})
>>> df = make_treatment_indicator(df, 'size', threshold=75)
>>> df['treated'].tolist()
[0, 0, 1, 1]
Treat units below a threshold:
>>> df = make_treatment_indicator(df, 'size', threshold=75, above_threshold=False)
>>> df['treated'].tolist()
[1, 1, 0, 0]
"""
df = data.copy()
if treated_values is not None and threshold is not None:
raise ValueError("Specify either 'treated_values' or 'threshold', not both.")
if treated_values is None and threshold is None:
raise ValueError("Must specify either 'treated_values' or 'threshold'.")
if column not in df.columns:
raise ValueError(f"Column '{column}' not found in DataFrame.")
if treated_values is not None:
# Convert single value to list
if not isinstance(treated_values, (list, tuple, set)):
treated_values = [treated_values]
df[new_column] = df[column].isin(treated_values).astype(int)
else:
# Use threshold
if above_threshold:
df[new_column] = (df[column] >= threshold).astype(int)
else:
df[new_column] = (df[column] <= threshold).astype(int)
return df
[docs]
def make_post_indicator(
data: pd.DataFrame,
time_column: str,
post_periods: Optional[Union[Any, List[Any]]] = None,
treatment_start: Optional[Any] = None,
new_column: str = "post",
) -> pd.DataFrame:
"""
Create a binary post-treatment indicator column.
Parameters
----------
data : pd.DataFrame
Input DataFrame.
time_column : str
Name of the time/period column.
post_periods : Any or list, optional
Specific period value(s) that are post-treatment. Periods matching
these values get post=1, others get post=0.
treatment_start : Any, optional
The first post-treatment period. All periods >= this value get post=1.
Works with numeric periods, strings (sorted alphabetically), or dates.
new_column : str, default="post"
Name of the new post indicator column.
Returns
-------
pd.DataFrame
DataFrame with the new post indicator column added.
Examples
--------
Using specific post periods:
>>> df = pd.DataFrame({'year': [2018, 2019, 2020, 2021], 'y': [1, 2, 3, 4]})
>>> df = make_post_indicator(df, 'year', post_periods=[2020, 2021])
>>> df['post'].tolist()
[0, 0, 1, 1]
Using treatment start:
>>> df = make_post_indicator(df, 'year', treatment_start=2020)
>>> df['post'].tolist()
[0, 0, 1, 1]
Works with date columns:
>>> df = pd.DataFrame({'date': pd.to_datetime(['2020-01-01', '2020-06-01', '2021-01-01'])})
>>> df = make_post_indicator(df, 'date', treatment_start='2020-06-01')
"""
df = data.copy()
if post_periods is not None and treatment_start is not None:
raise ValueError("Specify either 'post_periods' or 'treatment_start', not both.")
if post_periods is None and treatment_start is None:
raise ValueError("Must specify either 'post_periods' or 'treatment_start'.")
if time_column not in df.columns:
raise ValueError(f"Column '{time_column}' not found in DataFrame.")
if post_periods is not None:
# Convert single value to list
if not isinstance(post_periods, (list, tuple, set)):
post_periods = [post_periods]
df[new_column] = df[time_column].isin(post_periods).astype(int)
else:
# Use treatment_start - convert to same type as column if needed
col_dtype = df[time_column].dtype
if pd.api.types.is_datetime64_any_dtype(col_dtype):
treatment_start = pd.to_datetime(treatment_start)
df[new_column] = (df[time_column] >= treatment_start).astype(int)
return df
[docs]
def wide_to_long(
data: pd.DataFrame,
value_columns: List[str],
id_column: str,
time_name: str = "period",
value_name: str = "value",
time_values: Optional[List[Any]] = None,
) -> pd.DataFrame:
"""
Convert wide-format panel data to long format for DiD analysis.
Wide format has one row per unit with multiple columns for each time period.
Long format has one row per unit-period combination.
Parameters
----------
data : pd.DataFrame
Wide-format DataFrame with one row per unit.
value_columns : list of str
Column names containing the outcome values for each period.
These should be in chronological order.
id_column : str
Column name for the unit identifier.
time_name : str, default="period"
Name for the new time period column.
value_name : str, default="value"
Name for the new value/outcome column.
time_values : list, optional
Values to use for time periods. If None, uses 0, 1, 2, ...
Must have same length as value_columns.
Returns
-------
pd.DataFrame
Long-format DataFrame with one row per unit-period.
Examples
--------
>>> wide_df = pd.DataFrame({
... 'firm_id': [1, 2, 3],
... 'sales_2019': [100, 150, 200],
... 'sales_2020': [110, 160, 210],
... 'sales_2021': [120, 170, 220]
... })
>>> long_df = wide_to_long(
... wide_df,
... value_columns=['sales_2019', 'sales_2020', 'sales_2021'],
... id_column='firm_id',
... time_name='year',
... value_name='sales',
... time_values=[2019, 2020, 2021]
... )
>>> len(long_df)
9
>>> long_df.columns.tolist()
['firm_id', 'year', 'sales']
"""
if not value_columns:
raise ValueError("value_columns cannot be empty.")
if id_column not in data.columns:
raise ValueError(f"Column '{id_column}' not found in DataFrame.")
for col in value_columns:
if col not in data.columns:
raise ValueError(f"Column '{col}' not found in DataFrame.")
if time_values is None:
time_values = list(range(len(value_columns)))
if len(time_values) != len(value_columns):
raise ValueError(
f"time_values length ({len(time_values)}) must match "
f"value_columns length ({len(value_columns)})."
)
# Get other columns to preserve (not id or value columns)
other_cols = [c for c in data.columns if c != id_column and c not in value_columns]
# Use pd.melt for better performance (vectorized)
long_df = pd.melt(
data,
id_vars=[id_column] + other_cols,
value_vars=value_columns,
var_name="_temp_var",
value_name=value_name,
)
# Map column names to time values
col_to_time = dict(zip(value_columns, time_values))
long_df[time_name] = long_df["_temp_var"].map(col_to_time)
long_df = long_df.drop("_temp_var", axis=1)
# Reorder columns and sort
cols = [id_column, time_name, value_name] + other_cols
return long_df[cols].sort_values([id_column, time_name]).reset_index(drop=True)
[docs]
def balance_panel(
data: pd.DataFrame,
unit_column: str,
time_column: str,
method: str = "inner",
fill_value: Optional[float] = None,
) -> pd.DataFrame:
"""
Balance a panel dataset to ensure all units have all time periods.
Parameters
----------
data : pd.DataFrame
Unbalanced panel data.
unit_column : str
Column name for unit identifier.
time_column : str
Column name for time period.
method : str, default="inner"
Balancing method:
- "inner": Keep only units that appear in all periods (drops units)
- "outer": Include all unit-period combinations (creates NaN)
- "fill": Include all combinations and fill missing values
fill_value : float, optional
Value to fill missing observations when method="fill".
If None with method="fill", uses column-specific forward fill.
Returns
-------
pd.DataFrame
Balanced panel DataFrame.
Examples
--------
Keep only complete units:
>>> df = pd.DataFrame({
... 'unit': [1, 1, 1, 2, 2, 3, 3, 3],
... 'period': [1, 2, 3, 1, 2, 1, 2, 3],
... 'y': [10, 11, 12, 20, 21, 30, 31, 32]
... })
>>> balanced = balance_panel(df, 'unit', 'period', method='inner')
>>> balanced['unit'].unique().tolist()
[1, 3]
Include all combinations:
>>> balanced = balance_panel(df, 'unit', 'period', method='outer')
>>> len(balanced)
9
"""
if unit_column not in data.columns:
raise ValueError(f"Column '{unit_column}' not found in DataFrame.")
if time_column not in data.columns:
raise ValueError(f"Column '{time_column}' not found in DataFrame.")
if method not in ["inner", "outer", "fill"]:
raise ValueError(f"method must be 'inner', 'outer', or 'fill', got '{method}'")
all_units = data[unit_column].unique()
all_periods = sorted(data[time_column].unique())
n_periods = len(all_periods)
if method == "inner":
# Keep only units that have all periods
unit_counts = data.groupby(unit_column)[time_column].nunique()
complete_units = unit_counts[unit_counts == n_periods].index
return data[data[unit_column].isin(complete_units)].copy()
elif method in ["outer", "fill"]:
# Create full grid of unit-period combinations
full_index = pd.MultiIndex.from_product(
[all_units, all_periods], names=[unit_column, time_column]
)
full_df = pd.DataFrame(index=full_index).reset_index()
# Merge with original data
result = full_df.merge(data, on=[unit_column, time_column], how="left")
if method == "fill":
# Identify columns to fill (exclude unit and time columns)
cols_to_fill = [c for c in result.columns if c not in [unit_column, time_column]]
if fill_value is not None:
# Fill specified columns with fill_value
numeric_cols = result.select_dtypes(include=[np.number]).columns
for col in numeric_cols:
if col in cols_to_fill:
result[col] = result[col].fillna(fill_value)
else:
# Forward fill within each unit for non-key columns
result = result.sort_values([unit_column, time_column])
result[cols_to_fill] = result.groupby(unit_column)[cols_to_fill].ffill()
# Backward fill any remaining NaN at start
result[cols_to_fill] = result.groupby(unit_column)[cols_to_fill].bfill()
return result
return data
[docs]
def validate_did_data(
data: pd.DataFrame,
outcome: str,
treatment: str,
time: str,
unit: Optional[str] = None,
raise_on_error: bool = True,
) -> Dict[str, Any]:
"""
Validate that data is properly formatted for DiD analysis.
Checks for common data issues and provides informative error messages.
Parameters
----------
data : pd.DataFrame
Data to validate.
outcome : str
Name of outcome variable column.
treatment : str
Name of treatment indicator column.
time : str
Name of time/post indicator column.
unit : str, optional
Name of unit identifier column (for panel data validation).
raise_on_error : bool, default=True
If True, raises ValueError on validation failures.
If False, returns validation results without raising.
Returns
-------
dict
Validation results with keys:
- valid: bool indicating if data passed all checks
- errors: list of error messages
- warnings: list of warning messages
- summary: dict with data summary statistics
Examples
--------
>>> df = pd.DataFrame({
... 'y': [1, 2, 3, 4],
... 'treated': [0, 0, 1, 1],
... 'post': [0, 1, 0, 1]
... })
>>> result = validate_did_data(df, 'y', 'treated', 'post', raise_on_error=False)
>>> result['valid']
True
"""
errors = []
warnings = []
# Check columns exist
required_cols = [outcome, treatment, time]
if unit is not None:
required_cols.append(unit)
for col in required_cols:
if col not in data.columns:
errors.append(f"Required column '{col}' not found in DataFrame.")
if errors:
if raise_on_error:
raise ValueError("\n".join(errors))
return {"valid": False, "errors": errors, "warnings": warnings, "summary": {}}
# Check outcome is numeric
if not pd.api.types.is_numeric_dtype(data[outcome]):
errors.append(
f"Outcome column '{outcome}' must be numeric. " f"Got type: {data[outcome].dtype}"
)
# Check treatment is binary
treatment_vals = data[treatment].dropna().unique()
if not set(treatment_vals).issubset({0, 1}):
errors.append(
f"Treatment column '{treatment}' must be binary (0 or 1). "
f"Found values: {sorted(treatment_vals)}"
)
# Check time is binary for simple DiD
time_vals = data[time].dropna().unique()
if len(time_vals) == 2 and not set(time_vals).issubset({0, 1}):
warnings.append(
f"Time column '{time}' has 2 values but they are not 0 and 1: {sorted(time_vals)}. "
"For basic DiD, use 0 for pre-treatment and 1 for post-treatment."
)
# Check for missing values
for col in required_cols:
n_missing = data[col].isna().sum()
if n_missing > 0:
errors.append(
f"Column '{col}' has {n_missing} missing values. "
"Please handle missing data before fitting."
)
# Calculate summary statistics
summary = {}
if not errors:
summary["n_obs"] = len(data)
summary["n_treated"] = int((data[treatment] == 1).sum())
summary["n_control"] = int((data[treatment] == 0).sum())
summary["n_periods"] = len(time_vals)
if unit is not None:
summary["n_units"] = data[unit].nunique()
# Check for sufficient variation
if summary["n_treated"] == 0:
errors.append("No treated observations found (treatment column is all 0).")
if summary["n_control"] == 0:
errors.append("No control observations found (treatment column is all 1).")
# Check for each treatment-time combination
if len(time_vals) == 2:
# For 2-period DiD, check all four cells
for t_val in [0, 1]:
for p_val in time_vals:
count = len(data[(data[treatment] == t_val) & (data[time] == p_val)])
if count == 0:
errors.append(
f"No observations for treatment={t_val}, time={p_val}. "
"DiD requires observations in all treatment-time cells."
)
else:
# For multi-period, check that both treatment groups exist in multiple periods
for t_val in [0, 1]:
n_periods_with_obs = data[data[treatment] == t_val][time].nunique()
if n_periods_with_obs < 2:
group_name = "Treated" if t_val == 1 else "Control"
errors.append(
f"{group_name} group has observations in only {n_periods_with_obs} period(s). "
"DiD requires multiple periods per group."
)
# Panel-specific validation
if unit is not None and not errors:
# Check treatment is constant within units
unit_treatment_var = data.groupby(unit)[treatment].nunique()
units_with_varying_treatment = unit_treatment_var[unit_treatment_var > 1]
if len(units_with_varying_treatment) > 0:
warnings.append(
f"Treatment varies within {len(units_with_varying_treatment)} unit(s). "
"For standard DiD, treatment should be constant within units. "
"This may be intentional for staggered adoption designs."
)
# Check panel balance
periods_per_unit = data.groupby(unit)[time].nunique()
if periods_per_unit.min() != periods_per_unit.max():
warnings.append(
f"Unbalanced panel detected. Units have between "
f"{periods_per_unit.min()} and {periods_per_unit.max()} periods. "
"Consider using balance_panel() to balance the data."
)
valid = len(errors) == 0
if raise_on_error and not valid:
raise ValueError("Data validation failed:\n" + "\n".join(errors))
return {"valid": valid, "errors": errors, "warnings": warnings, "summary": summary}
[docs]
def summarize_did_data(
data: pd.DataFrame, outcome: str, treatment: str, time: str, unit: Optional[str] = None
) -> pd.DataFrame:
"""
Generate summary statistics by treatment group and time period.
Parameters
----------
data : pd.DataFrame
Input data.
outcome : str
Name of outcome variable column.
treatment : str
Name of treatment indicator column.
time : str
Name of time/period column.
unit : str, optional
Name of unit identifier column.
Returns
-------
pd.DataFrame
Summary statistics with columns for each treatment-time combination.
Examples
--------
>>> df = pd.DataFrame({
... 'y': [10, 11, 12, 13, 20, 21, 22, 23],
... 'treated': [0, 0, 1, 1, 0, 0, 1, 1],
... 'post': [0, 1, 0, 1, 0, 1, 0, 1]
... })
>>> summary = summarize_did_data(df, 'y', 'treated', 'post')
>>> print(summary)
"""
# Group by treatment and time
summary = (
data.groupby([treatment, time])[outcome]
.agg([("n", "count"), ("mean", "mean"), ("std", "std"), ("min", "min"), ("max", "max")])
.round(4)
)
# Calculate time values for labeling
time_vals = sorted(data[time].unique())
# Add group labels based on sorted time values (not literal 0/1)
if len(time_vals) == 2:
pre_val, post_val = time_vals[0], time_vals[1]
def format_label(x: tuple) -> str:
treatment_label = "Treated" if x[0] == 1 else "Control"
time_label = "Post" if x[1] == post_val else "Pre"
return f"{treatment_label} - {time_label}"
summary.index = summary.index.map(format_label)
# Calculate means for each cell
treated_pre = data[(data[treatment] == 1) & (data[time] == pre_val)][outcome].mean()
treated_post = data[(data[treatment] == 1) & (data[time] == post_val)][outcome].mean()
control_pre = data[(data[treatment] == 0) & (data[time] == pre_val)][outcome].mean()
control_post = data[(data[treatment] == 0) & (data[time] == post_val)][outcome].mean()
# Calculate DiD
treated_diff = treated_post - treated_pre
control_diff = control_post - control_pre
did_estimate = treated_diff - control_diff
# Add to summary as a new row
did_row = pd.DataFrame(
{"n": ["-"], "mean": [did_estimate], "std": ["-"], "min": ["-"], "max": ["-"]},
index=["DiD Estimate"],
)
summary = pd.concat([summary, did_row])
else:
summary.index = summary.index.map(
lambda x: f"{'Treated' if x[0] == 1 else 'Control'} - Period {x[1]}"
)
return summary
[docs]
def create_event_time(
data: pd.DataFrame, time_column: str, treatment_time_column: str, new_column: str = "event_time"
) -> pd.DataFrame:
"""
Create an event-time column relative to treatment timing.
Useful for event study designs where treatment occurs at different
times for different units.
Parameters
----------
data : pd.DataFrame
Panel data.
time_column : str
Name of the calendar time column.
treatment_time_column : str
Name of the column indicating when each unit was treated.
Units with NaN or infinity are considered never-treated.
new_column : str, default="event_time"
Name of the new event-time column.
Returns
-------
pd.DataFrame
DataFrame with event-time column added. Values are:
- Negative for pre-treatment periods
- 0 for the treatment period
- Positive for post-treatment periods
- NaN for never-treated units
Examples
--------
>>> df = pd.DataFrame({
... 'unit': [1, 1, 1, 2, 2, 2],
... 'year': [2018, 2019, 2020, 2018, 2019, 2020],
... 'treatment_year': [2019, 2019, 2019, 2020, 2020, 2020]
... })
>>> df = create_event_time(df, 'year', 'treatment_year')
>>> df['event_time'].tolist()
[-1, 0, 1, -2, -1, 0]
"""
df = data.copy()
if time_column not in df.columns:
raise ValueError(f"Column '{time_column}' not found in DataFrame.")
if treatment_time_column not in df.columns:
raise ValueError(f"Column '{treatment_time_column}' not found in DataFrame.")
# Calculate event time
df[new_column] = df[time_column] - df[treatment_time_column]
# Handle never-treated (inf or NaN in treatment time)
col = df[treatment_time_column]
if pd.api.types.is_numeric_dtype(col):
never_treated = col.isna() | np.isinf(col)
else:
never_treated = col.isna()
df.loc[never_treated, new_column] = np.nan
return df
[docs]
def aggregate_to_cohorts(
data: pd.DataFrame,
unit_column: str,
time_column: str,
treatment_column: str,
outcome: str,
covariates: Optional[List[str]] = None,
) -> pd.DataFrame:
"""
Aggregate unit-level data to treatment cohort means.
Useful for visualization and cohort-level analysis.
Parameters
----------
data : pd.DataFrame
Unit-level panel data.
unit_column : str
Name of unit identifier column.
time_column : str
Name of time period column.
treatment_column : str
Name of treatment indicator column.
outcome : str
Name of outcome variable column.
covariates : list of str, optional
Additional columns to aggregate (will compute means).
Returns
-------
pd.DataFrame
Cohort-level data with mean outcomes by treatment status and period.
Examples
--------
>>> df = pd.DataFrame({
... 'unit': [1, 1, 2, 2, 3, 3, 4, 4],
... 'period': [0, 1, 0, 1, 0, 1, 0, 1],
... 'treated': [1, 1, 1, 1, 0, 0, 0, 0],
... 'y': [10, 15, 12, 17, 8, 10, 9, 11]
... })
>>> cohort_df = aggregate_to_cohorts(df, 'unit', 'period', 'treated', 'y')
>>> len(cohort_df)
4
"""
agg_cols = {outcome: "mean", unit_column: "nunique"}
if covariates:
for cov in covariates:
agg_cols[cov] = "mean"
cohort_data = data.groupby([treatment_column, time_column]).agg(agg_cols).reset_index()
# Rename columns
cohort_data = cohort_data.rename(columns={unit_column: "n_units", outcome: f"mean_{outcome}"})
return cohort_data
[docs]
def rank_control_units(
data: pd.DataFrame,
unit_column: str,
time_column: str,
outcome_column: str,
treatment_column: Optional[str] = None,
treated_units: Optional[List[Any]] = None,
pre_periods: Optional[List[Any]] = None,
covariates: Optional[List[str]] = None,
outcome_weight: float = 0.7,
covariate_weight: float = 0.3,
exclude_units: Optional[List[Any]] = None,
require_units: Optional[List[Any]] = None,
n_top: Optional[int] = None,
suggest_treatment_candidates: bool = False,
n_treatment_candidates: int = 5,
lambda_reg: float = 0.0,
) -> pd.DataFrame:
"""
Rank potential control units by their suitability for DiD analysis.
Evaluates control units based on pre-treatment outcome trend similarity
and optional covariate matching to treated units. Returns a ranked list
with quality scores.
Parameters
----------
data : pd.DataFrame
Panel data in long format.
unit_column : str
Column name for unit identifier.
time_column : str
Column name for time periods.
outcome_column : str
Column name for outcome variable.
treatment_column : str, optional
Column with binary treatment indicator (0/1). Used to identify
treated units from data.
treated_units : list, optional
Explicit list of treated unit IDs. Alternative to treatment_column.
pre_periods : list, optional
Pre-treatment periods for comparison. If None, uses first half of periods.
covariates : list of str, optional
Covariate columns for matching. Similarity is based on pre-treatment means.
outcome_weight : float, default=0.7
Weight for pre-treatment outcome trend similarity (0-1).
covariate_weight : float, default=0.3
Weight for covariate distance (0-1). Ignored if no covariates.
exclude_units : list, optional
Units that cannot be in control group.
require_units : list, optional
Units that must be in control group (will always appear in output).
n_top : int, optional
Return only top N control units. If None, return all.
suggest_treatment_candidates : bool, default=False
If True and no treated units specified, identify potential treatment
candidates instead of ranking controls.
n_treatment_candidates : int, default=5
Number of treatment candidates to suggest.
lambda_reg : float, default=0.0
Regularization for synthetic weights. Higher values give more uniform
weights across controls.
Returns
-------
pd.DataFrame
Ranked control units with columns:
- unit: Unit identifier
- quality_score: Combined quality score (0-1, higher is better)
- outcome_trend_score: Pre-treatment outcome trend similarity
- covariate_score: Covariate match score (NaN if no covariates)
- synthetic_weight: Informational heuristic weight from a single-pass
uncentered Frank-Wolfe solve; does NOT factor into ``quality_score``
(ranking) and is NOT the canonical SDID unit weight. For canonical
SDID weights use ``SyntheticDiD.fit()``.
- pre_trend_rmse: RMSE of pre-treatment outcome vs treated mean
- is_required: Whether unit was in require_units
If suggest_treatment_candidates=True (and no treated units):
- unit: Unit identifier
- treatment_candidate_score: Suitability as treatment unit
- avg_outcome_level: Pre-treatment outcome mean
- outcome_trend: Pre-treatment trend slope
- n_similar_controls: Count of similar potential controls
Examples
--------
Rank controls against treated units:
>>> data = generate_did_data(n_units=30, n_periods=6, seed=42)
>>> ranking = rank_control_units(
... data,
... unit_column='unit',
... time_column='period',
... outcome_column='outcome',
... treatment_column='treated',
... n_top=10
... )
>>> ranking['quality_score'].is_monotonic_decreasing
True
With covariates:
>>> data['size'] = np.random.randn(len(data))
>>> ranking = rank_control_units(
... data,
... unit_column='unit',
... time_column='period',
... outcome_column='outcome',
... treatment_column='treated',
... covariates=['size']
... )
Filter data for SyntheticDiD:
>>> top_controls = ranking['unit'].tolist()
>>> filtered = data[(data['treated'] == 1) | (data['unit'].isin(top_controls))]
"""
# -------------------------------------------------------------------------
# Input validation
# -------------------------------------------------------------------------
for col in [unit_column, time_column, outcome_column]:
if col not in data.columns:
raise ValueError(f"Column '{col}' not found in DataFrame.")
if treatment_column is not None and treatment_column not in data.columns:
raise ValueError(f"Treatment column '{treatment_column}' not found in DataFrame.")
if covariates:
for cov in covariates:
if cov not in data.columns:
raise ValueError(f"Covariate column '{cov}' not found in DataFrame.")
if not 0 <= outcome_weight <= 1:
raise ValueError("outcome_weight must be between 0 and 1")
if not 0 <= covariate_weight <= 1:
raise ValueError("covariate_weight must be between 0 and 1")
if treated_units is not None and treatment_column is not None:
raise ValueError("Specify either 'treated_units' or 'treatment_column', not both.")
if require_units and exclude_units:
invalid_required = [u for u in require_units if u in exclude_units]
if invalid_required:
raise ValueError(f"Units cannot be both required and excluded: {invalid_required}")
# -------------------------------------------------------------------------
# Determine pre-treatment periods
# -------------------------------------------------------------------------
all_periods = sorted(data[time_column].unique())
if pre_periods is None:
mid_point = len(all_periods) // 2
pre_periods = all_periods[:mid_point]
else:
pre_periods = list(pre_periods)
if len(pre_periods) == 0:
raise ValueError("No pre-treatment periods specified or inferred.")
# -------------------------------------------------------------------------
# Identify treated and control units
# -------------------------------------------------------------------------
all_units = list(data[unit_column].unique())
if treated_units is not None:
treated_set = set(treated_units)
elif treatment_column is not None:
unit_treatment = data.groupby(unit_column)[treatment_column].first()
treated_set = set(unit_treatment[unit_treatment == 1].index)
elif suggest_treatment_candidates:
# Treatment candidate discovery mode - no treated units
treated_set = set()
else:
raise ValueError(
"Must specify treated_units, treatment_column, or set "
"suggest_treatment_candidates=True"
)
# -------------------------------------------------------------------------
# Treatment candidate discovery mode
# -------------------------------------------------------------------------
if suggest_treatment_candidates and len(treated_set) == 0:
return _suggest_treatment_candidates(
data, unit_column, time_column, outcome_column, pre_periods, n_treatment_candidates
)
if len(treated_set) == 0:
raise ValueError("No treated units found.")
# Determine control candidates
control_candidates = [u for u in all_units if u not in treated_set]
if exclude_units:
control_candidates = [u for u in control_candidates if u not in exclude_units]
if len(control_candidates) == 0:
raise ValueError("No control units available after exclusions.")
# -------------------------------------------------------------------------
# Create outcome matrices (pre-treatment)
# -------------------------------------------------------------------------
pre_data = data[data[time_column].isin(pre_periods)]
pivot = pre_data.pivot(index=time_column, columns=unit_column, values=outcome_column)
# Filter to pre_periods that exist in data
valid_pre_periods = [p for p in pre_periods if p in pivot.index]
if len(valid_pre_periods) == 0:
raise ValueError("No data found for specified pre-treatment periods.")
# Filter control_candidates to those present in pivot (handles unbalanced panels)
control_candidates = [c for c in control_candidates if c in pivot.columns]
if len(control_candidates) == 0:
raise ValueError("No control units found in pre-treatment data.")
# Control outcomes: shape (n_pre_periods, n_control_candidates)
Y_control = pivot.loc[valid_pre_periods, control_candidates].values.astype(float)
# Treated outcomes mean: shape (n_pre_periods,)
treated_list = [u for u in treated_set if u in pivot.columns]
if len(treated_list) == 0:
raise ValueError("Treated units not found in pre-treatment data.")
Y_treated_mean = pivot.loc[valid_pre_periods, treated_list].mean(axis=1).values.astype(float)
# -------------------------------------------------------------------------
# Compute outcome trend scores
# -------------------------------------------------------------------------
# Informational `synthetic_weight` column. This is a RANKING HEURISTIC,
# not an estimator: it gives a rough "which controls would a synthetic
# regression weight heavily" signal that's reported alongside RMSE and
# covariate distance. The actual ranking (`quality_score`) is computed
# below from `outcome_trend_score` (RMSE-based) + `covariate_score`; the
# `synthetic_weight` column does NOT factor into the ranking decision.
#
# Solver choice. We use a single-pass uncentered Frank-Wolfe via the
# shared `_sc_weight_fw` dispatcher to solve:
#
# min_w ||Y_treated_mean - Y_control @ w||^2 + lambda_reg * ||w||^2
# s.t. w >= 0, sum(w) = 1
#
# Mapped to the FW objective `zeta^2 ||w||^2 + (1/N) ||Aw - b||^2` via
# `zeta = sqrt(lambda_reg / N)`. intercept=False because this QP does
# no column-centering, max_iter=1000 to bound ranking-loop cost,
# min_weight=1e-6 post-processing for interpretability.
#
# NOTE — this is INTENTIONALLY NOT the canonical SDID / R
# `synthdid::sc.weight.fw` two-pass unit-weight procedure (that uses
# intercept=TRUE, 100-iter -> sparsify -> 10000-iter). SDID estimation
# still uses that canonical path in `_sc_weight_fw_numpy` at
# `utils.py:_sc_weight_fw_numpy` via `compute_sdid_unit_weights`; this
# ranking heuristic uses a simpler single-pass call to the same solver
# for a cheap diagnostic score.
#
# Replaces the former `compute_synthetic_weights` wrapper whose Rust
# and Python backends had divergent PGD implementations (audit
# finding #22). Net effect: users on default `lambda_reg=0` with
# typical data see `synthetic_weight` values that agree with the old
# code to ~1e-7; extreme Y or `lambda_reg > 0` cases produce values
# that differ from the old code (which was mathematically wrong).
_Y_control = np.ascontiguousarray(Y_control, dtype=np.float64)
_Y_treated_mean = np.ascontiguousarray(Y_treated_mean, dtype=np.float64)
_n_pre, _n_control = _Y_control.shape
if _n_control == 0:
synthetic_weights = np.array([], dtype=np.float64)
elif _n_control == 1:
synthetic_weights = np.array([1.0])
else:
_zeta = float(np.sqrt(lambda_reg / _n_pre)) if lambda_reg > 0 else 0.0
# Scale stopping threshold by noise level so convergence stays
# meaningful at any data magnitude.
_sigma = _compute_noise_level(_Y_control)
_min_decrease = 1e-5 * max(_sigma, 1e-12)
_Y_fw = np.column_stack([_Y_control, _Y_treated_mean])
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message=r".*did not converge.*",
category=UserWarning,
)
synthetic_weights = _sc_weight_fw(
_Y_fw,
zeta=_zeta,
intercept=False,
min_decrease=_min_decrease,
max_iter=1000,
)
# Set small weights to zero for interpretability, then renormalize.
synthetic_weights = np.asarray(synthetic_weights, dtype=np.float64)
_min_weight = 1e-6
synthetic_weights[synthetic_weights < _min_weight] = 0.0
_total = float(np.sum(synthetic_weights))
if _total > 0:
synthetic_weights = synthetic_weights / _total
else:
synthetic_weights = np.ones(_n_control) / _n_control
# RMSE for each control vs treated mean (use nanmean to handle missing data)
rmse_scores = []
for j in range(len(control_candidates)):
y_c = Y_control[:, j]
rmse = np.sqrt(np.nanmean((y_c - Y_treated_mean) ** 2))
rmse_scores.append(rmse)
# Convert RMSE to similarity score (lower RMSE = higher score)
max_rmse = max(rmse_scores) if rmse_scores else 1.0
min_rmse = min(rmse_scores) if rmse_scores else 0.0
rmse_range = max_rmse - min_rmse
if rmse_range < 1e-10:
# All controls have identical/similar pre-trends (includes single control case)
outcome_trend_scores = [1.0] * len(rmse_scores)
else:
# Normalize so best control gets 1.0, worst gets 0.0
outcome_trend_scores = [1 - (rmse - min_rmse) / rmse_range for rmse in rmse_scores]
# -------------------------------------------------------------------------
# Compute covariate scores (if covariates provided)
# -------------------------------------------------------------------------
if covariates and len(covariates) > 0:
# Get unit-level covariate values (pre-treatment mean)
cov_data = pre_data.groupby(unit_column)[covariates].mean()
# Treated covariate profile (mean across treated units)
treated_cov = cov_data.loc[list(treated_set)].mean()
# Standardize covariates
cov_mean = cov_data.mean()
cov_std = cov_data.std().replace(0, 1) # Avoid division by zero
cov_standardized = (cov_data - cov_mean) / cov_std
treated_cov_std = (treated_cov - cov_mean) / cov_std
# Euclidean distance in standardized space (vectorized)
control_cov_matrix = cov_standardized.loc[control_candidates].values
treated_cov_vector = treated_cov_std.values
covariate_distances = np.sqrt(
np.sum((control_cov_matrix - treated_cov_vector) ** 2, axis=1)
)
# Convert distance to similarity score (min-max normalization)
max_dist = covariate_distances.max() if len(covariate_distances) > 0 else 1.0
min_dist = covariate_distances.min() if len(covariate_distances) > 0 else 0.0
dist_range = max_dist - min_dist
if dist_range < 1e-10:
# All controls have identical/similar covariate profiles
covariate_scores = [1.0] * len(covariate_distances)
else:
# Normalize so best control (closest) gets 1.0, worst gets 0.0
covariate_scores = (1 - (covariate_distances - min_dist) / dist_range).tolist()
else:
covariate_scores = [np.nan] * len(control_candidates)
# -------------------------------------------------------------------------
# Compute combined quality score
# -------------------------------------------------------------------------
# Normalize weights
total_weight = outcome_weight + covariate_weight
if total_weight > 0:
norm_outcome_weight = outcome_weight / total_weight
norm_covariate_weight = covariate_weight / total_weight
else:
norm_outcome_weight = 1.0
norm_covariate_weight = 0.0
quality_scores = []
for i in range(len(control_candidates)):
outcome_score = outcome_trend_scores[i]
cov_score = covariate_scores[i]
if np.isnan(cov_score):
# No covariates - use only outcome score
combined = outcome_score
else:
combined = norm_outcome_weight * outcome_score + norm_covariate_weight * cov_score
quality_scores.append(combined)
# -------------------------------------------------------------------------
# Build result DataFrame
# -------------------------------------------------------------------------
require_set = set(require_units) if require_units else set()
result = pd.DataFrame(
{
"unit": control_candidates,
"quality_score": quality_scores,
"outcome_trend_score": outcome_trend_scores,
"covariate_score": covariate_scores,
"synthetic_weight": synthetic_weights,
"pre_trend_rmse": rmse_scores,
"is_required": [u in require_set for u in control_candidates],
}
)
# Sort by quality score (descending)
result = result.sort_values("quality_score", ascending=False)
# Apply n_top limit if specified
if n_top is not None and n_top < len(result):
# Always include required units
required_df = result[result["is_required"]]
non_required_df = result[~result["is_required"]]
# Take top from non-required to fill remaining slots
remaining_slots = max(0, n_top - len(required_df))
top_non_required = non_required_df.head(remaining_slots)
result = pd.concat([required_df, top_non_required])
result = result.sort_values("quality_score", ascending=False)
return result.reset_index(drop=True)
def _suggest_treatment_candidates(
data: pd.DataFrame,
unit_column: str,
time_column: str,
outcome_column: str,
pre_periods: List[Any],
n_candidates: int,
) -> pd.DataFrame:
"""
Identify units that would make good treatment candidates.
A good treatment candidate:
1. Has many similar control units available (for matching)
2. Has stable pre-treatment trends (predictable counterfactual)
3. Is not an extreme outlier
Parameters
----------
data : pd.DataFrame
Panel data.
unit_column : str
Unit identifier column.
time_column : str
Time period column.
outcome_column : str
Outcome variable column.
pre_periods : list
Pre-treatment periods.
n_candidates : int
Number of candidates to return.
Returns
-------
pd.DataFrame
Treatment candidates with scores.
"""
all_units = list(data[unit_column].unique())
pre_data = data[data[time_column].isin(pre_periods)]
candidate_info = []
for unit in all_units:
unit_data = pre_data[pre_data[unit_column] == unit]
if len(unit_data) == 0:
continue
# Average outcome level
avg_outcome = unit_data[outcome_column].mean()
# Trend (simple linear regression slope)
times = unit_data[time_column].values
outcomes = unit_data[outcome_column].values
if len(times) > 1:
times_norm = np.arange(len(times))
try:
slope = np.polyfit(times_norm, outcomes, 1)[0]
except (np.linalg.LinAlgError, ValueError):
slope = 0.0
else:
slope = 0.0
# Count similar potential controls
other_units = [u for u in all_units if u != unit]
other_means = (
pre_data[pre_data[unit_column].isin(other_units)]
.groupby(unit_column)[outcome_column]
.mean()
)
if len(other_means) > 0:
sd = other_means.std()
if sd > 0:
n_similar = int(
np.sum(np.abs(other_means - avg_outcome) < _SIMILARITY_THRESHOLD_SD * sd)
)
else:
n_similar = len(other_means)
else:
n_similar = 0
candidate_info.append(
{
"unit": unit,
"avg_outcome_level": avg_outcome,
"outcome_trend": slope,
"n_similar_controls": n_similar,
}
)
if len(candidate_info) == 0:
return pd.DataFrame(
columns=[
"unit",
"treatment_candidate_score",
"avg_outcome_level",
"outcome_trend",
"n_similar_controls",
]
)
result = pd.DataFrame(candidate_info)
# Score: prefer units with many similar controls and moderate outcome levels
max_similar = result["n_similar_controls"].max()
if max_similar > 0:
similarity_score = result["n_similar_controls"] / max_similar
else:
similarity_score = pd.Series([0.0] * len(result))
# Penalty for outliers in outcome level
outcome_mean = result["avg_outcome_level"].mean()
outcome_std = result["avg_outcome_level"].std()
if outcome_std > 0:
outcome_z = np.abs((result["avg_outcome_level"] - outcome_mean) / outcome_std)
else:
outcome_z = pd.Series([0.0] * len(result))
result["treatment_candidate_score"] = (
similarity_score - _OUTLIER_PENALTY_WEIGHT * outcome_z
).clip(0, 1)
# Return top candidates
result = result.nlargest(n_candidates, "treatment_candidate_score")
return result.reset_index(drop=True)
def trim_weights(
data: pd.DataFrame,
weight_col: str,
upper: Optional[float] = None,
quantile: Optional[float] = None,
lower: Optional[float] = None,
) -> pd.DataFrame:
"""Trim (winsorize) survey weights to reduce influence of extreme values.
Caps weights at specified thresholds. Useful for reducing variance from
extreme survey weights before DiD estimation. Federal agencies (e.g., NCHS)
recommend reviewing weights with CV > 30%.
Parameters
----------
data : pd.DataFrame
Input DataFrame.
weight_col : str
Name of the weight column.
upper : float, optional
Absolute upper cap. Weights above this value are set to it.
Mutually exclusive with ``quantile``.
quantile : float, optional
Quantile-based upper cap (e.g., 0.99). Weights above the quantile
value are capped at it. Mutually exclusive with ``upper``.
lower : float, optional
Absolute lower floor. Weights below this value are set to it.
Can be combined with either ``upper`` or ``quantile``.
Returns
-------
pd.DataFrame
Copy of data with trimmed weights.
Raises
------
ValueError
If both ``upper`` and ``quantile`` are provided, or if ``weight_col``
is not in the DataFrame.
"""
if upper is not None and quantile is not None:
raise ValueError("Specify either 'upper' or 'quantile', not both.")
if weight_col not in data.columns:
raise ValueError(f"Column '{weight_col}' not found in DataFrame.")
result = data.copy()
w = result[weight_col].values.copy()
if quantile is not None:
if not (0 < quantile < 1):
raise ValueError(f"quantile must be in (0, 1), got {quantile}")
upper = float(np.nanquantile(w, quantile))
# Validate cap values are finite and non-negative
if upper is not None:
if not np.isfinite(upper) or upper < 0:
raise ValueError(f"upper must be finite and >= 0, got {upper}")
if lower is not None:
if not np.isfinite(lower) or lower < 0:
raise ValueError(f"lower must be finite and >= 0, got {lower}")
if upper is not None and lower is not None and lower > upper:
raise ValueError(
f"lower ({lower}) must be <= upper ({upper}). "
f"When using quantile, the resolved upper cap may be below lower."
)
if upper is not None:
w = np.minimum(w, upper)
if lower is not None:
w = np.maximum(w, lower)
result[weight_col] = w
return result
# ---------------------------------------------------------------------------
# Survey aggregation helpers
# ---------------------------------------------------------------------------
def _cell_mean_variance(
y_full: np.ndarray,
full_resolved: ResolvedSurveyDesign,
cell_mask: np.ndarray,
min_n: int,
scaffolding: Optional[_PsuScaffolding] = None,
) -> Tuple[float, float, int, bool]:
"""Compute design-based mean and variance of the weighted mean for one cell.
Uses full-design domain estimation: the influence function is zero-padded
outside the cell, preserving the full strata/PSU structure for variance
estimation. This is the methodologically correct approach for domain
estimation under complex survey designs (Lumley 2004, Section 3.4).
Parameters
----------
y_full : np.ndarray
Outcome values for the full dataset (may contain NaN).
full_resolved : ResolvedSurveyDesign
Full-sample resolved survey design.
cell_mask : np.ndarray
Boolean mask identifying cell members in the full dataset.
min_n : int
Minimum valid observations for design-based variance. Below this
threshold, SRS fallback is used.
Returns
-------
mean : float
Design-weighted cell mean.
variance : float
Design-based variance of the cell mean (>= 0). Uses SRS fallback
when the design-based estimate is unidentifiable or n_valid < min_n.
n_valid : int
Number of non-missing observations in the cell.
used_srs_fallback : bool
True if SRS variance was used instead of design-based.
"""
y_cell = y_full[cell_mask]
w_cell = full_resolved.weights[cell_mask]
# Valid = non-missing AND positive weight (zero-weight rows are padding)
valid = ~np.isnan(y_cell) & (w_cell > 0)
n_valid = int(np.sum(valid))
if n_valid == 0:
return np.nan, np.nan, 0, False
if n_valid < 2:
y_bar = float(y_cell[valid][0])
return y_bar, np.nan, 1, False
# Weighted mean from cell members (NaN-safe)
w_valid = w_cell * valid.astype(np.float64)
y_clean = np.where(valid, y_cell, 0.0)
sum_w = float(np.sum(w_valid))
if sum_w <= 0:
return np.nan, np.nan, n_valid, False
y_bar = float(np.sum(w_valid * y_clean) / sum_w)
# SRS fallback if below min_n threshold
# Normalize positive weights to mean=1 so fallback is scale-invariant
# (replicate designs preserve raw weight scale per survey.py:L189-240)
used_srs = False
if n_valid < min_n:
w_norm = w_valid.copy()
w_pos = w_norm[w_norm > 0]
if len(w_pos) > 0:
w_norm[w_norm > 0] = w_pos / w_pos.mean()
sum_wn = float(np.sum(w_norm))
resid_sq = w_norm * (y_clean - y_bar) ** 2
variance = float(np.sum(resid_sq) / (sum_wn**2) * n_valid / (n_valid - 1))
return y_bar, max(variance, 0.0), n_valid, True
# Full-design domain estimation: construct full-length psi with zeros
# outside the cell, preserving full strata/PSU structure for variance
n_total = len(y_full)
psi = np.zeros(n_total)
# Positions in full array where cell member has valid data
cell_indices = np.where(cell_mask)[0]
valid_positions = cell_indices[valid]
psi[valid_positions] = w_valid[valid] * (y_clean[valid] - y_bar) / sum_w
# Route to TSL or replicate variance using the full design. When a
# design-level scaffolding is provided (aggregate_survey's fast path),
# use it to skip the per-call pandas groupby / np.unique setup that
# otherwise dominates runtime at BRFSS scale.
if full_resolved.uses_replicate_variance:
variance, _ = compute_replicate_if_variance(psi, full_resolved)
elif scaffolding is not None:
variance = _compute_if_variance_fast(psi, scaffolding)
else:
variance = compute_survey_if_variance(psi, full_resolved)
# SRS fallback when design-based variance is unidentifiable
if np.isnan(variance):
w_norm = w_valid.copy()
w_pos = w_norm[w_norm > 0]
if len(w_pos) > 0:
w_norm[w_norm > 0] = w_pos / w_pos.mean()
sum_wn = float(np.sum(w_norm))
resid_sq = w_norm * (y_clean - y_bar) ** 2
variance = float(np.sum(resid_sq) / (sum_wn**2) * n_valid / (n_valid - 1))
used_srs = True
return y_bar, max(float(variance), 0.0), n_valid, used_srs
[docs]
def aggregate_survey(
data: pd.DataFrame,
by: Union[str, List[str]],
outcomes: Union[str, List[str]],
survey_design: SurveyDesign,
covariates: Optional[Union[str, List[str]]] = None,
min_n: int = 2,
lonely_psu: Optional[str] = None,
second_stage_weights: str = "pweight",
) -> Tuple[pd.DataFrame, SurveyDesign]:
"""Aggregate survey microdata to geographic-period cells with design-based precision.
Computes design-weighted cell means and their Taylor-linearized (or
replicate-based) standard errors for each cell defined by the ``by``
columns. Returns a panel-ready DataFrame and a pre-configured
:class:`SurveyDesign` for second-stage DiD estimation.
Each cell is treated as a subpopulation/domain of the full survey
design: influence function values are zero-padded outside the cell,
preserving full strata/PSU structure for variance estimation per
Lumley (2004) Section 3.4.
Parameters
----------
data : pd.DataFrame
Individual-level microdata.
by : str or list of str
Columns defining cells (e.g., ``["state", "year"]``). The first
element is used as the clustering variable in the returned
SurveyDesign (geographic unit for second-stage inference).
outcomes : str or list of str
Outcome variable(s) to aggregate with full precision tracking.
Each outcome produces ``{name}_mean``, ``{name}_se``,
``{name}_n``, and ``{name}_precision`` columns. When multiple
outcomes are given, panel filtering (non-estimable cell
removal, zero-weight PSU pruning) is based on the **first**
outcome only, consistent with the returned SurveyDesign. For
independent per-outcome support, call once per outcome.
survey_design : SurveyDesign
Survey design specification for the microdata.
covariates : str or list of str, optional
Additional variables to aggregate as design-weighted means only
(no SE/precision columns).
min_n : int, default 2
Minimum respondents per cell. Cells below this threshold use
simple random sampling variance as a fallback.
lonely_psu : str, optional
Override the survey design's ``lonely_psu`` setting for within-cell
computation. One of ``"remove"``, ``"certainty"``, ``"adjust"``.
second_stage_weights : str, default "pweight"
Weight type for the returned second-stage ``SurveyDesign``:
- ``"pweight"`` (default): Population weights - the mean of
per-cell survey weight sums within each geographic unit
(first ``by`` column), constant across periods. Targets
population-weighted second-stage estimation. Compatible
with all survey-capable estimators including those that
require unit-constant survey columns.
- ``"aweight"``: Precision weights - inverse variance
(``1 / V(y_bar)``). Targets precision-weighted second-stage
estimation via WLS. Compatible with estimators that accept ``aweight``
(DifferenceInDifferences, TwoWayFixedEffects, MultiPeriodDiD,
SunAbraham, ContinuousDiD, EfficientDiD); rejected by
``pweight``-only estimators.
Returns
-------
panel_df : pd.DataFrame
Aggregated panel with columns: grouping variables,
``{outcome}_mean``, ``{outcome}_se``, ``{outcome}_n``,
``{outcome}_precision``, ``{outcome}_weight``,
``{covariate}_mean``, ``cell_n``, ``cell_n_eff``,
``cell_sum_w``, ``srs_fallback``. The ``_weight`` column
contains unit-constant population weights (mean of
``cell_sum_w`` within each geographic unit) in pweight mode,
or cleaned precision (NaN/Inf mapped to 0.0) in aweight mode.
``cell_sum_w`` is always present as a diagnostic column
containing the sum of normalized survey weights per cell
(proportional to estimated population).
second_stage_design : SurveyDesign
Pre-configured for second-stage estimation with the chosen
``weight_type``, weights from the first outcome, and
geographic clustering via ``psu``.
Examples
--------
>>> design = SurveyDesign(weights="finalwt", strata="strat", psu="psu")
>>> panel, stage2 = aggregate_survey(
... microdata, by=["state", "year"],
... outcomes="smoking_rate", survey_design=design,
... )
>>> # stage2 has weight_type="pweight" — compatible with all estimators.
>>> # Add treatment/time indicators at the panel level, then fit:
>>> # panel["first_treat"] = panel["state"].map(policy_year).fillna(0)
>>> # result = CallawaySantAnna().fit(
>>> # panel, outcome="smoking_rate_mean",
>>> # unit="state", time="year", first_treat="first_treat",
>>> # survey_design=stage2,
>>> # )
"""
import warnings
from dataclasses import replace
# --- Normalize inputs ---
by_cols = [by] if isinstance(by, str) else list(by)
outcome_cols = [outcomes] if isinstance(outcomes, str) else list(outcomes)
cov_cols = (
[covariates] if isinstance(covariates, str) else list(covariates) if covariates else []
)
# --- Validate ---
if not by_cols:
raise ValueError("'by' must specify at least one grouping column")
if not outcome_cols:
raise ValueError("'outcomes' must specify at least one outcome variable")
all_cols = by_cols + outcome_cols + cov_cols
missing = [c for c in all_cols if c not in data.columns]
if missing:
raise ValueError(f"Columns not found in DataFrame: {missing}")
overlap = set(by_cols) & (set(outcome_cols) | set(cov_cols))
if overlap:
raise ValueError(f"Columns appear in both 'by' and outcomes/covariates: {overlap}")
if not isinstance(survey_design, SurveyDesign):
raise TypeError(
f"survey_design must be a SurveyDesign instance, got {type(survey_design).__name__}"
)
_valid_second_stage = {"pweight", "aweight"}
if second_stage_weights not in _valid_second_stage:
raise ValueError(
f"second_stage_weights must be one of {sorted(_valid_second_stage)}, "
f"got '{second_stage_weights}'."
)
if min_n < 1:
raise ValueError(f"min_n must be >= 1, got {min_n}")
if lonely_psu is not None and lonely_psu not in ("remove", "certainty", "adjust"):
raise ValueError(
f"lonely_psu must be 'remove', 'certainty', or 'adjust', got '{lonely_psu}'"
)
# --- Empty-input guard ---
if data.empty:
raise ValueError("data must be non-empty")
# --- Validate grouping columns have no missing values ---
by_missing = data[by_cols].isna().any()
cols_with_na = list(by_missing[by_missing].index)
if cols_with_na:
raise ValueError(
f"Missing values in grouping column(s): {cols_with_na}. "
f"Drop or fill NaN values before calling aggregate_survey()."
)
# --- Resolve design once on full data ---
effective_design = (
replace(survey_design, lonely_psu=lonely_psu) if lonely_psu else survey_design
)
full_resolved = effective_design.resolve(data)
# Precompute stratum/PSU scaffolding once per design. Amortizes
# per-cell pandas groupby + np.unique + stratum FPC lookup that
# otherwise dominate runtime at scale (see _compute_if_variance_fast).
# Replicate-weight designs use a different variance surface and stay
# on the legacy path.
_tsl_scaffolding: Optional[_PsuScaffolding] = (
_precompute_psu_scaffolding(full_resolved)
if not full_resolved.uses_replicate_variance
else None
)
# --- Precompute full-length outcome/covariate arrays ---
n_total = len(data)
all_vars = outcome_cols + cov_cols
non_numeric = [v for v in all_vars if not pd.api.types.is_numeric_dtype(data[v])]
if non_numeric:
raise ValueError(
f"Non-numeric column(s) in outcomes/covariates: {non_numeric}. "
f"All outcome and covariate columns must be numeric."
)
y_arrays: Dict[str, np.ndarray] = {var: data[var].values.astype(np.float64) for var in all_vars}
# --- Per-cell computation ---
# Use groupby().indices for position-based cell membership (safe with
# duplicate DataFrame indices, no column injection into user data)
grouped = data.groupby(by_cols, sort=True)
cell_indices = grouped.indices # dict of cell_key → positional indices
rows: List[Dict[str, Any]] = []
srs_cells: List[str] = []
zero_var_cells: List[str] = []
for cell_key, pos_idx in cell_indices.items():
# Boolean mask for full-design domain estimation
cell_mask = np.zeros(n_total, dtype=bool)
cell_mask[pos_idx] = True
cell_n = int(np.sum(cell_mask))
cell_key_str = str(cell_key)
# Cell-level statistics (Kish ESS is a property of the cell)
cell_w = full_resolved.weights[cell_mask]
sum_w = float(np.sum(cell_w))
sum_w2 = float(np.sum(cell_w**2))
cell_n_eff = (sum_w**2 / sum_w2) if sum_w2 > 0 else 0.0
# Build row dict with grouping columns
row: Dict[str, Any] = {}
if len(by_cols) == 1:
row[by_cols[0]] = cell_key
else:
for i, col in enumerate(by_cols):
row[col] = cell_key[i]
row["cell_n"] = cell_n
row["cell_n_eff"] = cell_n_eff
row["cell_sum_w"] = sum_w
cell_srs_fallback = False
# Outcomes: mean + SE + n + precision (full-design domain estimation)
for var in outcome_cols:
y_bar, variance, n_valid, used_srs = _cell_mean_variance(
y_arrays[var],
full_resolved,
cell_mask,
min_n,
scaffolding=_tsl_scaffolding,
)
se = float(np.sqrt(variance)) if not np.isnan(variance) else np.nan
if used_srs:
cell_srs_fallback = True
# Zero variance → precision NaN
if se == 0.0:
precision = np.nan
zero_var_cells.append(cell_key_str)
elif np.isnan(se):
precision = np.nan
else:
precision = 1.0 / variance
row[f"{var}_mean"] = y_bar
row[f"{var}_se"] = se
row[f"{var}_n"] = n_valid
row[f"{var}_precision"] = precision
# Covariates: design-weighted mean only
for var in cov_cols:
y_cell = y_arrays[var][cell_mask]
valid = ~np.isnan(y_cell)
w_valid = cell_w * valid.astype(np.float64)
sw = float(np.sum(w_valid))
if sw > 0:
row[f"{var}_mean"] = float(np.sum(w_valid * np.where(valid, y_cell, 0.0)) / sw)
else:
row[f"{var}_mean"] = np.nan
row["srs_fallback"] = cell_srs_fallback
if cell_srs_fallback:
srs_cells.append(cell_key_str)
rows.append(row)
# --- Warnings ---
if srs_cells:
warnings.warn(
f"Design-based variance not estimable for {len(srs_cells)} cell(s); "
f"using SRS fallback: {srs_cells[:5]}"
+ (f" ... and {len(srs_cells) - 5} more" if len(srs_cells) > 5 else ""),
UserWarning,
stacklevel=2,
)
if zero_var_cells:
warnings.warn(
f"Zero variance in {len(zero_var_cells)} cell(s) (precision set to NaN): "
f"{zero_var_cells[:5]}"
+ (f" ... and {len(zero_var_cells) - 5} more" if len(zero_var_cells) > 5 else ""),
UserWarning,
stacklevel=2,
)
# --- Assemble output ---
panel_df = pd.DataFrame(rows)
# Sort by grouping columns
panel_df = panel_df.sort_values(by_cols).reset_index(drop=True)
# --- Drop non-estimable cells ---
# Cells with non-finite mean (n_valid==0 or all-missing) cannot contribute
# to second-stage estimation and would cause fit() to reject NaN outcomes.
# Dropping them also removes all-zero-weight PSUs from the panel.
first_outcome = outcome_cols[0]
mean_col = f"{first_outcome}_mean"
nonestimable = ~np.isfinite(panel_df[mean_col].values)
if np.any(nonestimable):
n_dropped = int(np.sum(nonestimable))
dropped_keys = panel_df.loc[nonestimable, by_cols].values.tolist()
# Warn about secondary outcomes losing valid data in dropped cells
secondary_loss = []
for var in outcome_cols[1:]:
valid_secondary = np.isfinite(panel_df.loc[nonestimable, f"{var}_mean"].values)
if np.any(valid_secondary):
secondary_loss.append(var)
msg = (
f"Dropped {n_dropped} non-estimable cell(s) (based on first outcome "
f"'{first_outcome}'): {dropped_keys[:5]}"
+ (f" ... and {n_dropped - 5} more" if n_dropped > 5 else "")
)
if secondary_loss:
msg += (
f". Note: {secondary_loss} had valid data in dropped cells. "
f"For independent per-outcome support, call once per outcome."
)
warnings.warn(msg, UserWarning, stacklevel=2)
panel_df = panel_df[~nonestimable].reset_index(drop=True)
# --- Construct second-stage SurveyDesign ---
geo_col = by_cols[0]
weight_col = f"{first_outcome}_weight"
if second_stage_weights == "pweight":
# Unit-level population weight: average cell_sum_w across periods
# within each geographic unit. This produces a unit-constant
# weight that satisfies _validate_unit_constant_survey() for
# panel estimators, while representing each unit's average
# population share (averaging out period-to-period sampling
# variability in per-cell weight sums).
panel_df[weight_col] = panel_df.groupby(geo_col)["cell_sum_w"].transform("mean")
else:
# Precision weight: inverse variance, with NaN/Inf -> 0.0 so
# downstream resolve() doesn't reject missing weights.
# Diagnostic *_precision column is kept unchanged.
panel_df[weight_col] = np.where(
np.isfinite(panel_df[f"{first_outcome}_precision"]),
panel_df[f"{first_outcome}_precision"],
0.0,
)
# Drop geographic units (PSUs) with zero total weight — they would
# inflate survey df and distort second-stage variance estimation.
# Under pweight mode, unit-averaged cell_sum_w > 0 for all surviving
# cells, so this block is a defensive no-op. Under aweight, NaN
# precision maps to 0.0 and geographic units with all-zero precision
# are pruned here.
geo_weight = panel_df.groupby(geo_col)[weight_col].sum()
zero_geos = geo_weight[geo_weight == 0].index
if len(zero_geos) > 0:
n_before = len(panel_df)
panel_df = panel_df[~panel_df[geo_col].isin(zero_geos)].reset_index(drop=True)
n_after = len(panel_df)
warnings.warn(
f"Dropped {n_before - n_after} cell(s) from {len(zero_geos)} "
f"geographic unit(s) with zero total weight: "
f"{list(zero_geos[:5])}"
+ (f" ... and {len(zero_geos) - 5} more" if len(zero_geos) > 5 else ""),
UserWarning,
stacklevel=2,
)
# Guard: all cells dropped
if panel_df.empty:
raise ValueError(
"No estimable cells remain after aggregation. "
"All cells had missing outcomes or zero effective weight."
)
second_stage_design = SurveyDesign(
weights=weight_col,
weight_type=second_stage_weights,
psu=geo_col,
)
return panel_df, second_stage_design