"""
Data preparation utilities for difference-in-differences analysis.
This module provides helper functions to prepare data for DiD estimation,
including creating treatment indicators, reshaping panel data, and
generating synthetic datasets for testing.
Data generation functions (generate_*) are defined in prep_dgp.py and
re-exported here for backward compatibility.
"""
from typing import Any, Dict, List, Optional, Union
import numpy as np
import pandas as pd
from diff_diff.utils import compute_synthetic_weights
# Re-export data generation functions from prep_dgp for backward compatibility
from diff_diff.prep_dgp import (
generate_continuous_did_data,
generate_did_data,
generate_staggered_data,
generate_factor_data,
generate_ddd_data,
generate_panel_data,
generate_event_study_data,
)
# Constants for rank_control_units
_SIMILARITY_THRESHOLD_SD = 0.5 # Controls within this many SDs are "similar"
_OUTLIER_PENALTY_WEIGHT = 0.3 # Penalty weight for outcome outliers in treatment candidate scoring
[docs]
def make_treatment_indicator(
data: pd.DataFrame,
column: str,
treated_values: Optional[Union[Any, List[Any]]] = None,
threshold: Optional[float] = None,
above_threshold: bool = True,
new_column: str = "treated"
) -> pd.DataFrame:
"""
Create a binary treatment indicator column from various input types.
Parameters
----------
data : pd.DataFrame
Input DataFrame.
column : str
Name of the column to use for creating the treatment indicator.
treated_values : Any or list, optional
Value(s) that indicate treatment. Units with these values get
treatment=1, others get treatment=0.
threshold : float, optional
Numeric threshold for creating treatment. Used when the treatment
is based on a continuous variable (e.g., treat firms above median size).
above_threshold : bool, default=True
If True, values >= threshold are treated. If False, values <= threshold
are treated. Only used when threshold is specified.
new_column : str, default="treated"
Name of the new treatment indicator column.
Returns
-------
pd.DataFrame
DataFrame with the new treatment indicator column added.
Examples
--------
Create treatment from categorical variable:
>>> df = pd.DataFrame({'group': ['A', 'A', 'B', 'B'], 'y': [1, 2, 3, 4]})
>>> df = make_treatment_indicator(df, 'group', treated_values='A')
>>> df['treated'].tolist()
[1, 1, 0, 0]
Create treatment from numeric threshold:
>>> df = pd.DataFrame({'size': [10, 50, 100, 200], 'y': [1, 2, 3, 4]})
>>> df = make_treatment_indicator(df, 'size', threshold=75)
>>> df['treated'].tolist()
[0, 0, 1, 1]
Treat units below a threshold:
>>> df = make_treatment_indicator(df, 'size', threshold=75, above_threshold=False)
>>> df['treated'].tolist()
[1, 1, 0, 0]
"""
df = data.copy()
if treated_values is not None and threshold is not None:
raise ValueError("Specify either 'treated_values' or 'threshold', not both.")
if treated_values is None and threshold is None:
raise ValueError("Must specify either 'treated_values' or 'threshold'.")
if column not in df.columns:
raise ValueError(f"Column '{column}' not found in DataFrame.")
if treated_values is not None:
# Convert single value to list
if not isinstance(treated_values, (list, tuple, set)):
treated_values = [treated_values]
df[new_column] = df[column].isin(treated_values).astype(int)
else:
# Use threshold
if above_threshold:
df[new_column] = (df[column] >= threshold).astype(int)
else:
df[new_column] = (df[column] <= threshold).astype(int)
return df
[docs]
def make_post_indicator(
data: pd.DataFrame,
time_column: str,
post_periods: Optional[Union[Any, List[Any]]] = None,
treatment_start: Optional[Any] = None,
new_column: str = "post"
) -> pd.DataFrame:
"""
Create a binary post-treatment indicator column.
Parameters
----------
data : pd.DataFrame
Input DataFrame.
time_column : str
Name of the time/period column.
post_periods : Any or list, optional
Specific period value(s) that are post-treatment. Periods matching
these values get post=1, others get post=0.
treatment_start : Any, optional
The first post-treatment period. All periods >= this value get post=1.
Works with numeric periods, strings (sorted alphabetically), or dates.
new_column : str, default="post"
Name of the new post indicator column.
Returns
-------
pd.DataFrame
DataFrame with the new post indicator column added.
Examples
--------
Using specific post periods:
>>> df = pd.DataFrame({'year': [2018, 2019, 2020, 2021], 'y': [1, 2, 3, 4]})
>>> df = make_post_indicator(df, 'year', post_periods=[2020, 2021])
>>> df['post'].tolist()
[0, 0, 1, 1]
Using treatment start:
>>> df = make_post_indicator(df, 'year', treatment_start=2020)
>>> df['post'].tolist()
[0, 0, 1, 1]
Works with date columns:
>>> df = pd.DataFrame({'date': pd.to_datetime(['2020-01-01', '2020-06-01', '2021-01-01'])})
>>> df = make_post_indicator(df, 'date', treatment_start='2020-06-01')
"""
df = data.copy()
if post_periods is not None and treatment_start is not None:
raise ValueError("Specify either 'post_periods' or 'treatment_start', not both.")
if post_periods is None and treatment_start is None:
raise ValueError("Must specify either 'post_periods' or 'treatment_start'.")
if time_column not in df.columns:
raise ValueError(f"Column '{time_column}' not found in DataFrame.")
if post_periods is not None:
# Convert single value to list
if not isinstance(post_periods, (list, tuple, set)):
post_periods = [post_periods]
df[new_column] = df[time_column].isin(post_periods).astype(int)
else:
# Use treatment_start - convert to same type as column if needed
col_dtype = df[time_column].dtype
if pd.api.types.is_datetime64_any_dtype(col_dtype):
treatment_start = pd.to_datetime(treatment_start)
df[new_column] = (df[time_column] >= treatment_start).astype(int)
return df
[docs]
def wide_to_long(
data: pd.DataFrame,
value_columns: List[str],
id_column: str,
time_name: str = "period",
value_name: str = "value",
time_values: Optional[List[Any]] = None
) -> pd.DataFrame:
"""
Convert wide-format panel data to long format for DiD analysis.
Wide format has one row per unit with multiple columns for each time period.
Long format has one row per unit-period combination.
Parameters
----------
data : pd.DataFrame
Wide-format DataFrame with one row per unit.
value_columns : list of str
Column names containing the outcome values for each period.
These should be in chronological order.
id_column : str
Column name for the unit identifier.
time_name : str, default="period"
Name for the new time period column.
value_name : str, default="value"
Name for the new value/outcome column.
time_values : list, optional
Values to use for time periods. If None, uses 0, 1, 2, ...
Must have same length as value_columns.
Returns
-------
pd.DataFrame
Long-format DataFrame with one row per unit-period.
Examples
--------
>>> wide_df = pd.DataFrame({
... 'firm_id': [1, 2, 3],
... 'sales_2019': [100, 150, 200],
... 'sales_2020': [110, 160, 210],
... 'sales_2021': [120, 170, 220]
... })
>>> long_df = wide_to_long(
... wide_df,
... value_columns=['sales_2019', 'sales_2020', 'sales_2021'],
... id_column='firm_id',
... time_name='year',
... value_name='sales',
... time_values=[2019, 2020, 2021]
... )
>>> len(long_df)
9
>>> long_df.columns.tolist()
['firm_id', 'year', 'sales']
"""
if not value_columns:
raise ValueError("value_columns cannot be empty.")
if id_column not in data.columns:
raise ValueError(f"Column '{id_column}' not found in DataFrame.")
for col in value_columns:
if col not in data.columns:
raise ValueError(f"Column '{col}' not found in DataFrame.")
if time_values is None:
time_values = list(range(len(value_columns)))
if len(time_values) != len(value_columns):
raise ValueError(
f"time_values length ({len(time_values)}) must match "
f"value_columns length ({len(value_columns)})."
)
# Get other columns to preserve (not id or value columns)
other_cols = [c for c in data.columns if c != id_column and c not in value_columns]
# Use pd.melt for better performance (vectorized)
long_df = pd.melt(
data,
id_vars=[id_column] + other_cols,
value_vars=value_columns,
var_name='_temp_var',
value_name=value_name
)
# Map column names to time values
col_to_time = dict(zip(value_columns, time_values))
long_df[time_name] = long_df['_temp_var'].map(col_to_time)
long_df = long_df.drop('_temp_var', axis=1)
# Reorder columns and sort
cols = [id_column, time_name, value_name] + other_cols
return long_df[cols].sort_values([id_column, time_name]).reset_index(drop=True)
[docs]
def balance_panel(
data: pd.DataFrame,
unit_column: str,
time_column: str,
method: str = "inner",
fill_value: Optional[float] = None
) -> pd.DataFrame:
"""
Balance a panel dataset to ensure all units have all time periods.
Parameters
----------
data : pd.DataFrame
Unbalanced panel data.
unit_column : str
Column name for unit identifier.
time_column : str
Column name for time period.
method : str, default="inner"
Balancing method:
- "inner": Keep only units that appear in all periods (drops units)
- "outer": Include all unit-period combinations (creates NaN)
- "fill": Include all combinations and fill missing values
fill_value : float, optional
Value to fill missing observations when method="fill".
If None with method="fill", uses column-specific forward fill.
Returns
-------
pd.DataFrame
Balanced panel DataFrame.
Examples
--------
Keep only complete units:
>>> df = pd.DataFrame({
... 'unit': [1, 1, 1, 2, 2, 3, 3, 3],
... 'period': [1, 2, 3, 1, 2, 1, 2, 3],
... 'y': [10, 11, 12, 20, 21, 30, 31, 32]
... })
>>> balanced = balance_panel(df, 'unit', 'period', method='inner')
>>> balanced['unit'].unique().tolist()
[1, 3]
Include all combinations:
>>> balanced = balance_panel(df, 'unit', 'period', method='outer')
>>> len(balanced)
9
"""
if unit_column not in data.columns:
raise ValueError(f"Column '{unit_column}' not found in DataFrame.")
if time_column not in data.columns:
raise ValueError(f"Column '{time_column}' not found in DataFrame.")
if method not in ["inner", "outer", "fill"]:
raise ValueError(f"method must be 'inner', 'outer', or 'fill', got '{method}'")
all_units = data[unit_column].unique()
all_periods = sorted(data[time_column].unique())
n_periods = len(all_periods)
if method == "inner":
# Keep only units that have all periods
unit_counts = data.groupby(unit_column)[time_column].nunique()
complete_units = unit_counts[unit_counts == n_periods].index
return data[data[unit_column].isin(complete_units)].copy()
elif method in ["outer", "fill"]:
# Create full grid of unit-period combinations
full_index = pd.MultiIndex.from_product(
[all_units, all_periods],
names=[unit_column, time_column]
)
full_df = pd.DataFrame(index=full_index).reset_index()
# Merge with original data
result = full_df.merge(data, on=[unit_column, time_column], how="left")
if method == "fill":
# Identify columns to fill (exclude unit and time columns)
cols_to_fill = [c for c in result.columns if c not in [unit_column, time_column]]
if fill_value is not None:
# Fill specified columns with fill_value
numeric_cols = result.select_dtypes(include=[np.number]).columns
for col in numeric_cols:
if col in cols_to_fill:
result[col] = result[col].fillna(fill_value)
else:
# Forward fill within each unit for non-key columns
result = result.sort_values([unit_column, time_column])
result[cols_to_fill] = result.groupby(unit_column)[cols_to_fill].ffill()
# Backward fill any remaining NaN at start
result[cols_to_fill] = result.groupby(unit_column)[cols_to_fill].bfill()
return result
return data
[docs]
def validate_did_data(
data: pd.DataFrame,
outcome: str,
treatment: str,
time: str,
unit: Optional[str] = None,
raise_on_error: bool = True
) -> Dict[str, Any]:
"""
Validate that data is properly formatted for DiD analysis.
Checks for common data issues and provides informative error messages.
Parameters
----------
data : pd.DataFrame
Data to validate.
outcome : str
Name of outcome variable column.
treatment : str
Name of treatment indicator column.
time : str
Name of time/post indicator column.
unit : str, optional
Name of unit identifier column (for panel data validation).
raise_on_error : bool, default=True
If True, raises ValueError on validation failures.
If False, returns validation results without raising.
Returns
-------
dict
Validation results with keys:
- valid: bool indicating if data passed all checks
- errors: list of error messages
- warnings: list of warning messages
- summary: dict with data summary statistics
Examples
--------
>>> df = pd.DataFrame({
... 'y': [1, 2, 3, 4],
... 'treated': [0, 0, 1, 1],
... 'post': [0, 1, 0, 1]
... })
>>> result = validate_did_data(df, 'y', 'treated', 'post', raise_on_error=False)
>>> result['valid']
True
"""
errors = []
warnings = []
# Check columns exist
required_cols = [outcome, treatment, time]
if unit is not None:
required_cols.append(unit)
for col in required_cols:
if col not in data.columns:
errors.append(f"Required column '{col}' not found in DataFrame.")
if errors:
if raise_on_error:
raise ValueError("\n".join(errors))
return {"valid": False, "errors": errors, "warnings": warnings, "summary": {}}
# Check outcome is numeric
if not pd.api.types.is_numeric_dtype(data[outcome]):
errors.append(
f"Outcome column '{outcome}' must be numeric. "
f"Got type: {data[outcome].dtype}"
)
# Check treatment is binary
treatment_vals = data[treatment].dropna().unique()
if not set(treatment_vals).issubset({0, 1}):
errors.append(
f"Treatment column '{treatment}' must be binary (0 or 1). "
f"Found values: {sorted(treatment_vals)}"
)
# Check time is binary for simple DiD
time_vals = data[time].dropna().unique()
if len(time_vals) == 2 and not set(time_vals).issubset({0, 1}):
warnings.append(
f"Time column '{time}' has 2 values but they are not 0 and 1: {sorted(time_vals)}. "
"For basic DiD, use 0 for pre-treatment and 1 for post-treatment."
)
# Check for missing values
for col in required_cols:
n_missing = data[col].isna().sum()
if n_missing > 0:
errors.append(
f"Column '{col}' has {n_missing} missing values. "
"Please handle missing data before fitting."
)
# Calculate summary statistics
summary = {}
if not errors:
summary["n_obs"] = len(data)
summary["n_treated"] = int((data[treatment] == 1).sum())
summary["n_control"] = int((data[treatment] == 0).sum())
summary["n_periods"] = len(time_vals)
if unit is not None:
summary["n_units"] = data[unit].nunique()
# Check for sufficient variation
if summary["n_treated"] == 0:
errors.append("No treated observations found (treatment column is all 0).")
if summary["n_control"] == 0:
errors.append("No control observations found (treatment column is all 1).")
# Check for each treatment-time combination
if len(time_vals) == 2:
# For 2-period DiD, check all four cells
for t_val in [0, 1]:
for p_val in time_vals:
count = len(data[(data[treatment] == t_val) & (data[time] == p_val)])
if count == 0:
errors.append(
f"No observations for treatment={t_val}, time={p_val}. "
"DiD requires observations in all treatment-time cells."
)
else:
# For multi-period, check that both treatment groups exist in multiple periods
for t_val in [0, 1]:
n_periods_with_obs = data[data[treatment] == t_val][time].nunique()
if n_periods_with_obs < 2:
group_name = "Treated" if t_val == 1 else "Control"
errors.append(
f"{group_name} group has observations in only {n_periods_with_obs} period(s). "
"DiD requires multiple periods per group."
)
# Panel-specific validation
if unit is not None and not errors:
# Check treatment is constant within units
unit_treatment_var = data.groupby(unit)[treatment].nunique()
units_with_varying_treatment = unit_treatment_var[unit_treatment_var > 1]
if len(units_with_varying_treatment) > 0:
warnings.append(
f"Treatment varies within {len(units_with_varying_treatment)} unit(s). "
"For standard DiD, treatment should be constant within units. "
"This may be intentional for staggered adoption designs."
)
# Check panel balance
periods_per_unit = data.groupby(unit)[time].nunique()
if periods_per_unit.min() != periods_per_unit.max():
warnings.append(
f"Unbalanced panel detected. Units have between "
f"{periods_per_unit.min()} and {periods_per_unit.max()} periods. "
"Consider using balance_panel() to balance the data."
)
valid = len(errors) == 0
if raise_on_error and not valid:
raise ValueError("Data validation failed:\n" + "\n".join(errors))
return {
"valid": valid,
"errors": errors,
"warnings": warnings,
"summary": summary
}
[docs]
def summarize_did_data(
data: pd.DataFrame,
outcome: str,
treatment: str,
time: str,
unit: Optional[str] = None
) -> pd.DataFrame:
"""
Generate summary statistics by treatment group and time period.
Parameters
----------
data : pd.DataFrame
Input data.
outcome : str
Name of outcome variable column.
treatment : str
Name of treatment indicator column.
time : str
Name of time/period column.
unit : str, optional
Name of unit identifier column.
Returns
-------
pd.DataFrame
Summary statistics with columns for each treatment-time combination.
Examples
--------
>>> df = pd.DataFrame({
... 'y': [10, 11, 12, 13, 20, 21, 22, 23],
... 'treated': [0, 0, 1, 1, 0, 0, 1, 1],
... 'post': [0, 1, 0, 1, 0, 1, 0, 1]
... })
>>> summary = summarize_did_data(df, 'y', 'treated', 'post')
>>> print(summary)
"""
# Group by treatment and time
summary = data.groupby([treatment, time])[outcome].agg([
("n", "count"),
("mean", "mean"),
("std", "std"),
("min", "min"),
("max", "max")
]).round(4)
# Calculate time values for labeling
time_vals = sorted(data[time].unique())
# Add group labels based on sorted time values (not literal 0/1)
if len(time_vals) == 2:
pre_val, post_val = time_vals[0], time_vals[1]
def format_label(x: tuple) -> str:
treatment_label = 'Treated' if x[0] == 1 else 'Control'
time_label = 'Post' if x[1] == post_val else 'Pre'
return f"{treatment_label} - {time_label}"
summary.index = summary.index.map(format_label)
# Calculate means for each cell
treated_pre = data[(data[treatment] == 1) & (data[time] == pre_val)][outcome].mean()
treated_post = data[(data[treatment] == 1) & (data[time] == post_val)][outcome].mean()
control_pre = data[(data[treatment] == 0) & (data[time] == pre_val)][outcome].mean()
control_post = data[(data[treatment] == 0) & (data[time] == post_val)][outcome].mean()
# Calculate DiD
treated_diff = treated_post - treated_pre
control_diff = control_post - control_pre
did_estimate = treated_diff - control_diff
# Add to summary as a new row
did_row = pd.DataFrame(
{
"n": ["-"],
"mean": [did_estimate],
"std": ["-"],
"min": ["-"],
"max": ["-"]
},
index=["DiD Estimate"]
)
summary = pd.concat([summary, did_row])
else:
summary.index = summary.index.map(
lambda x: f"{'Treated' if x[0] == 1 else 'Control'} - Period {x[1]}"
)
return summary
[docs]
def create_event_time(
data: pd.DataFrame,
time_column: str,
treatment_time_column: str,
new_column: str = "event_time"
) -> pd.DataFrame:
"""
Create an event-time column relative to treatment timing.
Useful for event study designs where treatment occurs at different
times for different units.
Parameters
----------
data : pd.DataFrame
Panel data.
time_column : str
Name of the calendar time column.
treatment_time_column : str
Name of the column indicating when each unit was treated.
Units with NaN or infinity are considered never-treated.
new_column : str, default="event_time"
Name of the new event-time column.
Returns
-------
pd.DataFrame
DataFrame with event-time column added. Values are:
- Negative for pre-treatment periods
- 0 for the treatment period
- Positive for post-treatment periods
- NaN for never-treated units
Examples
--------
>>> df = pd.DataFrame({
... 'unit': [1, 1, 1, 2, 2, 2],
... 'year': [2018, 2019, 2020, 2018, 2019, 2020],
... 'treatment_year': [2019, 2019, 2019, 2020, 2020, 2020]
... })
>>> df = create_event_time(df, 'year', 'treatment_year')
>>> df['event_time'].tolist()
[-1, 0, 1, -2, -1, 0]
"""
df = data.copy()
if time_column not in df.columns:
raise ValueError(f"Column '{time_column}' not found in DataFrame.")
if treatment_time_column not in df.columns:
raise ValueError(f"Column '{treatment_time_column}' not found in DataFrame.")
# Calculate event time
df[new_column] = df[time_column] - df[treatment_time_column]
# Handle never-treated (inf or NaN in treatment time)
col = df[treatment_time_column]
if pd.api.types.is_numeric_dtype(col):
never_treated = col.isna() | np.isinf(col)
else:
never_treated = col.isna()
df.loc[never_treated, new_column] = np.nan
return df
[docs]
def aggregate_to_cohorts(
data: pd.DataFrame,
unit_column: str,
time_column: str,
treatment_column: str,
outcome: str,
covariates: Optional[List[str]] = None
) -> pd.DataFrame:
"""
Aggregate unit-level data to treatment cohort means.
Useful for visualization and cohort-level analysis.
Parameters
----------
data : pd.DataFrame
Unit-level panel data.
unit_column : str
Name of unit identifier column.
time_column : str
Name of time period column.
treatment_column : str
Name of treatment indicator column.
outcome : str
Name of outcome variable column.
covariates : list of str, optional
Additional columns to aggregate (will compute means).
Returns
-------
pd.DataFrame
Cohort-level data with mean outcomes by treatment status and period.
Examples
--------
>>> df = pd.DataFrame({
... 'unit': [1, 1, 2, 2, 3, 3, 4, 4],
... 'period': [0, 1, 0, 1, 0, 1, 0, 1],
... 'treated': [1, 1, 1, 1, 0, 0, 0, 0],
... 'y': [10, 15, 12, 17, 8, 10, 9, 11]
... })
>>> cohort_df = aggregate_to_cohorts(df, 'unit', 'period', 'treated', 'y')
>>> len(cohort_df)
4
"""
agg_cols = {outcome: "mean", unit_column: "nunique"}
if covariates:
for cov in covariates:
agg_cols[cov] = "mean"
cohort_data = data.groupby([treatment_column, time_column]).agg(agg_cols).reset_index()
# Rename columns
cohort_data = cohort_data.rename(columns={
unit_column: "n_units",
outcome: f"mean_{outcome}"
})
return cohort_data
[docs]
def rank_control_units(
data: pd.DataFrame,
unit_column: str,
time_column: str,
outcome_column: str,
treatment_column: Optional[str] = None,
treated_units: Optional[List[Any]] = None,
pre_periods: Optional[List[Any]] = None,
covariates: Optional[List[str]] = None,
outcome_weight: float = 0.7,
covariate_weight: float = 0.3,
exclude_units: Optional[List[Any]] = None,
require_units: Optional[List[Any]] = None,
n_top: Optional[int] = None,
suggest_treatment_candidates: bool = False,
n_treatment_candidates: int = 5,
lambda_reg: float = 0.0,
) -> pd.DataFrame:
"""
Rank potential control units by their suitability for DiD analysis.
Evaluates control units based on pre-treatment outcome trend similarity
and optional covariate matching to treated units. Returns a ranked list
with quality scores.
Parameters
----------
data : pd.DataFrame
Panel data in long format.
unit_column : str
Column name for unit identifier.
time_column : str
Column name for time periods.
outcome_column : str
Column name for outcome variable.
treatment_column : str, optional
Column with binary treatment indicator (0/1). Used to identify
treated units from data.
treated_units : list, optional
Explicit list of treated unit IDs. Alternative to treatment_column.
pre_periods : list, optional
Pre-treatment periods for comparison. If None, uses first half of periods.
covariates : list of str, optional
Covariate columns for matching. Similarity is based on pre-treatment means.
outcome_weight : float, default=0.7
Weight for pre-treatment outcome trend similarity (0-1).
covariate_weight : float, default=0.3
Weight for covariate distance (0-1). Ignored if no covariates.
exclude_units : list, optional
Units that cannot be in control group.
require_units : list, optional
Units that must be in control group (will always appear in output).
n_top : int, optional
Return only top N control units. If None, return all.
suggest_treatment_candidates : bool, default=False
If True and no treated units specified, identify potential treatment
candidates instead of ranking controls.
n_treatment_candidates : int, default=5
Number of treatment candidates to suggest.
lambda_reg : float, default=0.0
Regularization for synthetic weights. Higher values give more uniform
weights across controls.
Returns
-------
pd.DataFrame
Ranked control units with columns:
- unit: Unit identifier
- quality_score: Combined quality score (0-1, higher is better)
- outcome_trend_score: Pre-treatment outcome trend similarity
- covariate_score: Covariate match score (NaN if no covariates)
- synthetic_weight: Weight from synthetic control optimization
- pre_trend_rmse: RMSE of pre-treatment outcome vs treated mean
- is_required: Whether unit was in require_units
If suggest_treatment_candidates=True (and no treated units):
- unit: Unit identifier
- treatment_candidate_score: Suitability as treatment unit
- avg_outcome_level: Pre-treatment outcome mean
- outcome_trend: Pre-treatment trend slope
- n_similar_controls: Count of similar potential controls
Examples
--------
Rank controls against treated units:
>>> data = generate_did_data(n_units=30, n_periods=6, seed=42)
>>> ranking = rank_control_units(
... data,
... unit_column='unit',
... time_column='period',
... outcome_column='outcome',
... treatment_column='treated',
... n_top=10
... )
>>> ranking['quality_score'].is_monotonic_decreasing
True
With covariates:
>>> data['size'] = np.random.randn(len(data))
>>> ranking = rank_control_units(
... data,
... unit_column='unit',
... time_column='period',
... outcome_column='outcome',
... treatment_column='treated',
... covariates=['size']
... )
Filter data for SyntheticDiD:
>>> top_controls = ranking['unit'].tolist()
>>> filtered = data[(data['treated'] == 1) | (data['unit'].isin(top_controls))]
"""
# -------------------------------------------------------------------------
# Input validation
# -------------------------------------------------------------------------
for col in [unit_column, time_column, outcome_column]:
if col not in data.columns:
raise ValueError(f"Column '{col}' not found in DataFrame.")
if treatment_column is not None and treatment_column not in data.columns:
raise ValueError(f"Treatment column '{treatment_column}' not found in DataFrame.")
if covariates:
for cov in covariates:
if cov not in data.columns:
raise ValueError(f"Covariate column '{cov}' not found in DataFrame.")
if not 0 <= outcome_weight <= 1:
raise ValueError("outcome_weight must be between 0 and 1")
if not 0 <= covariate_weight <= 1:
raise ValueError("covariate_weight must be between 0 and 1")
if treated_units is not None and treatment_column is not None:
raise ValueError("Specify either 'treated_units' or 'treatment_column', not both.")
if require_units and exclude_units:
invalid_required = [u for u in require_units if u in exclude_units]
if invalid_required:
raise ValueError(f"Units cannot be both required and excluded: {invalid_required}")
# -------------------------------------------------------------------------
# Determine pre-treatment periods
# -------------------------------------------------------------------------
all_periods = sorted(data[time_column].unique())
if pre_periods is None:
mid_point = len(all_periods) // 2
pre_periods = all_periods[:mid_point]
else:
pre_periods = list(pre_periods)
if len(pre_periods) == 0:
raise ValueError("No pre-treatment periods specified or inferred.")
# -------------------------------------------------------------------------
# Identify treated and control units
# -------------------------------------------------------------------------
all_units = list(data[unit_column].unique())
if treated_units is not None:
treated_set = set(treated_units)
elif treatment_column is not None:
unit_treatment = data.groupby(unit_column)[treatment_column].first()
treated_set = set(unit_treatment[unit_treatment == 1].index)
elif suggest_treatment_candidates:
# Treatment candidate discovery mode - no treated units
treated_set = set()
else:
raise ValueError(
"Must specify treated_units, treatment_column, or set "
"suggest_treatment_candidates=True"
)
# -------------------------------------------------------------------------
# Treatment candidate discovery mode
# -------------------------------------------------------------------------
if suggest_treatment_candidates and len(treated_set) == 0:
return _suggest_treatment_candidates(
data, unit_column, time_column, outcome_column,
pre_periods, n_treatment_candidates
)
if len(treated_set) == 0:
raise ValueError("No treated units found.")
# Determine control candidates
control_candidates = [u for u in all_units if u not in treated_set]
if exclude_units:
control_candidates = [u for u in control_candidates if u not in exclude_units]
if len(control_candidates) == 0:
raise ValueError("No control units available after exclusions.")
# -------------------------------------------------------------------------
# Create outcome matrices (pre-treatment)
# -------------------------------------------------------------------------
pre_data = data[data[time_column].isin(pre_periods)]
pivot = pre_data.pivot(index=time_column, columns=unit_column, values=outcome_column)
# Filter to pre_periods that exist in data
valid_pre_periods = [p for p in pre_periods if p in pivot.index]
if len(valid_pre_periods) == 0:
raise ValueError("No data found for specified pre-treatment periods.")
# Filter control_candidates to those present in pivot (handles unbalanced panels)
control_candidates = [c for c in control_candidates if c in pivot.columns]
if len(control_candidates) == 0:
raise ValueError("No control units found in pre-treatment data.")
# Control outcomes: shape (n_pre_periods, n_control_candidates)
Y_control = pivot.loc[valid_pre_periods, control_candidates].values.astype(float)
# Treated outcomes mean: shape (n_pre_periods,)
treated_list = [u for u in treated_set if u in pivot.columns]
if len(treated_list) == 0:
raise ValueError("Treated units not found in pre-treatment data.")
Y_treated_mean = pivot.loc[valid_pre_periods, treated_list].mean(axis=1).values.astype(float)
# -------------------------------------------------------------------------
# Compute outcome trend scores
# -------------------------------------------------------------------------
# Synthetic weights (higher = better match)
synthetic_weights = compute_synthetic_weights(
Y_control, Y_treated_mean, lambda_reg=lambda_reg
)
# RMSE for each control vs treated mean (use nanmean to handle missing data)
rmse_scores = []
for j in range(len(control_candidates)):
y_c = Y_control[:, j]
rmse = np.sqrt(np.nanmean((y_c - Y_treated_mean) ** 2))
rmse_scores.append(rmse)
# Convert RMSE to similarity score (lower RMSE = higher score)
max_rmse = max(rmse_scores) if rmse_scores else 1.0
min_rmse = min(rmse_scores) if rmse_scores else 0.0
rmse_range = max_rmse - min_rmse
if rmse_range < 1e-10:
# All controls have identical/similar pre-trends (includes single control case)
outcome_trend_scores = [1.0] * len(rmse_scores)
else:
# Normalize so best control gets 1.0, worst gets 0.0
outcome_trend_scores = [1 - (rmse - min_rmse) / rmse_range for rmse in rmse_scores]
# -------------------------------------------------------------------------
# Compute covariate scores (if covariates provided)
# -------------------------------------------------------------------------
if covariates and len(covariates) > 0:
# Get unit-level covariate values (pre-treatment mean)
cov_data = pre_data.groupby(unit_column)[covariates].mean()
# Treated covariate profile (mean across treated units)
treated_cov = cov_data.loc[list(treated_set)].mean()
# Standardize covariates
cov_mean = cov_data.mean()
cov_std = cov_data.std().replace(0, 1) # Avoid division by zero
cov_standardized = (cov_data - cov_mean) / cov_std
treated_cov_std = (treated_cov - cov_mean) / cov_std
# Euclidean distance in standardized space (vectorized)
control_cov_matrix = cov_standardized.loc[control_candidates].values
treated_cov_vector = treated_cov_std.values
covariate_distances = np.sqrt(
np.sum((control_cov_matrix - treated_cov_vector) ** 2, axis=1)
)
# Convert distance to similarity score (min-max normalization)
max_dist = covariate_distances.max() if len(covariate_distances) > 0 else 1.0
min_dist = covariate_distances.min() if len(covariate_distances) > 0 else 0.0
dist_range = max_dist - min_dist
if dist_range < 1e-10:
# All controls have identical/similar covariate profiles
covariate_scores = [1.0] * len(covariate_distances)
else:
# Normalize so best control (closest) gets 1.0, worst gets 0.0
covariate_scores = (1 - (covariate_distances - min_dist) / dist_range).tolist()
else:
covariate_scores = [np.nan] * len(control_candidates)
# -------------------------------------------------------------------------
# Compute combined quality score
# -------------------------------------------------------------------------
# Normalize weights
total_weight = outcome_weight + covariate_weight
if total_weight > 0:
norm_outcome_weight = outcome_weight / total_weight
norm_covariate_weight = covariate_weight / total_weight
else:
norm_outcome_weight = 1.0
norm_covariate_weight = 0.0
quality_scores = []
for i in range(len(control_candidates)):
outcome_score = outcome_trend_scores[i]
cov_score = covariate_scores[i]
if np.isnan(cov_score):
# No covariates - use only outcome score
combined = outcome_score
else:
combined = norm_outcome_weight * outcome_score + norm_covariate_weight * cov_score
quality_scores.append(combined)
# -------------------------------------------------------------------------
# Build result DataFrame
# -------------------------------------------------------------------------
require_set = set(require_units) if require_units else set()
result = pd.DataFrame({
'unit': control_candidates,
'quality_score': quality_scores,
'outcome_trend_score': outcome_trend_scores,
'covariate_score': covariate_scores,
'synthetic_weight': synthetic_weights,
'pre_trend_rmse': rmse_scores,
'is_required': [u in require_set for u in control_candidates]
})
# Sort by quality score (descending)
result = result.sort_values('quality_score', ascending=False)
# Apply n_top limit if specified
if n_top is not None and n_top < len(result):
# Always include required units
required_df = result[result['is_required']]
non_required_df = result[~result['is_required']]
# Take top from non-required to fill remaining slots
remaining_slots = max(0, n_top - len(required_df))
top_non_required = non_required_df.head(remaining_slots)
result = pd.concat([required_df, top_non_required])
result = result.sort_values('quality_score', ascending=False)
return result.reset_index(drop=True)
def _suggest_treatment_candidates(
data: pd.DataFrame,
unit_column: str,
time_column: str,
outcome_column: str,
pre_periods: List[Any],
n_candidates: int
) -> pd.DataFrame:
"""
Identify units that would make good treatment candidates.
A good treatment candidate:
1. Has many similar control units available (for matching)
2. Has stable pre-treatment trends (predictable counterfactual)
3. Is not an extreme outlier
Parameters
----------
data : pd.DataFrame
Panel data.
unit_column : str
Unit identifier column.
time_column : str
Time period column.
outcome_column : str
Outcome variable column.
pre_periods : list
Pre-treatment periods.
n_candidates : int
Number of candidates to return.
Returns
-------
pd.DataFrame
Treatment candidates with scores.
"""
all_units = list(data[unit_column].unique())
pre_data = data[data[time_column].isin(pre_periods)]
candidate_info = []
for unit in all_units:
unit_data = pre_data[pre_data[unit_column] == unit]
if len(unit_data) == 0:
continue
# Average outcome level
avg_outcome = unit_data[outcome_column].mean()
# Trend (simple linear regression slope)
times = unit_data[time_column].values
outcomes = unit_data[outcome_column].values
if len(times) > 1:
times_norm = np.arange(len(times))
try:
slope = np.polyfit(times_norm, outcomes, 1)[0]
except (np.linalg.LinAlgError, ValueError):
slope = 0.0
else:
slope = 0.0
# Count similar potential controls
other_units = [u for u in all_units if u != unit]
other_means = pre_data[
pre_data[unit_column].isin(other_units)
].groupby(unit_column)[outcome_column].mean()
if len(other_means) > 0:
sd = other_means.std()
if sd > 0:
n_similar = int(np.sum(
np.abs(other_means - avg_outcome) < _SIMILARITY_THRESHOLD_SD * sd
))
else:
n_similar = len(other_means)
else:
n_similar = 0
candidate_info.append({
'unit': unit,
'avg_outcome_level': avg_outcome,
'outcome_trend': slope,
'n_similar_controls': n_similar
})
if len(candidate_info) == 0:
return pd.DataFrame(columns=[
'unit', 'treatment_candidate_score', 'avg_outcome_level',
'outcome_trend', 'n_similar_controls'
])
result = pd.DataFrame(candidate_info)
# Score: prefer units with many similar controls and moderate outcome levels
max_similar = result['n_similar_controls'].max()
if max_similar > 0:
similarity_score = result['n_similar_controls'] / max_similar
else:
similarity_score = pd.Series([0.0] * len(result))
# Penalty for outliers in outcome level
outcome_mean = result['avg_outcome_level'].mean()
outcome_std = result['avg_outcome_level'].std()
if outcome_std > 0:
outcome_z = np.abs((result['avg_outcome_level'] - outcome_mean) / outcome_std)
else:
outcome_z = pd.Series([0.0] * len(result))
result['treatment_candidate_score'] = (
similarity_score - _OUTLIER_PENALTY_WEIGHT * outcome_z
).clip(0, 1)
# Return top candidates
result = result.nlargest(n_candidates, 'treatment_candidate_score')
return result.reset_index(drop=True)