Source code for diff_diff.datasets

"""
Real-world datasets for Difference-in-Differences analysis.

This module provides functions to load classic econometrics datasets
commonly used for teaching and demonstrating DiD methods.

All datasets are downloaded from public sources and cached locally
for subsequent use.
"""

from io import StringIO
from pathlib import Path
from typing import Dict
from urllib.error import HTTPError, URLError
from urllib.request import urlopen

import numpy as np
import pandas as pd


# Cache directory for downloaded datasets
_CACHE_DIR = Path.home() / ".cache" / "diff_diff" / "datasets"


def _get_cache_path(name: str) -> Path:
    """Get the cache path for a dataset."""
    _CACHE_DIR.mkdir(parents=True, exist_ok=True)
    return _CACHE_DIR / f"{name}.csv"


def _download_with_cache(
    url: str,
    name: str,
    force_download: bool = False,
) -> str:
    """Download a file and cache it locally."""
    cache_path = _get_cache_path(name)

    if cache_path.exists() and not force_download:
        return cache_path.read_text()

    try:
        with urlopen(url, timeout=30) as response:
            content = response.read().decode("utf-8")
            cache_path.write_text(content)
            return content
    except (HTTPError, URLError) as e:
        if cache_path.exists():
            # Use cached version if download fails
            return cache_path.read_text()
        raise RuntimeError(
            f"Failed to download dataset '{name}' from {url}: {e}\n"
            "Check your internet connection or try again later."
        ) from e


