Source code for diff_diff.visualization._diagnostic

"""Diagnostic visualization functions (sensitivity, Bacon decomposition)."""

from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

import numpy as np

if TYPE_CHECKING:
    from diff_diff.bacon import BaconDecompositionResults
    from diff_diff.honest_did import SensitivityResults


[docs] def plot_sensitivity( sensitivity_results: "SensitivityResults", *, show_bounds: bool = True, show_ci: bool = True, breakdown_line: bool = True, figsize: Tuple[float, float] = (10, 6), title: str = "Honest DiD Sensitivity Analysis", xlabel: str = "M (restriction parameter)", ylabel: str = "Treatment Effect", bounds_color: str = "#2563eb", bounds_alpha: float = 0.3, ci_color: str = "#2563eb", ci_linewidth: float = 1.5, breakdown_color: str = "#dc2626", original_color: str = "#1f2937", ax: Optional[Any] = None, show: bool = True, backend: str = "matplotlib", ) -> Any: """ Plot sensitivity analysis results from Honest DiD. Shows how treatment effect bounds and confidence intervals change as the restriction parameter M varies. Parameters ---------- sensitivity_results : SensitivityResults Results from HonestDiD.sensitivity_analysis(). show_bounds : bool, default=True Whether to show the identified set bounds as shaded region. show_ci : bool, default=True Whether to show robust confidence interval lines. breakdown_line : bool, default=True Whether to show vertical line at breakdown value. figsize : tuple, default=(10, 6) Figure size (width, height) in inches. title : str Plot title. xlabel : str X-axis label. ylabel : str Y-axis label. bounds_color : str Color for identified set shading. bounds_alpha : float Transparency for identified set shading. ci_color : str Color for confidence interval lines. ci_linewidth : float Line width for CI lines. breakdown_color : str Color for breakdown value line. original_color : str Color for original estimate line. ax : matplotlib.axes.Axes, optional Axes to plot on. If None, creates new figure. show : bool, default=True Whether to call plt.show(). backend : str, default="matplotlib" Plotting backend: ``"matplotlib"`` or ``"plotly"``. Returns ------- matplotlib.axes.Axes or plotly.graph_objects.Figure The axes object (matplotlib) or figure (plotly). Examples -------- >>> from diff_diff import MultiPeriodDiD >>> from diff_diff.honest_did import HonestDiD >>> from diff_diff.visualization import plot_sensitivity >>> >>> # Fit event study and run sensitivity analysis >>> results = MultiPeriodDiD().fit(data, ...) >>> honest = HonestDiD(method='relative_magnitude') >>> sensitivity = honest.sensitivity_analysis(results) >>> >>> # Create sensitivity plot >>> plot_sensitivity(sensitivity) """ M = sensitivity_results.M_values bounds_arr = np.array(sensitivity_results.bounds) ci_arr = np.array(sensitivity_results.robust_cis) if backend == "plotly": return _render_sensitivity_plotly( M=M, bounds_arr=bounds_arr, ci_arr=ci_arr, original_estimate=sensitivity_results.original_estimate, breakdown_M=sensitivity_results.breakdown_M, show_bounds=show_bounds, show_ci=show_ci, breakdown_line=breakdown_line, title=title, xlabel=xlabel, ylabel=ylabel, bounds_color=bounds_color, bounds_alpha=bounds_alpha, ci_color=ci_color, ci_linewidth=ci_linewidth, breakdown_color=breakdown_color, original_color=original_color, show=show, ) return _render_sensitivity_mpl( M=M, bounds_arr=bounds_arr, ci_arr=ci_arr, original_estimate=sensitivity_results.original_estimate, breakdown_M=sensitivity_results.breakdown_M, show_bounds=show_bounds, show_ci=show_ci, breakdown_line=breakdown_line, figsize=figsize, title=title, xlabel=xlabel, ylabel=ylabel, bounds_color=bounds_color, bounds_alpha=bounds_alpha, ci_color=ci_color, ci_linewidth=ci_linewidth, breakdown_color=breakdown_color, original_color=original_color, ax=ax, show=show, )
def _render_sensitivity_mpl( *, M, bounds_arr, ci_arr, original_estimate, breakdown_M, show_bounds, show_ci, breakdown_line, figsize, title, xlabel, ylabel, bounds_color, bounds_alpha, ci_color, ci_linewidth, breakdown_color, original_color, ax, show, ): """Render sensitivity plot with matplotlib.""" from diff_diff.visualization._common import _require_matplotlib plt = _require_matplotlib() if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.get_figure() # Plot original estimate ax.axhline( y=original_estimate, color=original_color, linestyle="-", linewidth=1.5, label="Original estimate", alpha=0.7, ) # Plot zero line ax.axhline(y=0, color="gray", linestyle="--", linewidth=1, alpha=0.5) # Plot identified set bounds if show_bounds: ax.fill_between( M, bounds_arr[:, 0], bounds_arr[:, 1], alpha=bounds_alpha, color=bounds_color, label="Identified set", ) # Plot confidence intervals if show_ci: ax.plot(M, ci_arr[:, 0], color=ci_color, linewidth=ci_linewidth, label="Robust CI") ax.plot(M, ci_arr[:, 1], color=ci_color, linewidth=ci_linewidth) # Plot breakdown line if breakdown_line and breakdown_M is not None: ax.axvline( x=breakdown_M, color=breakdown_color, linestyle=":", linewidth=2, label=f"Breakdown (M={breakdown_M:.2f})", ) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title) ax.legend(loc="best") ax.grid(True, alpha=0.3) fig.tight_layout() if show: plt.show() return ax def _render_sensitivity_plotly( *, M, bounds_arr, ci_arr, original_estimate, breakdown_M, show_bounds, show_ci, breakdown_line, title, xlabel, ylabel, bounds_color, bounds_alpha, ci_color, ci_linewidth, breakdown_color, original_color, show, ): """Render sensitivity plot with plotly.""" from diff_diff.visualization._common import ( _color_to_rgba, _plotly_default_layout, _require_plotly, ) go = _require_plotly() fig = go.Figure() M_list = list(M) if not isinstance(M, list) else M # Original estimate line fig.add_hline( y=original_estimate, line_color=original_color, line_width=1.5, opacity=0.7, annotation_text="Original estimate", ) # Zero line fig.add_hline(y=0, line_dash="dash", line_color="gray", line_width=1, opacity=0.5) # Identified set bounds if show_bounds: fig.add_trace( go.Scatter( x=M_list + M_list[::-1], y=list(bounds_arr[:, 1]) + list(bounds_arr[:, 0])[::-1], fill="toself", fillcolor=_color_to_rgba(bounds_color, bounds_alpha), line=dict(color="rgba(0,0,0,0)"), name="Identified set", ) ) # Confidence intervals if show_ci: fig.add_trace( go.Scatter( x=M_list, y=list(ci_arr[:, 0]), mode="lines", line=dict(color=ci_color, width=ci_linewidth), name="Robust CI", ) ) fig.add_trace( go.Scatter( x=M_list, y=list(ci_arr[:, 1]), mode="lines", line=dict(color=ci_color, width=ci_linewidth), showlegend=False, ) ) # Breakdown line if breakdown_line and breakdown_M is not None: fig.add_vline( x=breakdown_M, line_dash="dot", line_color=breakdown_color, line_width=2, annotation_text=f"Breakdown (M={breakdown_M:.2f})", ) _plotly_default_layout(fig, title=title, xlabel=xlabel, ylabel=ylabel) if show: fig.show() return fig
[docs] def plot_bacon( results: "BaconDecompositionResults", *, plot_type: str = "scatter", figsize: Tuple[float, float] = (10, 6), title: Optional[str] = None, xlabel: str = "2x2 DiD Estimate", ylabel: str = "Weight", colors: Optional[Dict[str, str]] = None, marker: str = "o", markersize: int = 80, alpha: float = 0.7, show_weighted_avg: bool = True, show_twfe_line: bool = True, ax: Optional[Any] = None, show: bool = True, backend: str = "matplotlib", ) -> Any: """ Visualize Goodman-Bacon decomposition results. Creates either a scatter plot showing the weight and estimate for each 2x2 comparison, or a stacked bar chart showing total weight by comparison type. Parameters ---------- results : BaconDecompositionResults Results from BaconDecomposition.fit() or bacon_decompose(). plot_type : str, default="scatter" Type of plot to create: - "scatter": Scatter plot with estimates on x-axis, weights on y-axis - "bar": Stacked bar chart of weights by comparison type figsize : tuple, default=(10, 6) Figure size (width, height) in inches. title : str, optional Plot title. If None, uses a default based on plot_type. xlabel : str, default="2x2 DiD Estimate" X-axis label (scatter plot only). ylabel : str, default="Weight" Y-axis label. colors : dict, optional Dictionary mapping comparison types to colors. Keys are: "treated_vs_never", "earlier_vs_later", "later_vs_earlier". If None, uses default colors. marker : str, default="o" Marker style for scatter plot. markersize : int, default=80 Marker size for scatter plot. alpha : float, default=0.7 Transparency for markers/bars. show_weighted_avg : bool, default=True Whether to show weighted average lines for each comparison type (scatter plot only). show_twfe_line : bool, default=True Whether to show a vertical line at the TWFE estimate (scatter plot only). ax : matplotlib.axes.Axes, optional Axes to plot on. If None, creates new figure. show : bool, default=True Whether to call plt.show() at the end. backend : str, default="matplotlib" Plotting backend: ``"matplotlib"`` or ``"plotly"``. Returns ------- matplotlib.axes.Axes or plotly.graph_objects.Figure The axes object (matplotlib) or figure (plotly). Examples -------- Scatter plot (default): >>> from diff_diff import bacon_decompose, plot_bacon >>> results = bacon_decompose(data, outcome='y', unit='id', ... time='t', first_treat='first_treat') >>> plot_bacon(results) Bar chart of weights by type: >>> plot_bacon(results, plot_type='bar') Notes ----- The scatter plot is particularly useful for understanding: 1. **Distribution of estimates**: Are 2x2 estimates clustered or spread? Wide spread suggests heterogeneous treatment effects. 2. **Weight concentration**: Do a few comparisons dominate the TWFE? Points with high weights have more influence. 3. **Forbidden comparison problem**: Red points (later_vs_earlier) show comparisons using already-treated units as controls. If these have different estimates than clean comparisons, TWFE may be biased. See Also -------- bacon_decompose : Perform the decomposition BaconDecomposition : Class-based interface """ # Default colors if colors is None: colors = { "treated_vs_never": "#22c55e", # Green - clean comparison "earlier_vs_later": "#3b82f6", # Blue - valid comparison "later_vs_earlier": "#ef4444", # Red - forbidden comparison } # Default titles if title is None: if plot_type == "scatter": title = "Goodman-Bacon Decomposition" else: title = "TWFE Weight by Comparison Type" if plot_type not in ("scatter", "bar"): raise ValueError(f"Unknown plot_type: {plot_type}. Use 'scatter' or 'bar'.") if backend == "plotly": return _render_bacon_plotly( results=results, plot_type=plot_type, title=title, xlabel=xlabel, ylabel=ylabel, colors=colors, marker=marker, markersize=markersize, alpha=alpha, show_weighted_avg=show_weighted_avg, show_twfe_line=show_twfe_line, show=show, ) return _render_bacon_mpl( results=results, plot_type=plot_type, figsize=figsize, title=title, xlabel=xlabel, ylabel=ylabel, colors=colors, marker=marker, markersize=markersize, alpha=alpha, show_weighted_avg=show_weighted_avg, show_twfe_line=show_twfe_line, ax=ax, show=show, )
def _render_bacon_mpl( *, results, plot_type, figsize, title, xlabel, ylabel, colors, marker, markersize, alpha, show_weighted_avg, show_twfe_line, ax, show, ): """Render Bacon decomposition plot with matplotlib.""" from diff_diff.visualization._common import _require_matplotlib plt = _require_matplotlib() if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.get_figure() if plot_type == "scatter": _plot_bacon_scatter( ax, results, colors, marker, markersize, alpha, show_weighted_avg, show_twfe_line, xlabel, ylabel, title, ) else: _plot_bacon_bar(ax, results, colors, alpha, ylabel, title) fig.tight_layout() if show: plt.show() return ax def _plot_bacon_scatter( ax: Any, results: "BaconDecompositionResults", colors: Dict[str, str], marker: str, markersize: int, alpha: float, show_weighted_avg: bool, show_twfe_line: bool, xlabel: str, ylabel: str, title: str, ) -> None: """Create scatter plot of Bacon decomposition.""" # Separate comparisons by type by_type: Dict[str, List[Tuple[float, float]]] = { "treated_vs_never": [], "earlier_vs_later": [], "later_vs_earlier": [], } for comp in results.comparisons: by_type[comp.comparison_type].append((comp.estimate, comp.weight)) # Plot each type labels = { "treated_vs_never": "Treated vs Never-treated", "earlier_vs_later": "Earlier vs Later treated", "later_vs_earlier": "Later vs Earlier (forbidden)", } for ctype, points in by_type.items(): if not points: continue estimates = [p[0] for p in points] weights = [p[1] for p in points] ax.scatter( estimates, weights, c=colors[ctype], label=labels[ctype], marker=marker, s=markersize, alpha=alpha, edgecolors="white", linewidths=0.5, ) # Show weighted average lines if show_weighted_avg: effect_by_type = results.effect_by_type() for ctype, avg_effect in effect_by_type.items(): if avg_effect is not None and by_type[ctype]: ax.axvline( x=avg_effect, color=colors[ctype], linestyle="--", alpha=0.5, linewidth=1.5, ) # Show TWFE estimate line if show_twfe_line: ax.axvline( x=results.twfe_estimate, color="black", linestyle="-", linewidth=2, label=f"TWFE = {results.twfe_estimate:.4f}", ) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title) ax.legend(loc="best") ax.grid(True, alpha=0.3) # Add zero line ax.axvline(x=0, color="gray", linestyle=":", alpha=0.5) def _plot_bacon_bar( ax: Any, results: "BaconDecompositionResults", colors: Dict[str, str], alpha: float, ylabel: str, title: str, ) -> None: """Create stacked bar chart of weights by comparison type.""" # Get weights weights = results.weight_by_type() # Labels and colors type_order = ["treated_vs_never", "earlier_vs_later", "later_vs_earlier"] labels = { "treated_vs_never": "Treated vs Never-treated", "earlier_vs_later": "Earlier vs Later", "later_vs_earlier": "Later vs Earlier\n(forbidden)", } # Create bar data bar_labels = [labels[t] for t in type_order] bar_weights = [weights[t] for t in type_order] bar_colors = [colors[t] for t in type_order] # Create bars bars = ax.bar( bar_labels, bar_weights, color=bar_colors, alpha=alpha, edgecolor="white", linewidth=1, ) # Add percentage labels on bars for bar, weight in zip(bars, bar_weights): if weight > 0.01: # Only label if > 1% height = bar.get_height() ax.annotate( f"{weight:.1%}", xy=(bar.get_x() + bar.get_width() / 2, height), xytext=(0, 3), textcoords="offset points", ha="center", va="bottom", fontsize=10, fontweight="bold", ) # Add weighted average effect annotations effects = results.effect_by_type() for bar, ctype in zip(bars, type_order): effect = effects[ctype] if effect is not None and weights[ctype] > 0.01: ax.annotate( f"β = {effect:.3f}", xy=(bar.get_x() + bar.get_width() / 2, bar.get_height() / 2), ha="center", va="center", fontsize=9, color="white", fontweight="bold", ) ax.set_ylabel(ylabel) ax.set_title(title) ax.set_ylim(0, 1.1) # Add horizontal line at total weight = 1 ax.axhline(y=1.0, color="gray", linestyle="--", alpha=0.5) # Add TWFE estimate as text ax.text( 0.98, 0.98, f"TWFE = {results.twfe_estimate:.4f}", transform=ax.transAxes, ha="right", va="top", fontsize=10, bbox=dict(boxstyle="round", facecolor="white", alpha=0.8), ) def _render_bacon_plotly( *, results, plot_type, title, xlabel, ylabel, colors, marker, markersize, alpha, show_weighted_avg, show_twfe_line, show, ): """Render Bacon decomposition plot with plotly.""" from diff_diff.visualization._common import ( _mpl_marker_to_plotly_symbol, _plotly_default_layout, _require_plotly, ) go = _require_plotly() fig = go.Figure() if plot_type == "scatter": # Separate comparisons by type by_type = { "treated_vs_never": [], "earlier_vs_later": [], "later_vs_earlier": [], } for comp in results.comparisons: by_type[comp.comparison_type].append((comp.estimate, comp.weight)) labels = { "treated_vs_never": "Treated vs Never-treated", "earlier_vs_later": "Earlier vs Later treated", "later_vs_earlier": "Later vs Earlier (forbidden)", } # Convert matplotlib scatter area (points^2) to plotly diameter (px) plotly_size = max(1, int(round(markersize**0.5))) symbol = _mpl_marker_to_plotly_symbol(marker) for ctype, points in by_type.items(): if not points: continue estimates = [p[0] for p in points] weights = [p[1] for p in points] fig.add_trace( go.Scatter( x=estimates, y=weights, mode="markers", marker=dict( color=colors[ctype], size=plotly_size, symbol=symbol, opacity=alpha, ), name=labels[ctype], ) ) # Weighted average lines if show_weighted_avg: effect_by_type = results.effect_by_type() for ctype, avg_effect in effect_by_type.items(): if avg_effect is not None and by_type[ctype]: fig.add_vline( x=avg_effect, line_dash="dash", line_color=colors[ctype], opacity=0.5, line_width=1.5, ) # TWFE line if show_twfe_line: fig.add_vline( x=results.twfe_estimate, line_color="black", line_width=2, annotation_text=f"TWFE = {results.twfe_estimate:.4f}", ) # Zero line fig.add_vline(x=0, line_dash="dot", line_color="gray", opacity=0.5) _plotly_default_layout(fig, title=title, xlabel=xlabel, ylabel=ylabel) else: # bar weights = results.weight_by_type() type_order = ["treated_vs_never", "earlier_vs_later", "later_vs_earlier"] labels = { "treated_vs_never": "Treated vs Never-treated", "earlier_vs_later": "Earlier vs Later", "later_vs_earlier": "Later vs Earlier (forbidden)", } fig.add_trace( go.Bar( x=[labels[t] for t in type_order], y=[weights[t] for t in type_order], marker_color=[colors[t] for t in type_order], opacity=alpha, text=[f"{weights[t]:.1%}" for t in type_order], textposition="outside", ) ) fig.update_layout(yaxis_range=[0, 1.1]) _plotly_default_layout(fig, title=title, xlabel=None, ylabel=ylabel, show_legend=False) if show: fig.show() return fig