Source code for diff_diff.visualization._synthetic

"""Synthetic control visualization functions."""

from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple

if TYPE_CHECKING:
    from diff_diff.results import SyntheticDiDResults


[docs] def plot_synth_weights( results: Optional["SyntheticDiDResults"] = None, *, weights: Optional[Dict[Any, float]] = None, weight_type: str = "unit", top_n: Optional[int] = None, min_weight: float = 0.001, figsize: Tuple[float, float] = (10, 6), title: Optional[str] = None, color: str = "#2563eb", ax: Optional[Any] = None, show: bool = True, backend: str = "matplotlib", ) -> Any: """ Plot synthetic control weights as a bar chart. Visualizes the unit weights or time weights from a Synthetic Difference-in-Differences estimation. Parameters ---------- results : SyntheticDiDResults, optional Results from SyntheticDiD estimator. Extracts weights based on ``weight_type``. weights : dict, optional Dictionary mapping unit/period IDs to weights. Used if results is None. weight_type : str, default="unit" Which weights to plot: ``"unit"`` for control unit weights or ``"time"`` for pre-treatment time weights. top_n : int, optional Show only the top N weights by magnitude. Useful when there are many control units. min_weight : float, default=0.001 Minimum weight threshold for display. figsize : tuple, default=(10, 6) Figure size (width, height) in inches. title : str, optional Plot title. If None, auto-generated based on ``weight_type``. color : str, default="#2563eb" Bar color. ax : matplotlib.axes.Axes, optional Axes to plot on. If None, creates new figure. show : bool, default=True Whether to call plt.show() at the end. backend : str, default="matplotlib" Plotting backend: ``"matplotlib"`` or ``"plotly"``. Returns ------- matplotlib.axes.Axes or plotly.graph_objects.Figure The axes object (matplotlib) or figure (plotly). """ # Extract weights if results is not None and weights is not None: raise ValueError("Provide either 'results' or 'weights', not both.") if results is not None: if weight_type == "unit": weights = results.unit_weights elif weight_type == "time": weights = results.time_weights else: raise ValueError(f"weight_type must be 'unit' or 'time', got '{weight_type}'") if weights is None: raise ValueError("Must provide either 'results' or 'weights'.") if not weights: raise ValueError("No weights available to plot.") # Filter by min_weight filtered = {k: v for k, v in weights.items() if abs(v) >= min_weight} if not filtered: raise ValueError(f"No weights >= {min_weight} to plot.") # Sort by weight descending sorted_items = sorted(filtered.items(), key=lambda x: x[1], reverse=True) # Apply top_n limit if top_n is not None: sorted_items = sorted_items[:top_n] labels = [str(k) for k, _ in sorted_items] values = [v for _, v in sorted_items] # Auto-generate title if title is None: if weight_type == "unit": title = "Synthetic Control Unit Weights" else: title = "Synthetic Control Time Weights" if backend == "plotly": return _render_synth_weights_plotly( labels=labels, values=values, title=title, color=color, weight_type=weight_type, show=show, ) return _render_synth_weights_mpl( labels=labels, values=values, figsize=figsize, title=title, color=color, weight_type=weight_type, ax=ax, show=show, )
def _render_synth_weights_mpl(*, labels, values, figsize, title, color, weight_type, ax, show): """Render synthetic control weights with matplotlib.""" from diff_diff.visualization._common import _require_matplotlib plt = _require_matplotlib() if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.get_figure() # Horizontal bar chart y_pos = range(len(labels)) ax.barh(y_pos, values, color=color, alpha=0.8, edgecolor="white") ax.set_yticks(y_pos) ax.set_yticklabels(labels) ax.invert_yaxis() # Highest weight at top xlabel = "Weight" ylabel = "Control Unit" if weight_type == "unit" else "Time Period" ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title) ax.grid(True, alpha=0.3, axis="x") # Add value labels on bars for i, v in enumerate(values): ax.text(v + 0.005, i, f"{v:.4f}", va="center", fontsize=9) fig.tight_layout() if show: plt.show() return ax def _render_synth_weights_plotly(*, labels, values, title, color, weight_type, show): """Render synthetic control weights with plotly.""" from diff_diff.visualization._common import _plotly_default_layout, _require_plotly go = _require_plotly() fig = go.Figure() fig.add_trace( go.Bar( y=labels, x=values, orientation="h", marker_color=color, opacity=0.8, text=[f"{v:.4f}" for v in values], textposition="outside", ) ) ylabel = "Control Unit" if weight_type == "unit" else "Time Period" _plotly_default_layout( fig, title=title, xlabel="Weight", ylabel=ylabel, show_legend=False, ) # Reverse y-axis so highest weight is at top fig.update_yaxes(autorange="reversed") if show: fig.show() return fig