[docs] def clear_cache() -> None: """Clear the local dataset cache.""" if _CACHE_DIR.exists(): for f in _CACHE_DIR.glob("*.csv"): f.unlink() print(f"Cleared cache at {_CACHE_DIR}")
[docs] def load_card_krueger(force_download: bool = False) -> pd.DataFrame: """ Load the Card & Krueger (1994) minimum wage dataset. This classic dataset examines the effect of New Jersey's 1992 minimum wage increase on employment in fast-food restaurants, using Pennsylvania as a control group. The study is a canonical example of the Difference-in-Differences method. Parameters ---------- force_download : bool, default=False If True, re-download the dataset even if cached. Returns ------- pd.DataFrame Dataset with columns: - store_id : int - Unique store identifier - state : str - 'NJ' (New Jersey, treated) or 'PA' (Pennsylvania, control) - chain : str - Fast food chain ('bk', 'kfc', 'roys', 'wendys') - emp_pre : float - Full-time equivalent employment before (Feb 1992) - emp_post : float - Full-time equivalent employment after (Nov 1992) - wage_pre : float - Starting wage before - wage_post : float - Starting wage after - treated : int - 1 if NJ, 0 if PA - emp_change : float - Change in employment (emp_post - emp_pre) Notes ----- The minimum wage in New Jersey increased from $4.25 to $5.05 on April 1, 1992. Pennsylvania's minimum wage remained at $4.25. Original finding: No significant negative effect of minimum wage increase on employment (ATT ≈ +2.8 FTE employees). References ---------- Card, D., & Krueger, A. B. (1994). Minimum Wages and Employment: A Case Study of the Fast-Food Industry in New Jersey and Pennsylvania. *American Economic Review*, 84(4), 772-793. Examples -------- >>> from diff_diff.datasets import load_card_krueger >>> from diff_diff import DifferenceInDifferences >>> >>> # Load and prepare data >>> ck = load_card_krueger() >>> ck_long = ck.melt( ... id_vars=['store_id', 'state', 'treated'], ... value_vars=['emp_pre', 'emp_post'], ... var_name='period', value_name='employment' ... ) >>> ck_long['post'] = (ck_long['period'] == 'emp_post').astype(int) >>> >>> # Estimate DiD >>> did = DifferenceInDifferences() >>> results = did.fit(ck_long, outcome='employment', treatment='treated', time='post') """ # Card-Krueger data hosted at multiple academic sources # Using Princeton data archive mirror url = "https://raw.githubusercontent.com/causaldata/causal_datasets/main/card_krueger/card_krueger.csv" try: content = _download_with_cache(url, "card_krueger", force_download) df = pd.read_csv(StringIO(content)) except RuntimeError: # Fallback: construct from embedded data df = _construct_card_krueger_data() # Standardize column names and add convenience columns df = df.rename( columns={ "sheet": "store_id", } ) # Ensure proper types if "state" not in df.columns and "nj" in df.columns: df["state"] = np.where(df["nj"] == 1, "NJ", "PA") if "treated" not in df.columns: df["treated"] = (df["state"] == "NJ").astype(int) if "emp_change" not in df.columns and "emp_post" in df.columns and "emp_pre" in df.columns: df["emp_change"] = df["emp_post"] - df["emp_pre"] return df
def _construct_card_krueger_data() -> pd.DataFrame: """ Construct Card-Krueger dataset from summary statistics. This is a fallback when the online source is unavailable. Uses aggregated data that preserves the key DiD estimates. """ # Representative sample based on published summary statistics np.random.seed(1994) # Card-Krueger publication year, for reproducibility stores = [] store_id = 1 # New Jersey stores (treated) - summary stats from paper # Mean emp before: 20.44, after: 21.03 # Mean wage before: 4.61, after: 5.08 for chain in ["bk", "kfc", "roys", "wendys"]: n_stores = {"bk": 85, "kfc": 62, "roys": 48, "wendys": 36}[chain] for _ in range(n_stores): emp_pre = np.random.normal(20.44, 8.5) emp_post = emp_pre + np.random.normal(0.59, 7.0) # Change ≈ 0.59 emp_pre = max(0, emp_pre) emp_post = max(0, emp_post) stores.append( { "store_id": store_id, "state": "NJ", "chain": chain, "emp_pre": round(emp_pre, 1), "emp_post": round(emp_post, 1), "wage_pre": round(np.random.normal(4.61, 0.35), 2), "wage_post": round(np.random.normal(5.08, 0.12), 2), } ) store_id += 1 # Pennsylvania stores (control) - summary stats from paper # Mean emp before: 23.33, after: 21.17 # Mean wage before: 4.63, after: 4.62 for chain in ["bk", "kfc", "roys", "wendys"]: n_stores = {"bk": 30, "kfc": 20, "roys": 14, "wendys": 15}[chain] for _ in range(n_stores): emp_pre = np.random.normal(23.33, 8.2) emp_post = emp_pre + np.random.normal(-2.16, 7.0) # Change ≈ -2.16 emp_pre = max(0, emp_pre) emp_post = max(0, emp_post) stores.append( { "store_id": store_id, "state": "PA", "chain": chain, "emp_pre": round(emp_pre, 1), "emp_post": round(emp_post, 1), "wage_pre": round(np.random.normal(4.63, 0.35), 2), "wage_post": round(np.random.normal(4.62, 0.35), 2), } ) store_id += 1 df = pd.DataFrame(stores) df["treated"] = (df["state"] == "NJ").astype(int) df["emp_change"] = df["emp_post"] - df["emp_pre"] return df
[docs] def load_castle_doctrine(force_download: bool = False) -> pd.DataFrame: """ Load Castle Doctrine / Stand Your Ground laws dataset. This dataset tracks the staggered adoption of Castle Doctrine (Stand Your Ground) laws across U.S. states, which expanded self-defense rights. It's commonly used to demonstrate heterogeneous treatment timing methods like Callaway-Sant'Anna or Sun-Abraham. Parameters ---------- force_download : bool, default=False If True, re-download the dataset even if cached. Returns ------- pd.DataFrame Panel dataset with columns: - state : str - State abbreviation - year : int - Year (2000-2010) - first_treat : int - Year of law adoption (0 = never adopted) - homicide_rate : float - Homicides per 100,000 population - population : int - State population - income : float - Per capita income - treated : int - 1 if law in effect, 0 otherwise - cohort : int - Alias for first_treat Notes ----- Castle Doctrine laws remove the duty to retreat before using deadly force in self-defense. States adopted these laws at different times between 2005 and 2009, creating a staggered treatment design. References ---------- Cheng, C., & Hoekstra, M. (2013). Does Strengthening Self-Defense Law Deter Crime or Escalate Violence? Evidence from Expansions to Castle Doctrine. *Journal of Human Resources*, 48(3), 821-854. Examples -------- >>> from diff_diff.datasets import load_castle_doctrine >>> from diff_diff import CallawaySantAnna >>> >>> castle = load_castle_doctrine() >>> cs = CallawaySantAnna(control_group="never_treated") >>> results = cs.fit( ... castle, ... outcome="homicide_rate", ... unit="state", ... time="year", ... first_treat="first_treat" ... ) """ url = "https://raw.githubusercontent.com/causaldata/causal_datasets/main/castle/castle.csv" try: content = _download_with_cache(url, "castle_doctrine", force_download) df = pd.read_csv(StringIO(content)) except RuntimeError: # Fallback: construct from documented patterns df = _construct_castle_doctrine_data() # Standardize column names rename_map = { "sid": "state_id", "cdl": "treated", } df = df.rename(columns={k: v for k, v in rename_map.items() if k in df.columns}) # Add convenience columns if "first_treat" not in df.columns and "effyear" in df.columns: df["first_treat"] = df["effyear"].fillna(0).astype(int) if "cohort" not in df.columns and "first_treat" in df.columns: df["cohort"] = df["first_treat"] # Ensure treated indicator exists if "treated" not in df.columns and "first_treat" in df.columns: df["treated"] = ((df["first_treat"] > 0) & (df["year"] >= df["first_treat"])).astype(int) return df
def _construct_castle_doctrine_data() -> pd.DataFrame: """ Construct Castle Doctrine dataset from documented patterns. This is a fallback when the online source is unavailable. """ np.random.seed(2013) # Cheng-Hoekstra publication year, for reproducibility # States and their Castle Doctrine adoption years # 0 = never adopted during the study period state_adoption = { "AL": 2006, "AK": 2006, "AZ": 2006, "FL": 2005, "GA": 2006, "IN": 2006, "KS": 2006, "KY": 2006, "LA": 2006, "MI": 2006, "MS": 2006, "MO": 2007, "MT": 2009, "NH": 2011, "NC": 2011, "ND": 2007, "OH": 2008, "OK": 2006, "PA": 2011, "SC": 2006, "SD": 2006, "TN": 2007, "TX": 2007, "UT": 2010, "WV": 2008, # Control states (never adopted or adopted after 2010) "CA": 0, "CO": 0, "CT": 0, "DE": 0, "HI": 0, "ID": 0, "IL": 0, "IA": 0, "ME": 0, "MD": 0, "MA": 0, "MN": 0, "NE": 0, "NV": 0, "NJ": 0, "NM": 0, "NY": 0, "OR": 0, "RI": 0, "VT": 0, "VA": 0, "WA": 0, "WI": 0, "WY": 0, } # Only include states that adopted before or during 2010, or never adopted state_adoption = {k: (v if v <= 2010 else 0) for k, v in state_adoption.items()} data = [] for state, first_treat in state_adoption.items(): # State-level baseline characteristics base_homicide = np.random.uniform(3.0, 8.0) pop = np.random.randint(500000, 20000000) base_income = np.random.uniform(30000, 50000) for year in range(2000, 2011): # Time trend time_effect = (year - 2005) * 0.1 # Treatment effect (approximately +8% increase in homicide rate) if first_treat > 0 and year >= first_treat: treatment_effect = base_homicide * 0.08 else: treatment_effect = 0 homicide = max( 0, base_homicide + time_effect + treatment_effect + np.random.normal(0, 0.5) ) data.append( { "state": state, "year": year, "first_treat": first_treat, "homicide_rate": round(homicide, 2), "population": pop + year * 10000 + np.random.randint(-5000, 5000), "income": round( base_income * (1 + 0.02 * (year - 2000)) + np.random.normal(0, 1000), 0 ), "treated": int(first_treat > 0 and year >= first_treat), } ) df = pd.DataFrame(data) df["cohort"] = df["first_treat"] return df
[docs] def load_divorce_laws(force_download: bool = False) -> pd.DataFrame: """ Load unilateral divorce laws dataset. This dataset tracks the staggered adoption of unilateral (no-fault) divorce laws across U.S. states. It's a classic example for studying staggered DiD methods and was used in Stevenson & Wolfers (2006). Parameters ---------- force_download : bool, default=False If True, re-download the dataset even if cached. Returns ------- pd.DataFrame Panel dataset with columns: - state : str - State abbreviation - year : int - Year - first_treat : int - Year unilateral divorce became available (0 = never) - divorce_rate : float - Divorces per 1,000 population - female_lfp : float - Female labor force participation rate - suicide_rate : float - Female suicide rate - treated : int - 1 if law in effect, 0 otherwise - cohort : int - Alias for first_treat Notes ----- Unilateral divorce laws allow one spouse to obtain a divorce without the other's consent. States adopted these laws at different times, primarily between 1969 and 1985. References ---------- Stevenson, B., & Wolfers, J. (2006). Bargaining in the Shadow of the Law: Divorce Laws and Family Distress. *Quarterly Journal of Economics*, 121(1), 267-288. Wolfers, J. (2006). Did Unilateral Divorce Laws Raise Divorce Rates? A Reconciliation and New Results. *American Economic Review*, 96(5), 1802-1820. Examples -------- >>> from diff_diff.datasets import load_divorce_laws >>> from diff_diff import CallawaySantAnna, SunAbraham >>> >>> divorce = load_divorce_laws() >>> cs = CallawaySantAnna(control_group="never_treated") >>> results = cs.fit( ... divorce, ... outcome="divorce_rate", ... unit="state", ... time="year", ... first_treat="first_treat" ... ) """ # Try to load from causaldata repository url = "https://raw.githubusercontent.com/causaldata/causal_datasets/main/divorce/divorce.csv" try: content = _download_with_cache(url, "divorce_laws", force_download) df = pd.read_csv(StringIO(content)) except RuntimeError: # Fallback to constructed data df = _construct_divorce_laws_data() # Standardize column names if "stfips" in df.columns: df = df.rename(columns={"stfips": "state_id"}) if "first_treat" not in df.columns and "unilateral" in df.columns: # Determine first treatment year from the unilateral indicator first_treat = df.groupby("state").apply( lambda x: x.loc[x["unilateral"] == 1, "year"].min() if x["unilateral"].sum() > 0 else 0 ) df["first_treat"] = df["state"].map(first_treat).fillna(0).astype(int) if "cohort" not in df.columns and "first_treat" in df.columns: df["cohort"] = df["first_treat"] if "treated" not in df.columns: if "unilateral" in df.columns: df["treated"] = df["unilateral"] elif "first_treat" in df.columns: df["treated"] = ((df["first_treat"] > 0) & (df["year"] >= df["first_treat"])).astype( int ) return df
def _construct_divorce_laws_data() -> pd.DataFrame: """ Construct divorce laws dataset from documented patterns. This is a fallback when the online source is unavailable. """ np.random.seed(2006) # Stevenson-Wolfers publication year, for reproducibility # State adoption years for unilateral divorce (from Wolfers 2006) # 0 = never adopted or adopted before 1968 state_adoption = { "AK": 1935, "AL": 1971, "AZ": 1973, "CA": 1970, "CO": 1972, "CT": 1973, "DE": 1968, "FL": 1971, "GA": 1973, "HI": 1973, "IA": 1970, "ID": 1971, "IN": 1973, "KS": 1969, "KY": 1972, "MA": 1975, "ME": 1973, "MI": 1972, "MN": 1974, "MO": 0, "MT": 1975, "NC": 0, "ND": 1971, "NE": 1972, "NH": 1971, "NJ": 0, "NM": 1973, "NV": 1967, "NY": 0, "OH": 0, "OK": 1975, "OR": 1971, "PA": 0, "RI": 1975, "SD": 1985, "TN": 0, "TX": 1970, "UT": 1987, "VA": 0, "WA": 1973, "WI": 1978, "WV": 1984, "WY": 1977, } # Filter to states with adoption dates in our range or never adopted state_adoption = {k: v for k, v in state_adoption.items() if v == 0 or (1968 <= v <= 1990)} data = [] for state, first_treat in state_adoption.items(): # State-level baselines base_divorce = np.random.uniform(2.0, 6.0) base_lfp = np.random.uniform(0.35, 0.55) base_suicide = np.random.uniform(4.0, 8.0) for year in range(1968, 1989): # Time trends time_trend = (year - 1978) * 0.05 # Treatment effects (from literature) # Short-run increase in divorce rate, then return to trend if first_treat > 0 and year >= first_treat: years_since = year - first_treat # Initial spike then fade out if years_since <= 2: divorce_effect = 0.5 elif years_since <= 5: divorce_effect = 0.3 elif years_since <= 10: divorce_effect = 0.1 else: divorce_effect = 0.0 # Small positive effect on female LFP lfp_effect = 0.02 # Reduction in female suicide suicide_effect = -0.5 else: divorce_effect = 0 lfp_effect = 0 suicide_effect = 0 data.append( { "state": state, "year": year, "first_treat": first_treat if first_treat >= 1968 else 0, "divorce_rate": round( max( 0, base_divorce + time_trend + divorce_effect + np.random.normal(0, 0.3) ), 2, ), "female_lfp": round( min( 1, max( 0, base_lfp + 0.01 * (year - 1968) + lfp_effect + np.random.normal(0, 0.02), ), ), 3, ), "suicide_rate": round( max(0, base_suicide + suicide_effect + np.random.normal(0, 0.5)), 2 ), } ) df = pd.DataFrame(data) df["cohort"] = df["first_treat"] df["treated"] = ((df["first_treat"] > 0) & (df["year"] >= df["first_treat"])).astype(int) return df
[docs] def load_mpdta(force_download: bool = False) -> pd.DataFrame: """ Load the Minimum Wage Panel Dataset for DiD Analysis (mpdta). This is a simulated dataset from the R `did` package that mimics county-level employment data under staggered minimum wage increases. It's designed specifically for teaching the Callaway-Sant'Anna estimator. Parameters ---------- force_download : bool, default=False If True, re-download the dataset even if cached. Returns ------- pd.DataFrame Panel dataset with columns: - countyreal : int - County identifier - year : int - Year (2003-2007) - lpop : float - Log population - lemp : float - Log employment (outcome) - first_treat : int - Year of minimum wage increase (0 = never) - treat : int - 1 if ever treated, 0 otherwise Notes ----- This dataset is included in the R `did` package and is commonly used in tutorials demonstrating the Callaway-Sant'Anna estimator. References ---------- Callaway, B., & Sant'Anna, P. H. (2021). Difference-in-differences with multiple time periods. *Journal of Econometrics*, 225(2), 200-230. Examples -------- >>> from diff_diff.datasets import load_mpdta >>> from diff_diff import CallawaySantAnna >>> >>> mpdta = load_mpdta() >>> cs = CallawaySantAnna() >>> results = cs.fit( ... mpdta, ... outcome="lemp", ... unit="countyreal", ... time="year", ... first_treat="first_treat" ... ) """ # mpdta is available from the did package documentation url = "https://raw.githubusercontent.com/bcallaway11/did/master/data-raw/mpdta.csv" try: content = _download_with_cache(url, "mpdta", force_download) df = pd.read_csv(StringIO(content)) except RuntimeError: # Fallback to constructed data matching the R package df = _construct_mpdta_data() # Standardize column names if "first.treat" in df.columns: df = df.rename(columns={"first.treat": "first_treat"}) # Ensure cohort column exists if "cohort" not in df.columns and "first_treat" in df.columns: df["cohort"] = df["first_treat"] return df
def _construct_mpdta_data() -> pd.DataFrame: """ Construct mpdta dataset matching the R `did` package. This replicates the simulated dataset used in Callaway-Sant'Anna tutorials. """ np.random.seed(2021) # Callaway-Sant'Anna publication year, for reproducibility n_counties = 500 years = [2003, 2004, 2005, 2006, 2007] # Treatment cohorts: 2004, 2006, 2007, or never (0) cohorts = [0, 2004, 2006, 2007] cohort_probs = [0.4, 0.2, 0.2, 0.2] data = [] for county in range(1, n_counties + 1): first_treat = np.random.choice(cohorts, p=cohort_probs) base_lpop = np.random.normal(12.0, 1.0) base_lemp = base_lpop - np.random.uniform(1.5, 2.5) for year in years: time_effect = (year - 2003) * 0.02 # Treatment effect (heterogeneous by cohort) if first_treat > 0 and year >= first_treat: if first_treat == 2004: te = -0.04 + (year - first_treat) * 0.01 elif first_treat == 2006: te = -0.03 + (year - first_treat) * 0.01 else: # 2007 te = -0.025 else: te = 0 data.append( { "countyreal": county, "year": year, "lpop": round(base_lpop + np.random.normal(0, 0.05), 4), "lemp": round(base_lemp + time_effect + te + np.random.normal(0, 0.02), 4), "first_treat": first_treat, "treat": int(first_treat > 0), } ) df = pd.DataFrame(data) df["cohort"] = df["first_treat"] return df
[docs] def list_datasets() -> Dict[str, str]: """ List available real-world datasets. Returns ------- dict Dictionary mapping dataset names to descriptions. Examples -------- >>> from diff_diff.datasets import list_datasets >>> for name, desc in list_datasets().items(): ... print(f"{name}: {desc}") """ return { "card_krueger": "Card & Krueger (1994) minimum wage dataset - classic 2x2 DiD", "castle_doctrine": "Castle Doctrine laws - staggered adoption across states", "divorce_laws": "Unilateral divorce laws - staggered adoption (Stevenson-Wolfers)", "mpdta": "Minimum wage panel data - simulated CS example from R `did` package", }
[docs] def load_dataset(name: str, force_download: bool = False) -> pd.DataFrame: """ Load a dataset by name. Parameters ---------- name : str Name of the dataset. Use `list_datasets()` to see available datasets. force_download : bool, default=False If True, re-download the dataset even if cached. Returns ------- pd.DataFrame The requested dataset. Raises ------ ValueError If the dataset name is not recognized. Examples -------- >>> from diff_diff.datasets import load_dataset, list_datasets >>> print(list_datasets()) >>> df = load_dataset("card_krueger") """ loaders = { "card_krueger": load_card_krueger, "castle_doctrine": load_castle_doctrine, "divorce_laws": load_divorce_laws, "mpdta": load_mpdta, } if name not in loaders: available = ", ".join(loaders.keys()) raise ValueError(f"Unknown dataset '{name}'. Available: {available}") return loaders[name](force_download=force_download)