{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Synthetic Difference-in-Differences (SDID)\n", "\n", "This notebook demonstrates **Synthetic Difference-in-Differences** (Arkhangelsky et al., 2021), which combines:\n", "\n", "- **Synthetic Control**: Reweight control units to match treated units in pre-treatment periods\n", "- **Difference-in-Differences**: Use time variation to control for unobserved confounders\n", "\n", "SDID is particularly useful when:\n", "- You have few treated units (even just one)\n", "- You want to construct a better counterfactual than simple averaging\n", "- You're concerned about parallel trends violations\n", "\n", "We'll cover:\n", "1. When to use Synthetic DiD\n", "2. Basic estimation\n", "3. Understanding unit and time weights\n", "4. Inference (bootstrap and placebo)\n", "5. Tuning regularization" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "from diff_diff import SyntheticDiD, DifferenceInDifferences\n", "\n", "# For nicer plots (optional)\n", "try:\n", " import matplotlib.pyplot as plt\n", " plt.style.use('seaborn-v0_8-whitegrid')\n", " HAS_MATPLOTLIB = True\n", "except ImportError:\n", " HAS_MATPLOTLIB = False\n", " print(\"matplotlib not installed - visualization examples will be skipped\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. When to Use Synthetic DiD\n", "\n", "Consider SDID when:\n", "- You have **few treated units** (1-10 is common)\n", "- You have a **reasonably long pre-treatment period** (5+ periods ideal)\n", "- You have **many potential control units**\n", "- Standard DiD assumptions (parallel trends) may be questionable" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Construct panel data with explicit unit characteristics\n", "# Only observation-level noise is random — structural properties are deterministic\n", "np.random.seed(42)\n", "\n", "n_treated = 3\n", "n_control = 40\n", "n_pre = 10\n", "n_post = 4\n", "n_periods = n_pre + n_post\n", "true_att = 5.0\n", "\n", "# Control unit design:\n", "# - \"Similar\" controls (10 units): intercepts near 50, trends near 2.0\n", "# These will be identified by SDID and receive high weights\n", "# - \"Diverse\" controls (30 units): wide intercept range, lower trends\n", "# These bias standard DiD but SDID correctly down-weights them\n", "similar_intercepts = np.linspace(45, 55, 10)\n", "similar_trends = np.linspace(1.7, 2.3, 10)\n", "\n", "diverse_intercepts = np.linspace(30, 70, 30)\n", "diverse_trends = np.linspace(0.3, 1.0, 30)\n", "\n", "control_intercepts = np.concatenate([similar_intercepts, diverse_intercepts])\n", "control_trends = np.concatenate([similar_trends, diverse_trends])\n", "\n", "# Treated units: intercepts near 50, trends around 2.0\n", "treated_intercepts = [48, 50, 52]\n", "treated_trends = [1.8, 2.0, 2.2]\n", "\n", "data = []\n", "\n", "# Control units\n", "for i in range(n_control):\n", " for period in range(n_periods):\n", " y = control_intercepts[i] + control_trends[i] * period + np.random.normal(0, 0.5)\n", " data.append({\n", " 'unit': i,\n", " 'period': period,\n", " 'treated': 0,\n", " 'outcome': y\n", " })\n", "\n", "# Treated units\n", "for i in range(n_treated):\n", " unit_id = n_control + i\n", " for period in range(n_periods):\n", " y = treated_intercepts[i] + treated_trends[i] * period\n", " if period >= n_pre:\n", " y += true_att\n", " y += np.random.normal(0, 0.5)\n", " data.append({\n", " 'unit': unit_id,\n", " 'period': period,\n", " 'treated': 1,\n", " 'outcome': y\n", " })\n", "\n", "df = pd.DataFrame(data)\n", "print(f\"Dataset: {len(df)} observations\")\n", "print(f\"Treated units: {n_treated}\")\n", "print(f\"Control units: {n_control}\")\n", "print(f\"Pre-treatment periods: {n_pre}\")\n", "print(f\"Post-treatment periods: {n_post}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if HAS_MATPLOTLIB:\n", " # Visualize the data\n", " fig, ax = plt.subplots(figsize=(12, 6))\n", " \n", " # Plot control units (gray, thin lines)\n", " for unit in df[df['treated'] == 0]['unit'].unique():\n", " unit_data = df[df['unit'] == unit]\n", " ax.plot(unit_data['period'], unit_data['outcome'], \n", " color='gray', alpha=0.3, linewidth=0.5)\n", " \n", " # Plot treated units (colored, thick lines)\n", " colors = ['red', 'orange', 'darkred']\n", " for i, unit in enumerate(df[df['treated'] == 1]['unit'].unique()):\n", " unit_data = df[df['unit'] == unit]\n", " ax.plot(unit_data['period'], unit_data['outcome'], \n", " color=colors[i], linewidth=2, label=f'Treated {i+1}')\n", " \n", " # Mark treatment time\n", " ax.axvline(x=n_pre - 0.5, color='black', linestyle='--', label='Treatment')\n", " \n", " ax.set_xlabel('Period')\n", " ax.set_ylabel('Outcome')\n", " ax.set_title('Panel Data: Treated Units vs Control Units')\n", " ax.legend(loc='upper left')\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Basic Synthetic DiD Estimation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Fit Synthetic DiD\n", "sdid = SyntheticDiD(\n", " n_bootstrap=999, # Number of bootstrap replications\n", " seed=42\n", ")\n", "\n", "results = sdid.fit(\n", " df,\n", " outcome=\"outcome\",\n", " treatment=\"treated\",\n", " unit=\"unit\",\n", " time=\"period\",\n", " post_periods=list(range(n_pre, n_periods)) # Post-treatment periods\n", ")\n", "\n", "print(results.summary())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Comparing SDID with Standard DiD\n", "\n", "Let's compare the Synthetic DiD estimate with standard DiD to see the difference." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create post indicator for standard DiD\n", "df['post'] = (df['period'] >= n_pre).astype(int)\n", "\n", "# Standard DiD\n", "did = DifferenceInDifferences()\n", "results_did = did.fit(\n", " df,\n", " outcome=\"outcome\",\n", " treatment=\"treated\",\n", " time=\"post\"\n", ")\n", "\n", "print(\"Comparison of Estimators\")\n", "print(\"=\" * 50)\n", "print(f\"True ATT: {true_att}\")\n", "print(f\"\")\n", "print(f\"Standard DiD:\")\n", "print(f\" ATT: {results_did.att:.4f}\")\n", "print(f\" SE: {results_did.se:.4f}\")\n", "print(f\" Bias: {results_did.att - true_att:.4f}\")\n", "print(f\"\")\n", "print(f\"Synthetic DiD:\")\n", "print(f\" ATT: {results.att:.4f}\")\n", "print(f\" SE: {results.se:.4f}\")\n", "print(f\" Bias: {results.att - true_att:.4f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Understanding Unit Weights\n", "\n", "SDID assigns weights to control units based on how well they match the treated units in the pre-treatment period. Units with higher weights are more similar to the treated units." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# View unit weights\n", "weights_df = results.get_unit_weights_df()\n", "print(\"Top 10 control units by weight:\")\n", "print(weights_df.sort_values('weight', ascending=False).head(10))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Check weight properties\n", "print(f\"\\nWeight statistics:\")\n", "print(f\" Sum of weights: {weights_df['weight'].sum():.6f}\")\n", "print(f\" Number of non-zero weights: {(weights_df['weight'] > 0.01).sum()}\")\n", "print(f\" Max weight: {weights_df['weight'].max():.4f}\")\n", "print(f\" Effective number of controls: {1 / (weights_df['weight'] ** 2).sum():.1f}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if HAS_MATPLOTLIB:\n", " # Visualize unit weights\n", " fig, ax = plt.subplots(figsize=(12, 5))\n", " \n", " sorted_weights = weights_df.sort_values('weight', ascending=True)\n", " ax.barh(range(len(sorted_weights)), sorted_weights['weight'])\n", " ax.set_yticks(range(len(sorted_weights)))\n", " ax.set_yticklabels(sorted_weights['unit'])\n", " ax.set_xlabel('Weight')\n", " ax.set_ylabel('Control Unit')\n", " ax.set_title('Synthetic Control Unit Weights')\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Understanding Time Weights\n", "\n", "SDID also computes **time weights** that determine how much each pre-treatment period contributes to the baseline. Periods where treated and control outcomes are more similar get higher weight." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# View time weights\n", "time_weights_df = results.get_time_weights_df()\n", "print(\"Time weights:\")\n", "print(time_weights_df)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if HAS_MATPLOTLIB:\n", " # Visualize time weights\n", " fig, ax = plt.subplots(figsize=(10, 5))\n", " \n", " ax.bar(time_weights_df['period'], time_weights_df['weight'])\n", " ax.set_xlabel('Pre-treatment Period')\n", " ax.set_ylabel('Weight')\n", " ax.set_title('Time Weights for Pre-treatment Periods')\n", " ax.set_xticks(time_weights_df['period'])\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. Pre-treatment Fit\n", "\n", "A key diagnostic is how well the synthetic control matches the treated units in the pre-treatment period." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(f\"Pre-treatment fit (RMSE): {results.pre_treatment_fit:.4f}\")\n", "print(f\"\\nLower values indicate better fit.\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if HAS_MATPLOTLIB:\n", " # Compare treated vs synthetic control trajectories\n", " fig, ax = plt.subplots(figsize=(12, 6))\n", " \n", " # Compute weighted control outcome\n", " weights_dict = dict(zip(weights_df['unit'], weights_df['weight']))\n", " \n", " # Get treated mean\n", " treated_mean = df[df['treated'] == 1].groupby('period')['outcome'].mean()\n", " \n", " # Get synthetic control (weighted average of controls)\n", " control_data = df[df['treated'] == 0].copy()\n", " control_data['weight'] = control_data['unit'].map(weights_dict)\n", " synthetic = control_data.groupby('period').apply(\n", " lambda x: np.average(x['outcome'], weights=x['weight'])\n", " )\n", " \n", " # Simple average control\n", " simple_control = df[df['treated'] == 0].groupby('period')['outcome'].mean()\n", " \n", " ax.plot(treated_mean.index, treated_mean.values, 'o-', \n", " linewidth=2, markersize=8, color='red', label='Treated')\n", " ax.plot(synthetic.index, synthetic.values, 's--', \n", " linewidth=2, markersize=8, color='blue', label='Synthetic Control')\n", " ax.plot(simple_control.index, simple_control.values, '^:', \n", " linewidth=1, markersize=6, color='gray', alpha=0.7, label='Simple Average Control')\n", " \n", " ax.axvline(x=n_pre - 0.5, color='black', linestyle='--', alpha=0.5)\n", " ax.fill_between([n_pre - 0.5, n_periods - 0.5], ax.get_ylim()[0], ax.get_ylim()[1], \n", " alpha=0.1, color='green')\n", " \n", " ax.set_xlabel('Period')\n", " ax.set_ylabel('Outcome')\n", " ax.set_title('Treated vs Synthetic Control')\n", " ax.legend()\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": "## 7. Inference Methods\n\nSDID supports three inference methods:\n\n1. **Placebo** (`variance_method=\"placebo\"`, default): Placebo-based variance using Algorithm 4 from Arkhangelsky et al. (2021). Library default (R's default is bootstrap — we deviate because placebo is unconditionally available on pweight-only survey designs and sidesteps the refit bootstrap slowdown).\n2. **Bootstrap** (`variance_method=\"bootstrap\"`): Paper-faithful pairs bootstrap (Algorithm 2 step 2) — re-estimates ω and λ via Frank-Wolfe on each draw. Matches R's default `synthdid::vcov(method=\"bootstrap\")` behavior. Expect ~5–30× slower per fit than placebo (panel-size dependent).\n3. **Jackknife** (`variance_method=\"jackknife\"`): Algorithm 3 — fixed-weight leave-one-out. Deterministic; no bootstrap replications." }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Placebo-based inference\n", "sdid_placebo = SyntheticDiD(\n", " variance_method=\"placebo\", # Use placebo inference\n", " n_bootstrap=200, # Number of placebo replications\n", " seed=42\n", ")\n", "\n", "results_placebo = sdid_placebo.fit(\n", " df,\n", " outcome=\"outcome\",\n", " treatment=\"treated\",\n", " unit=\"unit\",\n", " time=\"period\",\n", " post_periods=list(range(n_pre, n_periods))\n", ")\n", "\n", "print(\"Placebo-based inference:\")\n", "print(f\"ATT: {results_placebo.att:.4f}\")\n", "print(f\"SE: {results_placebo.se:.4f}\")\n", "print(f\"Number of placebo effects: {len(results_placebo.placebo_effects)}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if HAS_MATPLOTLIB:\n", " # Visualize placebo distribution\n", " fig, ax = plt.subplots(figsize=(10, 6))\n", " \n", " ax.hist(results_placebo.placebo_effects, bins=20, alpha=0.7, \n", " edgecolor='black', label='Placebo effects')\n", " ax.axvline(x=results_placebo.att, color='red', linewidth=2, \n", " linestyle='--', label=f'Actual ATT = {results_placebo.att:.2f}')\n", " ax.axvline(x=0, color='gray', linewidth=1, linestyle=':')\n", " \n", " ax.set_xlabel('Effect')\n", " ax.set_ylabel('Frequency')\n", " ax.set_title('Distribution of Placebo Effects')\n", " ax.legend()\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 8. Tuning Regularization\n", "\n", "By default, SDID **auto-computes** regularization from the data noise level, matching R's `synthdid` package:\n", "\n", "- `zeta_omega`: Unit weight regularization = `(N1 * T1)^0.25 * noise_level`\n", "- `zeta_lambda`: Time weight regularization = `1e-6 * noise_level`\n", "\n", "You can override these with explicit values:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Compare different unit weight regularization levels\n", "results_list = []\n", "\n", "for zeta_omega in [0.1, 1.0, 10.0]:\n", " sdid_reg = SyntheticDiD(\n", " zeta_omega=zeta_omega,\n", " variance_method=\"placebo\",\n", " n_bootstrap=200,\n", " seed=42\n", " )\n", " \n", " res = sdid_reg.fit(\n", " df,\n", " outcome=\"outcome\",\n", " treatment=\"treated\",\n", " unit=\"unit\",\n", " time=\"period\",\n", " post_periods=list(range(n_pre, n_periods))\n", " )\n", " \n", " weights = list(res.unit_weights.values())\n", " eff_n = 1 / sum(w**2 for w in weights) if sum(w**2 for w in weights) > 0 else 0\n", " \n", " results_list.append({\n", " 'zeta_omega': zeta_omega,\n", " 'ATT': res.att,\n", " 'SE': res.se,\n", " 'Eff. N controls': eff_n,\n", " 'Pre-fit RMSE': res.pre_treatment_fit\n", " })\n", "\n", "reg_df = pd.DataFrame(results_list)\n", "print(\"Effect of unit weight regularization:\")\n", "print(reg_df.to_string(index=False))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 9. Single Treated Unit Case\n", "\n", "SDID is particularly useful when you have only **one treated unit** (like the classic synthetic control case)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Filter to single treated unit\n", "single_treated = df[(df['treated'] == 0) | (df['unit'] == n_control)].copy()\n", "\n", "print(f\"Single treated unit analysis:\")\n", "print(f\" Treated units: 1\")\n", "print(f\" Control units: {n_control}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Fit SDID with single treated unit\n", "sdid_single = SyntheticDiD(\n", " variance_method=\"placebo\",\n", " n_bootstrap=200,\n", " seed=42\n", ")\n", "\n", "results_single = sdid_single.fit(\n", " single_treated,\n", " outcome=\"outcome\",\n", " treatment=\"treated\",\n", " unit=\"unit\",\n", " time=\"period\",\n", " post_periods=list(range(n_pre, n_periods))\n", ")\n", "\n", "print(results_single.summary())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 10. Including Covariates\n", "\n", "You can include covariates to improve the synthetic control match." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Add covariates\n", "df['size'] = np.random.normal(100, 20, len(df))\n", "df['age'] = np.random.normal(10, 3, len(df))\n", "\n", "# Fit with covariates\n", "sdid_cov = SyntheticDiD(\n", " n_bootstrap=199,\n", " seed=42\n", ")\n", "\n", "results_cov = sdid_cov.fit(\n", " df,\n", " outcome=\"outcome\",\n", " treatment=\"treated\",\n", " unit=\"unit\",\n", " time=\"period\",\n", " post_periods=list(range(n_pre, n_periods)),\n", " covariates=[\"size\", \"age\"]\n", ")\n", "\n", "print(f\"With covariates:\")\n", "print(f\" ATT: {results_cov.att:.4f}\")\n", "print(f\" SE: {results_cov.se:.4f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": "## Summary\n\nKey takeaways for Synthetic DiD:\n\n1. **Best use cases**: Few treated units, many controls, long pre-period\n2. **Unit weights**: Identify which controls are most similar to treated (Frank-Wolfe with sparsification)\n3. **Time weights**: Determine which pre-periods are most informative (Frank-Wolfe on collapsed form)\n4. **Pre-treatment fit**: Lower RMSE indicates better synthetic match\n5. **Inference options**:\n - Placebo (`variance_method=\"placebo\"`, default): Placebo-based variance from controls. Library default (R's default is bootstrap; we deviate for survey availability + perf).\n - Bootstrap (`variance_method=\"bootstrap\"`): Paper-faithful pairs bootstrap re-estimating ω and λ via Frank-Wolfe per draw (Algorithm 2 step 2; matches R's default `vcov`). ~5–30× slower than placebo.\n - Jackknife (`variance_method=\"jackknife\"`): Algorithm 3 — fixed-weight leave-one-out.\n6. **Regularization**: Auto-computed from data noise level by default. Override with `zeta_omega`/`zeta_lambda`.\n\nReference:\n- Arkhangelsky, D., Athey, S., Hirshberg, D. A., Imbens, G. W., & Wager, S. (2021). Synthetic difference-in-differences. American Economic Review, 111(12), 4088-4118." } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 4 }