{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Testing Parallel Trends and DiD Diagnostics\n", "\n", "The **parallel trends assumption** is the key identifying assumption for Difference-in-Differences. It states that in the absence of treatment, treated and control groups would have followed the same trend.\n", "\n", "This notebook covers:\n", "1. Visual inspection of parallel trends\n", "2. Statistical tests for parallel trends\n", "3. Equivalence testing (TOST)\n", "4. Distributional comparison (Wasserstein)\n", "5. Placebo tests and diagnostics\n", "6. Sensitivity analysis" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "from diff_diff import DifferenceInDifferences, MultiPeriodDiD\n", "from diff_diff.utils import (\n", " check_parallel_trends,\n", " check_parallel_trends_robust,\n", " equivalence_test_trends\n", ")\n", "from diff_diff.diagnostics import (\n", " run_placebo_test,\n", " placebo_timing_test,\n", " placebo_group_test,\n", " permutation_test,\n", " run_all_placebo_tests\n", ")\n", "\n", "# For plots\n", "try:\n", " import matplotlib.pyplot as plt\n", " plt.style.use('seaborn-v0_8-whitegrid')\n", " HAS_MATPLOTLIB = True\n", "except ImportError:\n", " HAS_MATPLOTLIB = False\n", " print(\"matplotlib not installed - visualization examples will be skipped\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Create Example Data\n", "\n", "We'll create two datasets:\n", "- One where parallel trends **holds**\n", "- One where parallel trends is **violated**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Generate panel data using the library function\n", "from diff_diff import generate_panel_data\n", "\n", "# Generate data with parallel trends\n", "df_parallel = generate_panel_data(\n", " n_units=100,\n", " n_periods=8,\n", " treatment_period=4,\n", " treatment_fraction=0.5,\n", " treatment_effect=5.0,\n", " parallel_trends=True, # Parallel trends holds\n", " unit_fe_sd=2.0,\n", " noise_sd=0.5,\n", " seed=42\n", ")\n", "\n", "# Generate data with non-parallel trends (violation)\n", "df_nonparallel = generate_panel_data(\n", " n_units=100,\n", " n_periods=8,\n", " treatment_period=4,\n", " treatment_fraction=0.5,\n", " treatment_effect=5.0,\n", " parallel_trends=False, # Treated has steeper trend\n", " trend_violation=1.0, # Differential trend = 1.0 per period\n", " unit_fe_sd=2.0,\n", " noise_sd=0.5,\n", " seed=42\n", ")\n", "\n", "print(\"Generated two datasets:\")\n", "print(f\" - df_parallel: Parallel trends holds\")\n", "print(f\" - df_nonparallel: Parallel trends violated\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Visual Inspection\n", "\n", "The first step is always to **plot the data**. Look for:\n", "- Similar slopes in pre-treatment periods\n", "- Divergence only after treatment begins" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_trends(df, title, ax):\n", " \"\"\"Plot mean outcomes by group over time.\"\"\"\n", " means = df.groupby(['period', 'treated'])['outcome'].mean().unstack()\n", " \n", " treatment_time = df[df['post'] == 1]['period'].min()\n", " \n", " ax.plot(means.index, means[0], 'o-', label='Control', color='blue')\n", " ax.plot(means.index, means[1], 's-', label='Treated', color='red')\n", " ax.axvline(x=treatment_time - 0.5, color='gray', linestyle='--', \n", " label='Treatment')\n", " ax.set_xlabel('Period')\n", " ax.set_ylabel('Mean Outcome')\n", " ax.set_title(title)\n", " ax.legend()\n", "\n", "if HAS_MATPLOTLIB:\n", " fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n", " \n", " plot_trends(df_parallel, 'Parallel Trends Holds', axes[0])\n", " plot_trends(df_nonparallel, 'Parallel Trends Violated', axes[1])\n", " \n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Simple Parallel Trends Test\n", "\n", "The `check_parallel_trends()` function computes and compares the pre-treatment trends." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Test for parallel trends (parallel case)\n", "results_pt_parallel = check_parallel_trends(\n", " df_parallel,\n", " outcome='outcome',\n", " time='period',\n", " treatment_group='treated',\n", " pre_periods=[0, 1, 2, 3] # Pre-treatment periods\n", ")\n", "\n", "print(\"Parallel Trends Test (parallel case):\")\n", "print(\"=\" * 50)\n", "print(f\"Treated trend: {results_pt_parallel['treated_trend']:.4f} \"\n", " f\"(SE: {results_pt_parallel['treated_trend_se']:.4f})\")\n", "print(f\"Control trend: {results_pt_parallel['control_trend']:.4f} \"\n", " f\"(SE: {results_pt_parallel['control_trend_se']:.4f})\")\n", "print(f\"Difference: {results_pt_parallel['trend_difference']:.4f} \"\n", " f\"(SE: {results_pt_parallel['trend_difference_se']:.4f})\")\n", "print(f\"t-statistic: {results_pt_parallel['t_statistic']:.4f}\")\n", "print(f\"p-value: {results_pt_parallel['p_value']:.4f}\")\n", "print(f\"\\nParallel trends plausible: {results_pt_parallel['parallel_trends_plausible']}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Test for parallel trends (non-parallel case)\n", "results_pt_nonparallel = check_parallel_trends(\n", " df_nonparallel,\n", " outcome='outcome',\n", " time='period',\n", " treatment_group='treated',\n", " pre_periods=[0, 1, 2, 3]\n", ")\n", "\n", "print(\"\\nParallel Trends Test (non-parallel case):\")\n", "print(\"=\" * 50)\n", "print(f\"Treated trend: {results_pt_nonparallel['treated_trend']:.4f}\")\n", "print(f\"Control trend: {results_pt_nonparallel['control_trend']:.4f}\")\n", "print(f\"Difference: {results_pt_nonparallel['trend_difference']:.4f}\")\n", "print(f\"p-value: {results_pt_nonparallel['p_value']:.4f}\")\n", "print(f\"\\nParallel trends plausible: {results_pt_nonparallel['parallel_trends_plausible']}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Robust Parallel Trends Test (Wasserstein)\n", "\n", "The `check_parallel_trends_robust()` function uses the Wasserstein (Earth Mover's) distance to compare the **full distribution** of outcome changes, not just means." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Robust test (parallel case)\n", "results_robust_parallel = check_parallel_trends_robust(\n", " df_parallel,\n", " outcome='outcome',\n", " time='period',\n", " treatment_group='treated',\n", " unit='unit',\n", " pre_periods=[0, 1, 2, 3],\n", " n_permutations=999,\n", " seed=42\n", ")\n", "\n", "print(\"Robust Parallel Trends Test (parallel case):\")\n", "print(\"=\" * 50)\n", "print(f\"Wasserstein distance: {results_robust_parallel['wasserstein_distance']:.4f}\")\n", "print(f\"Wasserstein (normalized): {results_robust_parallel['wasserstein_normalized']:.4f}\")\n", "print(f\"Wasserstein p-value: {results_robust_parallel['wasserstein_p_value']:.4f}\")\n", "print(f\"KS statistic: {results_robust_parallel['ks_statistic']:.4f}\")\n", "print(f\"KS p-value: {results_robust_parallel['ks_p_value']:.4f}\")\n", "print(f\"Mean difference: {results_robust_parallel['mean_difference']:.4f}\")\n", "print(f\"Variance ratio: {results_robust_parallel['variance_ratio']:.4f}\")\n", "print(f\"\\nParallel trends plausible: {results_robust_parallel['parallel_trends_plausible']}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Robust test (non-parallel case)\n", "results_robust_nonparallel = check_parallel_trends_robust(\n", " df_nonparallel,\n", " outcome='outcome',\n", " time='period',\n", " treatment_group='treated',\n", " unit='unit',\n", " pre_periods=[0, 1, 2, 3],\n", " n_permutations=999,\n", " seed=42\n", ")\n", "\n", "print(\"\\nRobust Parallel Trends Test (non-parallel case):\")\n", "print(\"=\" * 50)\n", "print(f\"Wasserstein distance: {results_robust_nonparallel['wasserstein_distance']:.4f}\")\n", "print(f\"Wasserstein p-value: {results_robust_nonparallel['wasserstein_p_value']:.4f}\")\n", "print(f\"\\nParallel trends plausible: {results_robust_nonparallel['parallel_trends_plausible']}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if HAS_MATPLOTLIB:\n", " # Visualize the distribution of outcome changes\n", " fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n", " \n", " for i, (results, title) in enumerate([\n", " (results_robust_parallel, 'Parallel Trends'),\n", " (results_robust_nonparallel, 'Non-Parallel Trends')\n", " ]):\n", " ax = axes[i]\n", " ax.hist(results['treated_changes'], bins=20, alpha=0.5, \n", " label='Treated', color='red')\n", " ax.hist(results['control_changes'], bins=20, alpha=0.5, \n", " label='Control', color='blue')\n", " ax.set_xlabel('Outcome Change')\n", " ax.set_ylabel('Frequency')\n", " ax.set_title(f'{title}\\n(Wasserstein p={results[\"wasserstein_p_value\"]:.3f})')\n", " ax.legend()\n", " \n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Equivalence Testing (TOST)\n", "\n", "Standard hypothesis testing has **low power** to detect parallel trends. A better approach is **equivalence testing** using the Two One-Sided Tests (TOST) procedure.\n", "\n", "Instead of asking \"Can we reject that trends are different?\", we ask:\n", "\"Can we confirm that trend differences are smaller than some practically meaningful threshold?\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Equivalence test (parallel case)\n", "results_equiv_parallel = equivalence_test_trends(\n", " df_parallel,\n", " outcome='outcome',\n", " time='period',\n", " treatment_group='treated',\n", " unit='unit',\n", " pre_periods=[0, 1, 2, 3],\n", " equivalence_margin=0.5 # Differences < 0.5 are \"equivalent\"\n", ")\n", "\n", "print(\"Equivalence Test (parallel case):\")\n", "print(\"=\" * 50)\n", "print(f\"Mean difference: {results_equiv_parallel['mean_difference']:.4f}\")\n", "print(f\"SE: {results_equiv_parallel['se_difference']:.4f}\")\n", "print(f\"Equivalence margin: +/- {results_equiv_parallel['equivalence_margin']:.4f}\")\n", "print(f\"TOST p-value: {results_equiv_parallel['tost_p_value']:.4f}\")\n", "print(f\"\\nTrends are equivalent (at alpha=0.05): {results_equiv_parallel['equivalent']}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Equivalence test (non-parallel case)\n", "results_equiv_nonparallel = equivalence_test_trends(\n", " df_nonparallel,\n", " outcome='outcome',\n", " time='period',\n", " treatment_group='treated',\n", " unit='unit',\n", " pre_periods=[0, 1, 2, 3],\n", " equivalence_margin=0.5\n", ")\n", "\n", "print(\"\\nEquivalence Test (non-parallel case):\")\n", "print(\"=\" * 50)\n", "print(f\"Mean difference: {results_equiv_nonparallel['mean_difference']:.4f}\")\n", "print(f\"TOST p-value: {results_equiv_nonparallel['tost_p_value']:.4f}\")\n", "print(f\"\\nTrends are equivalent: {results_equiv_nonparallel['equivalent']}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. Placebo Tests\n", "\n", "Placebo tests check whether we would detect \"effects\" where none should exist. Types of placebo tests:\n", "\n", "1. **Timing placebo**: Pretend treatment happened earlier\n", "2. **Group placebo**: Estimate DiD on never-treated units only\n", "3. **Permutation test**: Randomly reassign treatment and see if effect persists" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# First, fit the main model\n", "did = DifferenceInDifferences()\n", "main_results = did.fit(\n", " df_parallel,\n", " outcome='outcome',\n", " treatment='treated',\n", " time='post'\n", ")\n", "\n", "print(\"Main DiD Results:\")\n", "print(f\"ATT: {main_results.att:.4f} (SE: {main_results.se:.4f})\")\n", "print(f\"p-value: {main_results.p_value:.4f}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Placebo timing test\n", "# Estimate DiD with a fake treatment time in pre-period\n", "placebo_timing = placebo_timing_test(\n", " df_parallel,\n", " outcome='outcome',\n", " treatment='treated',\n", " time='period',\n", " fake_treatment_period=2, # Pretend treatment at period 2\n", " post_periods=[4, 5, 6, 7] # Actual post-treatment periods to exclude\n", ")\n", "\n", "print(\"\\nPlacebo Timing Test:\")\n", "print(\"=\" * 50)\n", "print(f\"Placebo ATT: {placebo_timing.placebo_effect:.4f}\")\n", "print(f\"SE: {placebo_timing.se:.4f}\")\n", "print(f\"p-value: {placebo_timing.p_value:.4f}\")\n", "print(f\"\\nPass (effect not significant): {not placebo_timing.is_significant}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Placebo group test\n", "# Estimate DiD using only never-treated units (some randomly designated as \"fake treated\")\n", "# First, identify control units (never-treated)\n", "control_units = df_parallel[df_parallel['treated'] == 0]['unit'].unique()\n", "\n", "# Randomly select half of control units as \"fake treated\"\n", "np.random.seed(42)\n", "fake_treated = np.random.choice(control_units, size=len(control_units)//2, replace=False).tolist()\n", "\n", "placebo_group = placebo_group_test(\n", " df_parallel,\n", " outcome='outcome',\n", " time='period',\n", " unit='unit',\n", " fake_treated_units=fake_treated,\n", " post_periods=[4, 5, 6, 7] # Periods to use as post-treatment\n", ")\n", "\n", "print(\"\\nPlacebo Group Test:\")\n", "print(\"=\" * 50)\n", "print(f\"Placebo ATT: {placebo_group.placebo_effect:.4f}\")\n", "print(f\"SE: {placebo_group.se:.4f}\")\n", "print(f\"p-value: {placebo_group.p_value:.4f}\")\n", "print(f\"\\nPass (effect not significant): {not placebo_group.is_significant}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Permutation test\n", "perm_results = permutation_test(\n", " df_parallel,\n", " outcome='outcome',\n", " treatment='treated',\n", " time='post',\n", " unit='unit',\n", " n_permutations=999,\n", " seed=42\n", ")\n", "\n", "print(\"\\nPermutation Test:\")\n", "print(\"=\" * 50)\n", "print(f\"Observed ATT: {perm_results.placebo_effect:.4f}\")\n", "print(f\"Permutation p-value: {perm_results.p_value:.4f}\")\n", "print(f\"Number of permutations: {len(perm_results.permutation_distribution)}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if HAS_MATPLOTLIB:\n", " # Visualize permutation distribution\n", " fig, ax = plt.subplots(figsize=(10, 6))\n", " \n", " ax.hist(perm_results.permutation_distribution, bins=30, alpha=0.7, \n", " edgecolor='black', label='Permuted effects')\n", " ax.axvline(x=perm_results.placebo_effect, color='red', linewidth=2, \n", " linestyle='--', label=f'Observed = {perm_results.placebo_effect:.2f}')\n", " ax.axvline(x=0, color='gray', linewidth=1, linestyle=':')\n", " \n", " ax.set_xlabel('Effect')\n", " ax.set_ylabel('Frequency')\n", " ax.set_title(f'Permutation Test Distribution\\n(p-value = {perm_results.p_value:.3f})')\n", " ax.legend()\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7. Comprehensive Diagnostics\n", "\n", "Run all placebo tests at once with `run_all_placebo_tests()`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Run comprehensive diagnostics\n", "all_tests = run_all_placebo_tests(\n", " df_parallel,\n", " outcome='outcome',\n", " treatment='treated',\n", " time='period',\n", " unit='unit',\n", " pre_periods=[0, 1, 2, 3], # Pre-treatment periods\n", " post_periods=[4, 5, 6, 7], # Post-treatment periods\n", " n_permutations=499,\n", " seed=42\n", ")\n", "\n", "print(\"Comprehensive Placebo Test Results:\")\n", "print(\"=\" * 60)\n", "print(f\"{'Test':<25} {'Effect':>10} {'p-value':>10} {'Pass':>10}\")\n", "print(\"-\" * 60)\n", "\n", "for test_name, result in all_tests.items():\n", " if isinstance(result, dict) and 'error' in result:\n", " print(f\"{test_name:<25} {'ERROR':>10} {'-':>10} {result['error'][:20]}\")\n", " else:\n", " passed = not result.is_significant # Pass if NOT significant\n", " print(f\"{test_name:<25} {result.placebo_effect:>10.4f} {result.p_value:>10.4f} {str(passed):>10}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 8. Event Study as a Parallel Trends Check\n", "\n", "An **event study** shows period-by-period effects. Pre-treatment coefficients should be close to zero if parallel trends holds." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Event study\n", "mp_did = MultiPeriodDiD()\n", "event_results = mp_did.fit(\n", " df_parallel,\n", " outcome='outcome',\n", " treatment='treated',\n", " time='period',\n", " post_periods=[4, 5, 6, 7],\n", " reference_period=3 # Use period 3 as reference\n", ")\n", "\n", "print(event_results.summary())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from diff_diff.visualization import plot_event_study\n", "\n", "if HAS_MATPLOTLIB:\n", " fig, ax = plt.subplots(figsize=(10, 6))\n", " plot_event_study(\n", " results=event_results,\n", " ax=ax,\n", " title='Event Study: Check Pre-trends',\n", " xlabel='Period',\n", " ylabel='Effect'\n", " )\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 9. What to Do If Parallel Trends Fails?\n", "\n", "If parallel trends is violated, consider:\n", "\n", "1. **Add covariates** that might explain differential trends\n", "2. **Use Synthetic DiD** which is more robust to trend differences\n", "3. **Use bounds/sensitivity analysis** (Rambachan-Roth)\n", "4. **Consider alternative designs** (RDD, IV, etc.)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Example: Compare standard DiD vs Synthetic DiD on non-parallel data\n", "from diff_diff import SyntheticDiD\n", "\n", "# Standard DiD (biased when trends differ)\n", "did_np = DifferenceInDifferences()\n", "results_did_np = did_np.fit(\n", " df_nonparallel,\n", " outcome='outcome',\n", " treatment='treated',\n", " time='post'\n", ")\n", "\n", "# Synthetic DiD (may be less biased)\n", "sdid = SyntheticDiD(n_bootstrap=99, seed=42)\n", "results_sdid = sdid.fit(\n", " df_nonparallel,\n", " outcome='outcome',\n", " treatment='treated',\n", " unit='unit',\n", " time='period',\n", " post_periods=[4, 5, 6, 7]\n", ")\n", "\n", "print(\"Comparison on Non-Parallel Trends Data\")\n", "print(\"=\" * 50)\n", "print(f\"True ATT: 5.0\")\n", "print(f\"\")\n", "print(f\"Standard DiD:\")\n", "print(f\" ATT: {results_did_np.att:.4f} (Bias: {results_did_np.att - 5.0:.4f})\")\n", "print(f\"\")\n", "print(f\"Synthetic DiD:\")\n", "print(f\" ATT: {results_sdid.att:.4f} (Bias: {results_sdid.att - 5.0:.4f})\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\n", "\n", "**Key takeaways for parallel trends testing:**\n", "\n", "1. **Always visualize** the data first\n", "\n", "2. **Simple tests** (`check_parallel_trends`):\n", " - Compare pre-treatment slopes\n", " - Easy to interpret but limited\n", "\n", "3. **Robust tests** (`check_parallel_trends_robust`):\n", " - Compare full distributions with Wasserstein distance\n", " - More powerful for detecting violations\n", "\n", "4. **Equivalence testing** (`equivalence_test_trends`):\n", " - Tests whether differences are practically small\n", " - Better than \"failing to reject\" parallel trends\n", "\n", "5. **Placebo tests**:\n", " - Timing: Fake treatment in pre-period\n", " - Group: DiD on never-treated only\n", " - Permutation: Randomize treatment assignment\n", "\n", "6. **Event studies** show pre-treatment coefficients should be ~0\n", "\n", "7. **If parallel trends fails**, consider:\n", " - Adding covariates\n", " - Synthetic DiD\n", " - Sensitivity analysis\n", " - Alternative identification strategies" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 4 }