Source code for spikelab.spikedata.plot_utils

"""
Plotting utilities for SpikeLab.

Provides ``plot_recording`` for assembling multi-panel figures from SpikeData
objects, ``plot_heatmap`` for standalone 2-D heatmaps, ``plot_distribution``
for comparing per-unit metrics across conditions, ``plot_pvalue_matrix`` for
significance heatmaps (standalone or as inset), ``plot_scatter`` for pairwise
comparisons with optional regression, ``plot_lines`` for multi-trace line
plots, ``plot_burst_sensitivity`` for threshold sensitivity curves,
``plot_scatter_with_marginals`` for scatter plots with marginal histograms,
``plot_aligned_slice_single_unit`` for event-aligned single-unit raster plots,
and ``plot_spatial_network`` for MEA spatial network visualisation.

Requires ``matplotlib`` (optional dependency).
"""

import numpy as np


def _import_matplotlib():
    """Import matplotlib and return (plt, mticker). Raises ImportError with message."""
    try:
        import matplotlib.pyplot as plt
        import matplotlib.ticker as mticker

        return plt, mticker
    except ImportError as e:
        raise ImportError(
            "plot_utils requires 'matplotlib'. " "Install with: pip install matplotlib"
        ) from e


def _add_colorbar(im, ax, label="", font_size=None, size="3%", pad=0.05):
    """Add a colorbar on a dedicated axes so the parent axes width is unchanged.

    Uses ``make_axes_locatable`` to append a thin axes to the right of ax.
    This avoids the width-stealing behaviour of ``fig.colorbar(im, ax=ax)``.
    When font_size is None, matplotlib rcParams control the label and tick
    sizes (``axes.labelsize`` and ``xtick.labelsize``).

    Parameters:
        im: The mappable artist (e.g. image returned by ``imshow``).
        ax (matplotlib.axes.Axes): Parent axes to attach the colorbar to.
        label (str): Colorbar label text.
        font_size (int or None): Font size for the label and tick labels.
            If None, uses matplotlib rcParams defaults.
        size (str): Width of the colorbar axes as a percentage of the
            parent axes (e.g. ``"3%"``).
        pad (float): Padding between the parent axes and the colorbar axes.

    Returns:
        cb (matplotlib.colorbar.Colorbar): The colorbar instance.
    """
    import matplotlib as mpl
    from mpl_toolkits.axes_grid1 import make_axes_locatable

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size=size, pad=pad)
    cb = ax.figure.colorbar(im, cax=cax)
    # Render colorbar at full opacity even when the mappable has alpha < 1
    if cb.solids is not None:
        cb.solids.set_alpha(1.0)
    fs = font_size or mpl.rcParams["axes.labelsize"]
    cb.set_label(label, fontsize=fs)
    cb.ax.tick_params(labelsize=fs)
    return cb


def _apply_font_size(ax, font_size):
    """Apply font_size to axis labels and tick labels.

    Parameters:
        ax (matplotlib.axes.Axes): Target axes.
        font_size (int or float): Font size to apply.
    """
    ax.xaxis.label.set_fontsize(font_size)
    ax.yaxis.label.set_fontsize(font_size)
    ax.tick_params(axis="both", labelsize=font_size)


def _style_axes(ax):
    """Apply default axis styling: remove top/right spines, outward ticks.

    Parameters:
        ax (matplotlib.axes.Axes): Target axes.
    """
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.tick_params(axis="both", direction="out")


def _style_axes_heatmap(ax):
    """Apply heatmap axis styling: all four spines at 0.5 pt, outward ticks.

    Parameters:
        ax (matplotlib.axes.Axes): Target axes.
    """
    for spine in ax.spines.values():
        spine.set_linewidth(0.5)
    ax.tick_params(axis="both", direction="out")


# ---------------------------------------------------------------------------
# plot_distribution
# ---------------------------------------------------------------------------


[docs] def plot_distribution( ax, metric_data, labels=None, colors=None, ylabel="", xlabel="", style="violin", show_median=True, show_quartiles=True, show_data=False, data_alpha=0.3, data_size=4, log_scale=False, font_size=None, ): """Plot distributions of a per-unit metric across multiple groups/conditions. Parameters: ax (matplotlib.axes.Axes): Target axes (caller creates). metric_data (dict[str, np.ndarray] or list[np.ndarray]): Condition-labelled or ordered collection of per-unit value arrays. NaN values are stripped automatically before plotting. labels (list[str] or None): Ordered condition labels. If None, uses dict keys (for dict input) or integer indices (for list input). colors (list[str] or None): Per-condition colours. If None, uses the default matplotlib colour cycle. ylabel (str): Y-axis label. xlabel (str): X-axis label. style (str): ``"violin"`` (default) or ``"boxplot"``. show_median (bool): Overlay a median dot on each distribution. show_quartiles (bool): Overlay IQR lines (25th-75th percentile) on each distribution. show_data (bool): Overlay individual data points on each distribution, jittered horizontally to reduce overlap. data_alpha (float): Alpha transparency for overlaid data points. data_size (float): Marker size for overlaid data points. log_scale (bool): Use a log scale on the y-axis. font_size (int or None): Font size for labels and ticks. If None, uses current rcParams. Returns: parts (dict): The violin or boxplot artist dict returned by matplotlib (``violinplot`` or ``boxplot``). Notes: In violin mode, groups with fewer than 2 data points cannot produce a kernel density estimate. These groups are rendered as individual scatter points instead and excluded from the violin plot. """ _import_matplotlib() # --- Normalise input to list-of-arrays + labels ----------------------- if isinstance(metric_data, dict): keys = list(metric_data.keys()) data_arrays = [np.asarray(metric_data[k]) for k in keys] if labels is None: labels = keys else: data_arrays = [np.asarray(a) for a in metric_data] if labels is None: labels = [str(i) for i in range(len(data_arrays))] # Strip NaNs from each array clean_data = [] for arr in data_arrays: flat = arr.ravel() clean_data.append(flat[~np.isnan(flat)]) n = len(clean_data) positions = list(range(n)) # --- Resolve colours -------------------------------------------------- if colors is None: import matplotlib.pyplot as _plt cycle_colors = _plt.rcParams["axes.prop_cycle"].by_key()["color"] colors = [cycle_colors[i % len(cycle_colors)] for i in range(n)] # --- Draw distribution ------------------------------------------------ if style == "boxplot": parts = ax.boxplot( clean_data, positions=positions, widths=0.6, patch_artist=True, showfliers=True, ) for i, box in enumerate(parts["boxes"]): box.set_facecolor(colors[i]) box.set_alpha(0.8) elif style == "violin": # Separate groups with enough points for KDE from sparse groups violin_positions = [] violin_data = [] sparse_groups = [] # (position, data, color) for groups with < 2 points for i, d in enumerate(clean_data): if len(d) >= 2: violin_positions.append(positions[i]) violin_data.append(d) else: sparse_groups.append((positions[i], d, colors[i])) parts = {"bodies": []} if violin_data: parts = ax.violinplot( violin_data, positions=violin_positions, showmeans=False, showextrema=False, ) # Map violin bodies back to their colour by position index pos_to_color = {p: colors[p] for p in violin_positions} for body, pos in zip(parts["bodies"], violin_positions): body.set_facecolor(pos_to_color[pos]) body.set_edgecolor("black") body.set_linewidth(0.5) body.set_alpha(0.8) # Render sparse groups as scatter points for pos, d, color in sparse_groups: if len(d) > 0: ax.scatter( np.full(len(d), pos), d, color=color, s=data_size * 4, zorder=3, edgecolors="black", linewidths=0.5, ) else: raise ValueError(f"Unknown style '{style}'. Use 'violin' or 'boxplot'.") # --- Median dot + IQR lines ------------------------------------------- if show_median or show_quartiles: for i, d in enumerate(clean_data): if len(d) == 0: continue q25, median, q75 = np.nanpercentile(d, [25, 50, 75]) if show_median: ax.scatter( i, median, color="white", s=15, zorder=4, edgecolors="black", linewidths=0.5, ) if show_quartiles: ax.vlines(i, q25, q75, color="black", linewidth=1.5, zorder=3) # --- Overlay individual data points ----------------------------------- if show_data: rng = np.random.default_rng(0) for i, d in enumerate(clean_data): if len(d) == 0: continue jitter = rng.uniform(-0.15, 0.15, size=len(d)) ax.scatter( positions[i] + jitter, d, color=colors[i], s=data_size, alpha=data_alpha, zorder=2, edgecolors="none", ) # --- Axes formatting -------------------------------------------------- ax.set_xticks(positions) ax.set_xticklabels(labels) ax.set_ylabel(ylabel) ax.set_xlabel(xlabel) if log_scale: ax.set_yscale("log") _style_axes(ax) if font_size is not None: _apply_font_size(ax, font_size) return parts
# --------------------------------------------------------------------------- # plot_pvalue_matrix # ---------------------------------------------------------------------------
[docs] def plot_pvalue_matrix( pval_matrix, sig_matrix=None, labels=None, ax=None, parent_ax=None, inset_loc="upper left", inset_size="30%", inset_offset=0.08, cmap="viridis", sig_marker_color="red", sig_marker_size=2.5, show_colorbar=True, font_size=None, ): """Display a pairwise p-value matrix as a ``-log10(p)`` heatmap. Supports two rendering modes (mutually exclusive): - Standalone: pass ``ax`` to plot directly on existing axes. - Inset: pass ``parent_ax`` to create a small inset axes on a parent plot (e.g. a violin or distribution plot). Exactly one of ``ax`` or ``parent_ax`` must be provided. Parameters: pval_matrix (np.ndarray): (K, K) p-value matrix. Diagonal entries should be NaN (they are rendered in black). sig_matrix (np.ndarray or None): (K, K) boolean -- True where the comparison is significant. If None, computed as ``pval_matrix < 0.05``. labels (list[str] or None): Tick labels for each group. If None, integer indices are used. ax (matplotlib.axes.Axes or None): Target axes for standalone mode. parent_ax (matplotlib.axes.Axes or None): Parent axes on which to create an inset. inset_loc (str): Location string for ``inset_axes`` (e.g. ``"upper left"``, ``"lower left"``). Only used in inset mode. inset_size (str): Width and height of the inset as a percentage of the parent (e.g. ``"30%"``). Only used in inset mode. inset_offset (float): Horizontal offset of the inset bounding box from the parent axes edge. Only used in inset mode. cmap (str): Matplotlib colormap name. sig_marker_color (str): Colour for significance markers. sig_marker_size (float): Marker size for significance dots. show_colorbar (bool): Show a ``-log10(P)`` colorbar. font_size (int or None): Font size for labels and ticks. If None, uses current rcParams. Returns: target_ax (matplotlib.axes.Axes): The axes the matrix was drawn on (either ``ax`` or the newly created inset axes). """ plt, _ = _import_matplotlib() from mpl_toolkits.axes_grid1.inset_locator import inset_axes pval_matrix = np.asarray(pval_matrix, dtype=float) K = pval_matrix.shape[0] if sig_matrix is None: sig_matrix = pval_matrix < 0.05 sig_matrix = np.asarray(sig_matrix, dtype=bool) if labels is None: labels = [str(i) for i in range(K)] # --- Resolve target axes ---------------------------------------------- if ax is not None and parent_ax is not None: raise ValueError("Provide either 'ax' or 'parent_ax', not both.") if ax is None and parent_ax is None: raise ValueError("Provide either 'ax' (standalone) or 'parent_ax' (inset).") if parent_ax is not None: target_ax = inset_axes( parent_ax, width=inset_size, height=inset_size, loc=inset_loc, bbox_to_anchor=(inset_offset, 0, 1, 1), bbox_transform=parent_ax.transAxes, borderpad=1.0, ) else: target_ax = ax # --- Compute -log10(p) ------------------------------------------------ neg_log_p = -np.log10(pval_matrix) # Cap infinite values for display finite_vals = neg_log_p[np.isfinite(neg_log_p) & ~np.isnan(neg_log_p)] vmax = np.max(finite_vals) if len(finite_vals) > 0 else 1 neg_log_p = np.where(np.isfinite(neg_log_p), neg_log_p, vmax) # Diagonal → NaN (rendered as black via set_bad) np.fill_diagonal(neg_log_p, np.nan) import matplotlib as mpl try: colormap = mpl.colormaps[cmap].copy() except (AttributeError, TypeError, KeyError, ValueError): # Matplotlib < 3.7 / older registry API colormap = mpl.cm.get_cmap(cmap).copy() colormap.set_bad(color="black") im = target_ax.imshow( neg_log_p, cmap=colormap, aspect="equal", interpolation="none", vmin=0, vmax=vmax, ) # --- Significance markers --------------------------------------------- for i in range(K): for j in range(K): if i != j and sig_matrix[i, j]: target_ax.plot( j, i, "o", color=sig_marker_color, markersize=sig_marker_size, markeredgewidth=0, ) # --- Tick labels ------------------------------------------------------ fs = font_size if fs is None: fs = plt.rcParams.get("xtick.labelsize", 10) if not isinstance(fs, (int, float)): fs = 10 tick_fs = fs - 1 if parent_ax is not None else fs target_ax.set_xticks(range(K)) target_ax.set_xticklabels(labels, fontsize=tick_fs) target_ax.set_yticks(range(K)) target_ax.set_yticklabels(labels, fontsize=tick_fs) _style_axes_heatmap(target_ax) # --- Colorbar --------------------------------------------------------- if show_colorbar: cbar_ax = inset_axes( target_ax, width="8%", height="100%", loc="center right", bbox_to_anchor=(0.18, 0, 1, 1), bbox_transform=target_ax.transAxes, borderpad=0, ) cbar = target_ax.figure.colorbar(im, cax=cbar_ax) cbar.outline.set_linewidth(0.5) cbar.ax.tick_params(labelsize=tick_fs, width=0.5, length=1.5) cbar_label_fs = tick_fs if parent_ax is not None else fs cbar.set_label(r"$-\log_{10}(\mathrm{P})$", fontsize=cbar_label_fs) return target_ax
# --------------------------------------------------------------------------- # plot_scatter # ---------------------------------------------------------------------------
[docs] def plot_scatter( ax, x, y, xlabel="", ylabel="", color_vals=None, color_label="", cmap="viridis", vmin=None, vmax=None, show_identity=False, show_colorbar=True, fit=None, show_ci=False, show_r2=False, marker_size=8, alpha=0.7, groups=None, group_labels=None, group_colors=None, show_legend=True, font_size=None, ): """Scatter plot comparing two arrays with optional color coding and regression. Supports two colouring modes (mutually exclusive): - Continuous: pass ``color_vals`` for a colormap-based colour scale with an optional colorbar. - Discrete groups: pass ``groups`` (integer index per point) to colour each group separately with its own legend entry. When ``groups`` is provided, ``color_vals``, ``cmap``, ``vmin``, ``vmax``, ``color_label``, and ``show_colorbar`` are ignored. Parameters: ax (matplotlib.axes.Axes): Target axes (caller creates). x (np.ndarray): X-axis values. y (np.ndarray): Y-axis values. xlabel (str): X-axis label. ylabel (str): Y-axis label. color_vals (np.ndarray or str or None): Per-point values for continuous color mapping. Pass the string ``"density"`` to auto-compute KDE density and sort points so dense regions render on top (requires scipy). If None and ``groups`` is also None, all points are drawn in a uniform colour. color_label (str): Colorbar label (continuous mode only). cmap (str): Matplotlib colormap name (continuous mode only). vmin (float or None): Colormap minimum (continuous mode only). vmax (float or None): Colormap maximum (continuous mode only). show_identity (bool): Plot the x = y identity line. show_colorbar (bool): Add a colorbar when color_vals is provided (continuous mode only). fit (str or None): Regression to overlay. ``"linear"`` or None. show_ci (bool): Show confidence interval band on the fit. show_r2 (bool): Annotate R-squared on the plot. marker_size (float): Scatter marker size. alpha (float): Scatter alpha. groups (array-like or None): Per-point integer group index for discrete colouring. Each unique value is rendered as a separate scatter series with its own colour and legend entry. group_labels (list[str] or None): Label for each unique group value, ordered by ``np.unique(groups)``. If None, the group values are used as labels. group_colors (list[str] or None): Colour for each unique group value, ordered by ``np.unique(groups)``. If None, uses the default matplotlib colour cycle. show_legend (bool): Show legend when ``groups`` is provided. Default True. font_size (int or None): Font size for labels/ticks. If None, uses current rcParams. Returns: sc (PathCollection or list[PathCollection]): In continuous mode, the single scatter artist (useful for shared colorbars). In group mode, a list of scatter artists (one per group). Notes: When ``color_vals="density"``, non-finite values in x and y are removed before computing the KDE. Any regression overlay (``fit="linear"``) and the identity line are therefore computed on this filtered subset, not the original arrays. """ _import_matplotlib() x = np.asarray(x, dtype=float).ravel() y = np.asarray(y, dtype=float).ravel() sc = None if groups is not None: # --- Discrete group colouring ------------------------------------- groups = np.asarray(groups).ravel() unique_groups = np.unique(groups) n_groups = len(unique_groups) if group_colors is None: import matplotlib.pyplot as _plt cycle_colors = _plt.rcParams["axes.prop_cycle"].by_key()["color"] group_colors = [ cycle_colors[i % len(cycle_colors)] for i in range(n_groups) ] if group_labels is None: group_labels = [str(g) for g in unique_groups] sc = [] for i, g in enumerate(unique_groups): mask = groups == g s = ax.scatter( x[mask], y[mask], c=group_colors[i], s=marker_size, alpha=alpha, edgecolors="none", label=group_labels[i], zorder=2, ) sc.append(s) if show_legend: ax.legend(frameon=False) else: # --- Continuous / uniform colouring ------------------------------- scatter_kw = dict(s=marker_size, alpha=alpha, edgecolors="none", zorder=2) if color_vals is not None: if isinstance(color_vals, str) and color_vals == "density": try: from scipy.stats import gaussian_kde except ImportError as e: raise ImportError( "color_vals='density' requires scipy. " "Install with: pip install scipy" ) from e valid = np.isfinite(x) & np.isfinite(y) xy = np.vstack([x[valid], y[valid]]) kde = gaussian_kde(xy) density = kde(xy) sort_idx = density.argsort() x = x[valid][sort_idx] y = y[valid][sort_idx] color_vals = density[sort_idx] else: color_vals = np.asarray(color_vals, dtype=float).ravel() scatter_kw.update(c=color_vals, cmap=cmap) if vmin is not None: scatter_kw["vmin"] = vmin if vmax is not None: scatter_kw["vmax"] = vmax else: scatter_kw["c"] = "black" sc = ax.scatter(x, y, **scatter_kw) # --- Identity line ---------------------------------------------------- if show_identity: lo = min(np.nanmin(x), np.nanmin(y)) hi = max(np.nanmax(x), np.nanmax(y)) ax.plot([lo, hi], [lo, hi], ls="--", color="grey", linewidth=0.8, zorder=1) # --- Regression fit --------------------------------------------------- if fit == "linear": from .stat_utils import linear_regression reg = linear_regression(x, y) ax.plot(reg["x_fit"], reg["y_fit"], color="red", linewidth=1.2, zorder=3) if show_ci: ax.fill_between( reg["x_fit"], reg["ci_lower"], reg["ci_upper"], color="red", alpha=0.15, zorder=1, ) if show_r2: ax.annotate( f"$R^2 = {reg['r_squared']:.3f}$", xy=(0.05, 0.95), xycoords="axes fraction", ha="left", va="top", fontsize=font_size or 10, ) elif fit is not None: raise ValueError(f"Unknown fit '{fit}'. Use 'linear' or None.") # --- Colorbar --------------------------------------------------------- if groups is None and color_vals is not None and show_colorbar: _add_colorbar(sc, ax, label=color_label, font_size=font_size) # --- Axes formatting -------------------------------------------------- ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) _style_axes(ax) if font_size is not None: _apply_font_size(ax, font_size) return sc
# --------------------------------------------------------------------------- # plot_scatter_with_marginals # ---------------------------------------------------------------------------
[docs] def plot_scatter_with_marginals( gs_slot, fig, x, y, xlabel="", ylabel="", title=None, marginal_bins=60, marginal_color="0.4", show_zero_lines=False, height_ratios=None, width_ratios=None, **scatter_kwargs, ): """Scatter plot with marginal histograms on the top and right edges. Creates a 2x2 sub-GridSpec inside gs_slot (top-left: x histogram, bottom-left: scatter, bottom-right: y histogram, top-right: empty). All scatter options are forwarded to :func:`plot_scatter`. When a colorbar is shown (continuous ``color_vals`` with ``show_colorbar=True``), it is placed to the right of the right marginal histogram rather than on the scatter axes. Parameters: gs_slot: A GridSpec slot (e.g. ``gs[0]``) to place the sub-layout in. fig (matplotlib.Figure): Parent figure. x (np.ndarray): X-axis values. y (np.ndarray): Y-axis values. xlabel (str): X-axis label for the scatter. ylabel (str): Y-axis label for the scatter. title (str or None): Title placed above the top marginal histogram. marginal_bins (int): Number of histogram bins. marginal_color (str): Histogram bar colour. show_zero_lines (bool): Draw vertical/horizontal zero reference lines on the marginal histograms. height_ratios (list or None): ``[hist, scatter]`` height ratios. Default ``[1, 4]``. width_ratios (list or None): ``[scatter, hist]`` width ratios. Default ``[4, 1]``. **scatter_kwargs: Additional keyword arguments forwarded to :func:`plot_scatter` (e.g. ``color_vals``, ``show_identity``, ``marker_size``, ``cmap``). Returns: ax_scatter (matplotlib.axes.Axes): The main scatter axes. ax_histx (matplotlib.axes.Axes): Top marginal histogram axes. ax_histy (matplotlib.axes.Axes): Right marginal histogram axes. sc: Return value from :func:`plot_scatter`. """ plt, _ = _import_matplotlib() from matplotlib.gridspec import GridSpecFromSubplotSpec if height_ratios is None: height_ratios = [1, 4] if width_ratios is None: width_ratios = [4, 1] # Intercept show_colorbar so we can place it on the right marginal axis # instead of on the scatter axes. wants_colorbar = scatter_kwargs.pop("show_colorbar", True) scatter_kwargs["show_colorbar"] = False inner = GridSpecFromSubplotSpec( 2, 2, subplot_spec=gs_slot, height_ratios=height_ratios, width_ratios=width_ratios, hspace=0.05, wspace=0.05, ) ax_scatter = fig.add_subplot(inner[1, 0]) ax_histx = fig.add_subplot(inner[0, 0], sharex=ax_scatter) ax_histy = fig.add_subplot(inner[1, 1], sharey=ax_scatter) ax_corner = fig.add_subplot(inner[0, 1]) ax_corner.axis("off") # Plot scatter sc = plot_scatter(ax_scatter, x, y, xlabel=xlabel, ylabel=ylabel, **scatter_kwargs) ax_scatter.set_aspect("equal", adjustable="box") # Determine axis range from scatter xlim = ax_scatter.get_xlim() ylim = ax_scatter.get_ylim() # Marginal histograms x_arr = np.asarray(x, dtype=float).ravel() y_arr = np.asarray(y, dtype=float).ravel() valid = np.isfinite(x_arr) & np.isfinite(y_arr) bins_x = np.linspace(xlim[0], xlim[1], marginal_bins) bins_y = np.linspace(ylim[0], ylim[1], marginal_bins) ax_histx.hist(x_arr[valid], bins=bins_x, color=marginal_color, edgecolor="none") ax_histy.hist( y_arr[valid], bins=bins_y, color=marginal_color, edgecolor="none", orientation="horizontal", ) # Style marginal axes ax_histx.set_yticks([]) ax_histy.set_xticks([]) ax_histx.tick_params(labelbottom=False, bottom=False) ax_histy.tick_params(labelleft=False, left=False) for spine in ["top", "right", "left"]: ax_histx.spines[spine].set_visible(False) for spine in ["top", "right", "bottom"]: ax_histy.spines[spine].set_visible(False) if show_zero_lines: ax_histx.axvline(0, ls=":", color="red", lw=1.5) ax_histy.axhline(0, ls=":", color="red", lw=1.5) # Title above the top marginal histogram if title is not None: font_size = scatter_kwargs.get("font_size") ax_histx.set_title(title, fontsize=font_size) # Colorbar on the right marginal axis (outside the histograms) color_vals = scatter_kwargs.get("color_vals") groups = scatter_kwargs.get("groups") if wants_colorbar and groups is None and color_vals is not None: font_size = scatter_kwargs.get("font_size") color_label = scatter_kwargs.get("color_label", "") _add_colorbar( sc, ax_histy, label=color_label, font_size=font_size, size="5%", pad=0.08, ) return ax_scatter, ax_histx, ax_histy, sc
# --------------------------------------------------------------------------- # plot_manifold # ---------------------------------------------------------------------------
[docs] def plot_manifold( ax, embedding, pc_x=0, pc_y=1, var_explained=None, bg_mask=None, bg_color="0.85", bg_alpha=0.05, bg_size=0.3, color_vals=None, color_label="", cmap="viridis", vmin=None, vmax=None, groups=None, group_labels=None, group_colors=None, marker_size=3, alpha=0.5, show_colorbar=True, show_legend=True, xlabel=None, ylabel=None, font_size=None, ): """Plot a 2-D embedding (PCA, UMAP, etc.) with flexible point coloring. Supports three foreground coloring modes (same as :func:`plot_scatter`): - Continuous: pass ``color_vals`` for colormap-scaled values. - Discrete groups: pass ``groups`` for per-group colours. - Uniform: neither provided -- all foreground points drawn in black. An optional background mask renders selected points in a faint colour before the foreground, useful for separating non-event from event points (e.g. non-burst vs burst time bins). Parameters: ax (matplotlib.axes.Axes): Target axes (caller creates). embedding (np.ndarray): Shape ``(T, >=2)`` embedding coordinates. pc_x (int): Column index for the x-axis. Default 0. pc_y (int): Column index for the y-axis. Default 1. var_explained (np.ndarray or None): Explained variance ratio per component. When provided, axis labels are auto-generated as ``"PC{n} (X.X%)"``; overridden by explicit ``xlabel``/``ylabel``. bg_mask (np.ndarray or None): Boolean mask, shape ``(T,)``. True for background points. These are drawn first in ``bg_color``. bg_color (str): Colour for background points. bg_alpha (float): Alpha for background points. bg_size (float): Marker size for background points. color_vals (np.ndarray or str or None): Per-point values for continuous colour mapping (foreground only). Pass ``"density"`` for KDE-based density colouring. color_label (str): Colorbar label (continuous mode). cmap (str): Matplotlib colourmap name (continuous mode). vmin (float or None): Colourmap minimum. vmax (float or None): Colourmap maximum. groups (array-like or None): Per-point integer group index for discrete colouring (foreground only). group_labels (list[str] or None): Labels per unique group value. group_colors (list[str] or None): Colours per unique group value. marker_size (float): Marker size for foreground points. Default 4. alpha (float): Alpha for foreground points. show_colorbar (bool): Add a colorbar (continuous mode only). show_legend (bool): Show a legend (group mode only). xlabel (str or None): X-axis label. Overrides auto-label from ``var_explained``. ylabel (str or None): Y-axis label. Overrides auto-label from ``var_explained``. font_size (int or None): Font size for labels and ticks. If None, uses current rcParams. Returns: sc: The foreground scatter artist(s) -- a single ``PathCollection`` (continuous/uniform) or a ``list[PathCollection]`` (group mode). Useful for adding shared colorbars or custom legends. """ _import_matplotlib() embedding = np.asarray(embedding) x = embedding[:, pc_x] y = embedding[:, pc_y] # --- Background points ------------------------------------------------ if bg_mask is not None: bg_mask = np.asarray(bg_mask, dtype=bool) ax.scatter( x[bg_mask], y[bg_mask], s=bg_size, c=bg_color, alpha=bg_alpha, rasterized=True, edgecolors="none", zorder=1, ) fg_mask = ~bg_mask else: fg_mask = np.ones(len(x), dtype=bool) # --- Foreground: delegate to plot_scatter ----------------------------- fg_x = x[fg_mask] fg_y = y[fg_mask] scatter_kw = dict( marker_size=marker_size, alpha=alpha, show_colorbar=show_colorbar, show_legend=show_legend, font_size=font_size, ) if color_vals is not None: if not isinstance(color_vals, str): color_vals = np.asarray(color_vals).ravel()[fg_mask] scatter_kw.update( color_vals=color_vals, color_label=color_label, cmap=cmap, vmin=vmin, vmax=vmax, ) elif groups is not None: groups = np.asarray(groups).ravel()[fg_mask] scatter_kw.update( groups=groups, group_labels=group_labels, group_colors=group_colors, ) sc = plot_scatter(ax, fg_x, fg_y, **scatter_kw) # --- Axis labels ------------------------------------------------------ if xlabel is not None: ax.set_xlabel(xlabel) elif var_explained is not None: ax.set_xlabel(f"PC{pc_x + 1} ({var_explained[pc_x]:.1%})") if ylabel is not None: ax.set_ylabel(ylabel) elif var_explained is not None: ax.set_ylabel(f"PC{pc_y + 1} ({var_explained[pc_y]:.1%})") _style_axes(ax) if font_size is not None: _apply_font_size(ax, font_size) return sc
# --------------------------------------------------------------------------- # plot_lines # ---------------------------------------------------------------------------
[docs] def plot_lines( ax, traces, x=None, labels=None, colors=None, xlabel="", ylabel="", linewidth=1.5, show_legend=True, font_size=None, ): """Plot one or more 1-D traces on a shared x-axis. Parameters: ax (matplotlib.axes.Axes): Target axes (caller creates). traces (dict[str, np.ndarray] or list[np.ndarray]): Line data. Dict keys are used as labels; for list input, supply ``labels`` separately. x (np.ndarray or None): Shared x-axis values. If None, integer indices (``0 … len-1``) are used. labels (list[str] or None): Per-trace labels. Required for list input; ignored for dict input (keys are used instead). colors (list[str] or dict[str, str] or None): Per-trace colours. For dict ``traces``, may be a dict keyed by the same labels or a list in the same order. If None, uses the default matplotlib colour cycle. xlabel (str): X-axis label. ylabel (str): Y-axis label. linewidth (float): Line width for all traces. show_legend (bool): Show legend. Default True. font_size (int or None): Font size for labels and ticks. If None, uses current rcParams. Returns: lines (list[Line2D]): The line artists. """ _import_matplotlib() # --- Normalise input to ordered (label, array) pairs ------------------ if isinstance(traces, dict): ordered_labels = list(traces.keys()) ordered_data = [np.asarray(traces[k]).ravel() for k in ordered_labels] else: ordered_data = [np.asarray(a).ravel() for a in traces] if labels is not None: ordered_labels = list(labels) else: ordered_labels = [str(i) for i in range(len(ordered_data))] n = len(ordered_data) # --- Resolve x-axis --------------------------------------------------- if x is None: x = np.arange(len(ordered_data[0])) else: x = np.asarray(x).ravel() for i in range(n): if len(ordered_data[i]) != len(x): raise ValueError( f"Trace '{ordered_labels[i]}' has length " f"{len(ordered_data[i])} but x has length {len(x)}" ) # --- Resolve colours -------------------------------------------------- if colors is None: import matplotlib.pyplot as _plt cycle_colors = _plt.rcParams["axes.prop_cycle"].by_key()["color"] resolved_colors = [cycle_colors[i % len(cycle_colors)] for i in range(n)] elif isinstance(colors, dict): resolved_colors = [colors[lbl] for lbl in ordered_labels] else: resolved_colors = list(colors) # --- Draw lines ------------------------------------------------------- lines = [] for i in range(n): (line,) = ax.plot( x, ordered_data[i], color=resolved_colors[i], linewidth=linewidth, label=ordered_labels[i], ) lines.append(line) # --- Axes formatting -------------------------------------------------- ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) if show_legend: ax.legend(frameon=False) _style_axes(ax) if font_size is not None: _apply_font_size(ax, font_size) return lines
# --------------------------------------------------------------------------- # plot_percentile_bands # ---------------------------------------------------------------------------
[docs] def plot_percentile_bands( ax, metric_data, labels=None, normalize=False, summary="mean", bands=None, band_color="0.3", band_alphas=None, style="bands", line_color="0.5", line_alpha=0.3, line_width=0.5, summary_color="black", summary_linewidth=1.5, show_zero_line=True, xlabel="", ylabel="", ylim_range=None, show_legend=False, font_size=None, ): """Plot percentile bands or individual lines across ordered groups/conditions. For each unit a value is computed per condition. Optionally, values are normalized relative to the first group using symmetric normalization: ``N = (x' - d0') / (x' + d0')`` where ``x' = max(x, 0)``. Parameters: ax (matplotlib.axes.Axes): Target axes (caller creates). metric_data (dict[str, np.ndarray] or list[np.ndarray]): Per-condition 1-D arrays of per-unit values. Dict keys or ``labels`` define the x-axis order. labels (list[str] or None): Ordered condition labels. If None, uses dict keys (for dict input) or integer indices (for list input). normalize (bool): Apply symmetric normalization to the first group. Units with non-positive or NaN baseline values are excluded. summary (str): Summary line type: ``"mean"`` (default) or ``"median"``. bands (list[tuple[int, int]] or None): Percentile band definitions as ``(lo, hi)`` pairs, ordered from widest to narrowest. Default is ``[(5, 95), (10, 90), (25, 75)]``. band_color (str): Fill colour for all bands. band_alphas (list[float] or None): Alpha transparency per band. Must match length of ``bands``. Default is linearly increasing from 0.15 to 0.40. style (str): ``"bands"`` (default) draws shaded percentile regions; ``"lines"`` draws one line per unit. line_color (str): Line colour when ``style="lines"``. line_alpha (float): Line alpha when ``style="lines"``. line_width (float): Line width when ``style="lines"``. summary_color (str): Colour for the summary line. summary_linewidth (float): Line width for the summary line. show_zero_line (bool): Draw a dashed horizontal line at y=0 when ``normalize=True``. xlabel (str): X-axis label. ylabel (str): Y-axis label. ylim_range (float or None): Symmetric y-axis limits ``(-val, val)``. If None and ``normalize=True``, derived from the 5th/95th percentile of the data. show_legend (bool): Show legend. font_size (int or None): Font size for labels and ticks. If None, uses current rcParams. Returns: artists (dict): Keys ``"summary"`` (Line2D), and either ``"bands"`` (list of PolyCollection) or ``"lines"`` (list of Line2D). """ _import_matplotlib() # --- Normalise input to list-of-arrays + labels ----------------------- if isinstance(metric_data, dict): keys = list(metric_data.keys()) data_arrays = [np.asarray(metric_data[k]).ravel() for k in keys] if labels is None: labels = keys else: data_arrays = [np.asarray(a).ravel() for a in metric_data] if labels is None: labels = [str(i) for i in range(len(data_arrays))] n_groups = len(data_arrays) x = np.arange(n_groups) # --- Build (n_units, n_groups) matrix, optionally normalized ---------- if normalize: d0 = np.maximum(data_arrays[0], 0) valid = (d0 > 0) & ~np.isnan(data_arrays[0]) for arr in data_arrays[1:]: valid &= ~np.isnan(arr) n_units = int(np.sum(valid)) mat = np.zeros((n_units, n_groups)) for j, arr in enumerate(data_arrays): vals = np.maximum(arr[valid], 0) mat[:, j] = (vals - d0[valid]) / (vals + d0[valid]) else: # Keep all non-NaN across every group valid = np.ones(len(data_arrays[0]), dtype=bool) for arr in data_arrays: valid &= ~np.isnan(arr) n_units = int(np.sum(valid)) mat = np.column_stack([arr[valid] for arr in data_arrays]) # --- Plot bands or individual lines ----------------------------------- artists = {} if style == "bands": if bands is None: bands = [(5, 95), (10, 90), (25, 75)] if band_alphas is None: n_bands = len(bands) band_alphas = [ 0.15 + (0.40 - 0.15) * i / max(n_bands - 1, 1) for i in range(n_bands) ] band_artists = [] for (lo_pct, hi_pct), alpha in zip(bands, band_alphas): lo_vals = np.nanpercentile(mat, lo_pct, axis=0) hi_vals = np.nanpercentile(mat, hi_pct, axis=0) label = f"{lo_pct}\u2013{hi_pct}th" fill = ax.fill_between( x, lo_vals, hi_vals, color=band_color, alpha=alpha, zorder=1, label=label, ) band_artists.append(fill) artists["bands"] = band_artists elif style == "lines": line_artists = [] for i in range(n_units): (ln,) = ax.plot( x, mat[i, :], color=line_color, alpha=line_alpha, linewidth=line_width, zorder=1, ) line_artists.append(ln) artists["lines"] = line_artists else: raise ValueError(f"style must be 'bands' or 'lines', got {style!r}") # --- Summary line ----------------------------------------------------- if summary == "mean": summary_vals = np.nanmean(mat, axis=0) elif summary == "median": summary_vals = np.nanmedian(mat, axis=0) else: raise ValueError(f"summary must be 'mean' or 'median', got {summary!r}") (summary_line,) = ax.plot( x, summary_vals, color=summary_color, linewidth=summary_linewidth, zorder=3, label=summary.capitalize(), ) artists["summary"] = summary_line # --- Zero reference line ---------------------------------------------- if show_zero_line and normalize: ax.axhline(0, color="0.4", linewidth=0.5, linestyle="--", zorder=0) # --- Axes formatting -------------------------------------------------- ax.set_xticks(x) ax.set_xticklabels(labels) ax.set_xlim(x[0], x[-1]) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) if ylim_range is not None: ax.set_ylim(-ylim_range, ylim_range) elif normalize and mat.size > 0: p5 = np.nanpercentile(mat, 5, axis=0) p95 = np.nanpercentile(mat, 95, axis=0) ylim = max(abs(p5.min()), abs(p95.max())) * 1.15 if np.isfinite(ylim) and ylim > 0: ax.set_ylim(-ylim, ylim) if show_legend: handles = [summary_line] if style == "bands": handles += artists["bands"] ax.legend(handles=handles, loc="upper left", frameon=False) _style_axes(ax) if font_size is not None: _apply_font_size(ax, font_size) return artists
# --------------------------------------------------------------------------- # plot_burst_sensitivity # ---------------------------------------------------------------------------
[docs] def plot_burst_sensitivity( ax, thresholds, burst_counts, dist_values=None, labels=None, colors=None, xlabel="RMS mult.", ylabel="Number of bursts", show_legend=True, show_colorbar=True, cmap="hot", font_size=None, ): """Plot burst detection sensitivity as line plots or heatmaps. Automatically selects the visualisation based on the dimensionality of the burst count arrays: - 1-D arrays (single-parameter sweep): one line per condition on a shared axes. - 2-D arrays (two-parameter sweep over thresholds x distances, as returned by ``SpikeData.burst_sensitivity()``): one heatmap per condition via :func:`plot_heatmap`. A single condition is plotted on ax; multiple conditions produce a row of subplots on a new figure. Parameters: ax (matplotlib.axes.Axes or None): Target axes. Used directly for 1-D line plots and single-condition 2-D heatmaps. For multi-condition 2-D heatmaps this parameter is ignored and a new figure is created (pass None explicitly in that case). thresholds (np.ndarray): 1-D array of threshold values (x-axis). burst_counts (dict[str, np.ndarray] or np.ndarray): Burst counts per condition. Dict mapping condition labels to arrays, or a bare ``np.ndarray`` for a single unnamed condition. Arrays can be 1-D (line plot) or 2-D of shape ``(len(thresholds), len(dist_values))`` (heatmap). dist_values (np.ndarray or None): 1-D array of distance parameter values. Required when burst counts are 2-D (used as y-axis tick labels on the heatmap). Ignored for 1-D line plots. labels (list[str] or None): Ordered condition labels. If None, uses dict keys. colors (list[str] or None): Per-condition line colours (line plots only). If None, uses the default matplotlib colour cycle. xlabel (str): X-axis label. ylabel (str): Y-axis label. For 2-D heatmaps, defaults to ``"Min. burst dist. (bins)"`` when not explicitly changed from the default. show_legend (bool): Whether to show a legend (line plots only). show_colorbar (bool): Whether to show a colorbar (heatmaps only). cmap (str): Colormap for heatmaps. font_size (int or None): Font size for labels/ticks. If None, uses current rcParams. Returns: result: For 1-D line plots: ``list[Line2D]`` artists. For a single 2-D heatmap: the axes returned by :func:`plot_heatmap`. For multiple 2-D heatmaps: ``(fig, axes_list)`` where axes_list contains one axes per condition. """ plt, _ = _import_matplotlib() thresholds = np.asarray(thresholds).ravel() # --- Normalise burst_counts to an ordered dict ------------------------ if isinstance(burst_counts, np.ndarray): burst_counts = {"": burst_counts} elif not isinstance(burst_counts, dict): burst_counts = {"": np.asarray(burst_counts).ravel()} if labels is None: labels = list(burst_counts.keys()) first_val = np.asarray(burst_counts[labels[0]]) is_2d = first_val.ndim == 2 # --- 2-D heatmap path ------------------------------------------------- if is_2d: if dist_values is None: raise ValueError( "dist_values is required for 2-D burst sensitivity heatmaps." ) dist_values = np.asarray(dist_values).ravel() heatmap_ylabel = ( "Min. burst dist. (bins)" if ylabel == "Number of bursts" else ylabel ) n_thr = len(thresholds) n_dist = len(dist_values) extent = ( thresholds[0], thresholds[-1], dist_values[0], dist_values[-1], ) xticks = ( np.linspace(thresholds[0], thresholds[-1], min(n_thr, 6)), [ f"{v:.1f}" for v in np.linspace(thresholds[0], thresholds[-1], min(n_thr, 6)) ], ) yticks = ( np.linspace(dist_values[0], dist_values[-1], min(n_dist, 6)), [ f"{v:.0f}" for v in np.linspace(dist_values[0], dist_values[-1], min(n_dist, 6)) ], ) fs = font_size # Single condition — plot on the provided ax if len(labels) == 1: return plot_heatmap( np.asarray(burst_counts[labels[0]]).T, ax=ax, cmap=cmap, origin="lower", extent=extent, xlabel=xlabel, ylabel=heatmap_ylabel, show_colorbar=show_colorbar, colorbar_label="Burst count", xticks=xticks, yticks=yticks, font_size=fs, ) # Multiple conditions — create a subplot row n_cond = len(labels) fig, axes_row = plt.subplots(1, n_cond, figsize=(5 * n_cond, 4), squeeze=False) axes_row = axes_row[0] # Compute shared colour range across all conditions all_arrays = [np.asarray(burst_counts[l]) for l in labels] shared_vmin = min(a.min() for a in all_arrays) shared_vmax = max(a.max() for a in all_arrays) for i, label in enumerate(labels): plot_heatmap( all_arrays[i].T, ax=axes_row[i], cmap=cmap, origin="lower", vmin=shared_vmin, vmax=shared_vmax, extent=extent, xlabel=xlabel, ylabel=heatmap_ylabel if i == 0 else "", show_colorbar=show_colorbar and i == n_cond - 1, colorbar_label="Burst count", xticks=xticks, yticks=yticks if i == 0 else (yticks[0], [""] * len(yticks[1])), font_size=fs, ) axes_row[i].set_title(label, fontsize=fs) fig.tight_layout() return fig, list(axes_row) # --- 1-D line plot path ----------------------------------------------- traces = {label: np.asarray(burst_counts[label]).ravel() for label in labels} return plot_lines( ax, traces, x=thresholds, colors=colors, xlabel=xlabel, ylabel=ylabel, show_legend=show_legend, font_size=font_size, )
# --------------------------------------------------------------------------- # plot_aligned_slice_single_unit # ---------------------------------------------------------------------------
[docs] def plot_aligned_slice_single_unit( ax, spike_times_per_slice, color_vals=None, color_label="", cmap="viridis", time_offset=0, xlabel="Rel. time (ms)", ylabel="Burst", x_range=None, vlines=None, show_colorbar=True, marker_size=20, font_size=None, style="scatter", invert_y=False, linewidths=0.5, ): """Raster plot of one unit's spike times across multiple event-aligned slices. Each row corresponds to one slice/burst, x-axis is time relative to the alignment point. Optionally colour-codes rows by a per-slice variable. Parameters: ax (matplotlib.axes.Axes): Target axes (caller creates). spike_times_per_slice (list[np.ndarray]): List of 1-D arrays, one per slice, containing spike times relative to the alignment point. color_vals (np.ndarray or None): Per-slice colour values. If None, all spikes are drawn in black. color_label (str): Colorbar label. cmap (str): Matplotlib colormap name. time_offset (float): Value subtracted from every spike time before plotting. Slices from ``align_to_events`` are already event-centered (spike times in ``[-pre_ms, +post_ms]``), so use the default ``time_offset=0``. Only set a non-zero value when spike times are not already centered on the event. xlabel (str): X-axis label. ylabel (str): Y-axis label. x_range (tuple or None): ``(xmin, xmax)`` for the x-axis. If None, auto-scales to the data. vlines (list[dict] or None): Vertical reference lines. Each dict must contain ``'x'`` (required) and may optionally include ``'color'`` (default ``'red'``), ``'linestyle'`` (default ``'--'``), and ``'linewidth'`` (default ``1.5``). show_colorbar (bool): Add a colorbar when color_vals is provided. marker_size (float): Scatter marker size (only used when ``style="scatter"``). font_size (int or None): Font size for labels/ticks. If None, uses current rcParams. style (str): ``"scatter"`` for dot markers (default), or ``"eventplot"`` for vertical line markers. invert_y (bool): If True, the first slice is plotted at the top and the last at the bottom. Default False (first slice at bottom). linewidths (float): Line width for eventplot markers (only used when ``style="eventplot"``). Default 0.5. Returns: sc (PathCollection or None): The scatter artist when color_vals is provided and ``style="scatter"``, otherwise None. """ _import_matplotlib() n_slices = len(spike_times_per_slice) # Shift spike times shifted_per_slice = [] for times in spike_times_per_slice: times = np.asarray(times, dtype=float).ravel() shifted_per_slice.append(times - time_offset) sc = None if style == "eventplot": ax.eventplot( shifted_per_slice, colors="black", linewidths=linewidths, ) else: # Flatten spike times into (x, y, c) arrays for scatter x_all = [] y_all = [] c_all = [] for idx, shifted in enumerate(shifted_per_slice): x_all.append(shifted) y_all.append(np.full(len(shifted), idx)) if color_vals is not None: c_all.append(np.full(len(shifted), color_vals[idx])) if len(x_all) == 0: return None x_all = np.concatenate(x_all) y_all = np.concatenate(y_all) if color_vals is not None and len(c_all) > 0: c_all = np.concatenate(c_all) sc = ax.scatter(x_all, y_all, c=c_all, cmap=cmap, s=marker_size, zorder=2) if show_colorbar: _add_colorbar(sc, ax, label=color_label, font_size=font_size) else: ax.scatter(x_all, y_all, c="black", s=marker_size, zorder=2) # --- Vertical reference lines ----------------------------------------- if vlines is not None: for vl in vlines: if "x" not in vl: raise ValueError("Each vlines dict must contain an 'x' key.") ax.axvline( x=vl["x"], color=vl.get("color", "red"), linestyle=vl.get("linestyle", "--"), linewidth=vl.get("linewidth", 1.5), zorder=0, ) # --- Axes formatting -------------------------------------------------- ax.set_ylim(-0.5, n_slices - 0.5) if invert_y: ax.invert_yaxis() if x_range is not None: ax.set_xlim(x_range) else: non_empty = [s for s in shifted_per_slice if len(s) > 0] if non_empty: all_shifted = np.concatenate(non_empty) ax.set_xlim(np.min(all_shifted), np.max(all_shifted)) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) _style_axes(ax) if font_size is not None: _apply_font_size(ax, font_size) return sc
# --------------------------------------------------------------------------- # plot_heatmap # ---------------------------------------------------------------------------
[docs] def plot_heatmap( data_mat, ax=None, norm=False, vmin=None, vmax=None, cmap="hot", aspect="auto", origin="upper", extent=None, xlabel="Time (ms)", ylabel="Unit", xticks=None, yticks=None, vlines=None, hlines=None, show_colorbar=True, colorbar_label="Rate (Hz)", font_size=14, save_path=None, ): """Plot a 2-D matrix as a heatmap. Parameters: data_mat (np.ndarray): 2-D array to display, shape ``(rows, cols)``. ax (matplotlib.axes.Axes or None): Target axes. If None a standalone figure is created. norm (bool or str): Row normalisation. ``False`` for none, ``'row'`` to scale each row to [0, 1]. vmin (float or None): Colormap minimum. vmax (float or None): Colormap maximum. cmap (str): Matplotlib colormap name. aspect (str): Aspect ratio passed to ``imshow``. ``"auto"`` stretches to fill the axes, ``"equal"`` preserves square pixels. origin (str): Origin for ``imshow``. ``"upper"`` places row 0 at the top (default), ``"lower"`` places row 0 at the bottom. extent (tuple or None): ``(left, right, bottom, top)`` passed to ``imshow``. If None, pixel indices are used. xlabel (str): X-axis label. ylabel (str): Y-axis label. xticks (tuple or None): ``(locations, labels)`` for x-axis ticks. yticks (tuple or None): ``(locations, labels)`` for y-axis ticks. vlines (list[dict] or None): Vertical lines. Each dict may contain ``'x'``, ``'color'``, ``'linestyle'``, ``'linewidth'``. hlines (list[dict] or None): Horizontal lines (same keys as vlines but with ``'y'``). show_colorbar (bool): Whether to draw a colorbar. colorbar_label (str): Label for the colorbar. font_size (int): Font size for labels and ticks. save_path (str or None): If provided (and ``ax`` is None), save the figure to this path and close it. Returns: result: ``(fig, ax)`` when ``ax`` is None, otherwise just ``ax``. """ plt, _ = _import_matplotlib() standalone = ax is None if standalone: fig, ax = plt.subplots(figsize=(8, 6)) # Normalise if norm == "row": result = np.zeros_like(data_mat, dtype=float) for i in range(data_mat.shape[0]): row_min, row_max = np.nanmin(data_mat[i]), np.nanmax(data_mat[i]) if row_max > row_min: result[i] = (data_mat[i] - row_min) / (row_max - row_min) else: result[i] = data_mat[i] data_mat = result colorbar_label = "Norm. " + colorbar_label if vmax is None: vmax = 1.0 im_kwargs = dict(cmap=cmap, aspect=aspect, origin=origin, interpolation="none") if extent is not None: im_kwargs["extent"] = extent if vmin is not None: im_kwargs["vmin"] = vmin if vmax is not None: im_kwargs["vmax"] = vmax im = ax.imshow(data_mat, **im_kwargs) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) if yticks is not None: ax.set_yticks(yticks[0]) ax.set_yticklabels(yticks[1]) else: ax.set_yticks([0, data_mat.shape[0] - 1]) ax.set_yticklabels([1, data_mat.shape[0]]) if xticks is not None: ax.set_xticks(xticks[0]) ax.set_xticklabels(xticks[1]) if vlines: for vl in vlines: if "x" not in vl: raise ValueError("Each vlines dict must contain an 'x' key.") ax.axvline( x=vl["x"], color=vl.get("color", "green"), linestyle=vl.get("linestyle", "dotted"), linewidth=vl.get("linewidth", 2), ) if hlines: for hl in hlines: ax.axhline( y=hl["y"], color=hl.get("color", "green"), linestyle=hl.get("linestyle", "dotted"), linewidth=hl.get("linewidth", 2), ) if show_colorbar: _add_colorbar( im, ax, label=colorbar_label, font_size=font_size, size="10%", pad=0.08 ) _style_axes_heatmap(ax) _apply_font_size(ax, font_size) if standalone: plt.tight_layout() if save_path is not None: fig.savefig(save_path) plt.close(fig) return fig, ax return ax
# --------------------------------------------------------------------------- # plot_recording # ---------------------------------------------------------------------------
[docs] def plot_recording( sd, # --- panel toggles --- show_raster=True, show_pop_rate=False, show_fr_rates=False, show_model_states=False, # --- data inputs (None = auto-compute where possible) --- pop_rate=None, pop_rate_params=None, fr_rates=None, fr_rate_sigma_ms=10.0, model_states=None, cont_prob=None, gplvm_result=None, # --- display options --- time_range=None, sort_indices=None, raster_style="eventplot", raster_bin_size_ms=1.0, raster_vmax=5, burst_times=None, burst_edges=None, burst_colors=None, # --- heatmap options --- vmin_heatmap=None, vmax_heatmap=None, heatmap_clip_pct=80, norm_heatmap=False, model_states_cmap="viridis", model_states_vmin=0, model_states_vmax=1, # --- layout --- axes=None, figsize=None, height_ratios=None, absolute_xticks=True, font_size=14, show=True, save_path=None, ): """Assemble a multi-panel column figure from a SpikeData object. Each panel is optional -- the caller selects which panels to include and they are stacked vertically with a shared x-axis. Available panels (in display order): 1. Spike raster -- eventplot or binned-count image. 2. Population rate -- smoothed firing rate curve; if ``cont_prob`` is also provided it is overlaid on the same axes. 3. Firing-rate heatmap -- per-unit instantaneous rates as a colour map. 4. Model states -- latent-state posterior from a GPLVM fit. Parameters: sd (SpikeData): Source spike data object. show_raster (bool): Include the spike raster panel. show_pop_rate (bool): Include the population-rate panel. Automatically enabled when ``pop_rate`` or ``cont_prob`` is provided. show_fr_rates (bool): Include the firing-rate heatmap. Automatically enabled when ``fr_rates`` is provided. show_model_states (bool): Include the model-states panel. Automatically enabled when ``model_states`` is provided. pop_rate (np.ndarray or None): Pre-computed smoothed population rate in spikes per bin (as returned by ``sd.get_pop_rate()``), shape ``(T,)``. Automatically converted to Hz/unit for display. If None and panel is enabled, computed via ``sd.get_pop_rate()``. pop_rate_params (dict or None): Keyword arguments forwarded to ``sd.get_pop_rate()`` when ``pop_rate`` is None. Defaults: ``square_width=8, gauss_sigma=8``. fr_rates (np.ndarray or None): Pre-computed per-unit instantaneous firing rates, shape ``(U, T)``. If None and ``show_fr_rates`` is True, computed via ``sd.resampled_isi()``. fr_rate_sigma_ms (float): Gaussian sigma in ms for ``sd.resampled_isi()`` when auto-computing firing rates. model_states (np.ndarray or None): Latent-state posterior, shape ``(S, T)`` where S is the number of latent states. If None and ``show_model_states`` is True, extracted from ``gplvm_result``. cont_prob (np.ndarray or None): Continuous-dynamics probability, shape ``(T,)``. Overlaid on the population-rate panel. If None and ``gplvm_result`` is provided, extracted automatically. gplvm_result (dict or None): Result dictionary from ``SpikeData.fit_gplvm()``. Used to auto-extract ``model_states`` and ``cont_prob`` when those are not provided directly. time_range (tuple or None): ``(start_ms, end_ms)`` display window. None shows the full recording. sort_indices (np.ndarray or None): Unit reordering indices applied to the raster and firing-rate heatmap. raster_style (str): ``'eventplot'`` (default) or ``'imshow'``. Controls how the raster panel is rendered. Both styles display unit 0 at the top, but achieve this differently: ``'imshow'`` uses ``origin='upper'`` while ``'eventplot'`` inverts the y-axis. As a result, y-coordinates are reversed in eventplot mode (ylim goes from N to 0). Keep this in mind when overlaying custom artists. raster_bin_size_ms (float): Bin size in ms for the raster (used for both eventplot and imshow styles). When ``absolute_xticks=True``, tick values are computed as ``bin_index * raster_bin_size_ms + offset``, which is exact only when ``raster_bin_size_ms=1.0``. raster_vmax (int): Maximum spike count for colormap when ``raster_style='imshow'``. Default is 5. burst_times (np.ndarray or None): Burst peak positions as raster bin indices (as returned by ``get_bursts``), plotted as markers on the population-rate panel. With the default ``raster_bin_size_ms=1.0``, bin indices equal milliseconds. burst_edges (np.ndarray or None): Per-burst ``[start_bin, end_bin]`` boundaries as raster bin indices, shape ``(B, 2)``. Plotted as shaded spans on the population-rate panel. burst_colors (array-like or None): Per-burst color values, one per burst (length B, matching ``burst_times`` / ``burst_edges``). When provided, each burst span and peak marker is drawn in its assigned color. When None, spans default to blue and peaks to black. vmin_heatmap (float or None): Colormap minimum for the FR heatmap. vmax_heatmap (float or None): Colormap maximum for the FR heatmap. When None, clipped to ``heatmap_clip_pct`` percentile of the data. heatmap_clip_pct (float or None): Fraction of the data maximum (as a percentage) used as the colormap maximum when ``vmax_heatmap`` is None. Default 80 (i.e. 80% of the max value). Set to None to use the absolute maximum. norm_heatmap (bool or str): Row normalisation for the FR heatmap (``False`` or ``'row'``). model_states_cmap (str): Colormap for the model-states panel. model_states_vmin (float): Colormap minimum for model states. model_states_vmax (float): Colormap maximum for model states. axes (list or None): Pre-created axes to plot onto instead of creating a new figure. Must be a list of ``(ax_panel, ax_cbar)`` tuples, one per enabled panel, in display order (raster, pop_rate, fr_heatmap, model_states -- only those that are enabled). ``ax_cbar`` is used for colorbars on heatmap/imshow panels; pass a hidden axes if no colorbar is needed for that panel. When provided, the function skips figure creation, ``tight_layout``, ``show``, and ``save_path``. ``figsize`` and ``height_ratios`` are ignored. Raises ``ValueError`` if the length does not match the number of enabled panels. figsize (tuple or None): Figure size ``(width, height)``. Ignored when ``axes`` is provided. height_ratios (list or None): Relative panel heights. Length must match the number of enabled panels. absolute_xticks (bool): If True, x-tick labels show absolute seconds from recording start. If False, labels are relative to the display window. font_size (int): Font size for labels and tick labels. show (bool): If True and ``save_path`` is None, call ``plt.show()``. save_path (str or None): Save figure to this path and close it. Returns: fig (matplotlib.Figure): The assembled figure. Notes: At least one panel must be enabled, otherwise ``ValueError`` is raised. When ``gplvm_result`` is provided, ``model_states`` and ``cont_prob`` are extracted from the ``decode_res`` sub-dict (keys ``posterior_latent_marg`` and ``posterior_dynamics_marg``). Arrays with a different time resolution (e.g. GPLVM binned output) are automatically cropped to match the ms-based ``time_range``. """ plt, mticker = _import_matplotlib() # ------------------------------------------------------------------ # 0. Auto-extract from gplvm_result # ------------------------------------------------------------------ if gplvm_result is not None: decode = gplvm_result.get("decode_res", {}) if model_states is None and "posterior_latent_marg" in decode: model_states = np.asarray(decode["posterior_latent_marg"]).T if cont_prob is None and "posterior_dynamics_marg" in decode: dyn = np.asarray(decode["posterior_dynamics_marg"]) cont_prob = dyn[:, 0] if dyn.ndim == 2 else dyn # ------------------------------------------------------------------ # 1. Resolve panel flags — auto-enable when data is provided # ------------------------------------------------------------------ if pop_rate is not None or cont_prob is not None: show_pop_rate = True if fr_rates is not None: show_fr_rates = True if model_states is not None: show_model_states = True panels = [] if show_raster: panels.append("raster") if show_pop_rate: panels.append("pop_rate") if show_fr_rates: panels.append("fr_heatmap") if show_model_states: panels.append("model_states") if not panels: raise ValueError( "No panels enabled. Set at least one of show_raster, " "show_pop_rate, show_fr_rates, or show_model_states to True, " "or provide data for auto-detection." ) n_panels = len(panels) # ------------------------------------------------------------------ # 2. Build spike matrix # ------------------------------------------------------------------ spk_mat = sd.sparse_raster(bin_size=raster_bin_size_ms).toarray() # Auto-compute firing rates if requested if show_fr_rates and fr_rates is None: times_arr = np.arange(sd.start_time, sd.start_time + sd.length, 1.0) fr_rates = sd.resampled_isi(times_arr, sigma_ms=fr_rate_sigma_ms) # resampled_isi now returns RateData; plotting expects array-like (U, T). from .ratedata import RateData if fr_rates is not None and isinstance(fr_rates, RateData): fr_rates = fr_rates.inst_Frate_data # Apply unit reordering if sort_indices is not None: spk_mat = spk_mat[sort_indices, :] if fr_rates is not None: fr_rates = fr_rates[sort_indices, :] # Auto-compute population rate if requested if show_pop_rate and pop_rate is None: params = { "square_width": 8, "gauss_sigma": 8, "raster_bin_size_ms": raster_bin_size_ms, } if pop_rate_params is not None: params.update(pop_rate_params) pop_rate = sd.get_pop_rate(**params) # ------------------------------------------------------------------ # 3. Crop to time_range # ------------------------------------------------------------------ if time_range is not None: # Convert ms time_range to bin indices relative to the raster. # Bin 0 corresponds to sd.start_time. start = int((time_range[0] - sd.start_time) / raster_bin_size_ms) end = int((time_range[1] - sd.start_time) / raster_bin_size_ms) start = max(0, min(start, spk_mat.shape[1])) end = max(start, min(end, spk_mat.shape[1])) else: start, end = 0, spk_mat.shape[1] spk_mat_view = spk_mat[:, start:end] n_samples = end - start # Crop arrays whose time axis matches the raster resolution. # Arrays with a different time resolution (e.g. GPLVM binned output) # are cropped using proportional index conversion so the correct # time window is displayed. raster_T = spk_mat.shape[1] def _rescaled_range(arr_len): """Convert raster-resolution [start, end) to indices for an array of length arr_len.""" if arr_len == raster_T: return start, end scale = arr_len / raster_T s = max(0, min(int(round(start * scale)), arr_len)) e = max(s, min(int(round(end * scale)), arr_len)) return s, e def _crop_1d(arr): if arr is None: return None s, e = _rescaled_range(len(arr)) return arr[s:e] def _crop_2d(arr): if arr is None: return None s, e = _rescaled_range(arr.shape[-1]) return arr[:, s:e] pop_rate_view = _crop_1d(pop_rate) fr_rates_view = _crop_2d(fr_rates) cont_prob_view = _crop_1d(cont_prob) model_states_view = _crop_2d(model_states) # Burst peaks: keep those inside range, shift to window coords burst_times_view = None burst_colors_times_view = None if burst_times is not None: mask = (burst_times >= start) & (burst_times <= end) burst_times_view = np.round(burst_times[mask] - start).astype(int) if burst_colors is not None: burst_colors_times_view = np.asarray(burst_colors)[mask] # Burst edges: clip to range, skip fully-outside bursts burst_edges_view = None burst_colors_edges_view = None if burst_edges is not None: clipped = [] edge_color_list = [] colors_arr = np.asarray(burst_colors) if burst_colors is not None else None for idx, (t0, t1) in enumerate(burst_edges): if t1 < start or t0 > end: continue clipped.append((max(t0, start) - start, min(t1, end) - start)) if colors_arr is not None: edge_color_list.append(colors_arr[idx]) burst_edges_view = clipped if clipped else None if colors_arr is not None and edge_color_list: burst_colors_edges_view = edge_color_list # ------------------------------------------------------------------ # 4. Create figure with two-column GridSpec (panels + colorbars) # ------------------------------------------------------------------ from matplotlib.gridspec import GridSpec external_axes = axes is not None if external_axes: if len(axes) != n_panels: raise ValueError( f"Expected {n_panels} (ax, cbar_ax) pairs for the enabled " f"panels {panels}, got {len(axes)}." ) main_axes = [pair[0] for pair in axes] cbar_axes = [pair[1] for pair in axes] fig = main_axes[0].figure else: default_ratio_map = { "raster": 2, "pop_rate": 1, "fr_heatmap": 2, "model_states": 2, } default_ratios = [default_ratio_map[p] for p in panels] default_height = sum(default_ratios) * 1.7 default_figsize = (12, default_height) fig = plt.figure(figsize=figsize or default_figsize) gs = GridSpec( n_panels, 2, figure=fig, height_ratios=height_ratios or default_ratios, width_ratios=[1, 0.02], wspace=0.03, ) # Create main panel axes with shared x main_axes = [] for i in range(n_panels): ax = fig.add_subplot(gs[i, 0], sharex=main_axes[0] if main_axes else None) main_axes.append(ax) # Create colorbar axes (one per row, hidden by default) cbar_axes = [] for i in range(n_panels): cax = fig.add_subplot(gs[i, 1]) cax.axis("off") cbar_axes.append(cax) panel_axes = dict(zip(panels, main_axes)) panel_cbar = dict(zip(panels, cbar_axes)) # ------------------------------------------------------------------ # 5. Raster panel # ------------------------------------------------------------------ if "raster" in panel_axes: ax = panel_axes["raster"] if raster_style == "imshow": im = ax.imshow( spk_mat_view, aspect="auto", cmap="Greys", vmin=0, vmax=raster_vmax, origin="upper", ) cax = panel_cbar["raster"] cax.axis("on") fig.colorbar(im, cax=cax, label="Spike Count") _apply_font_size(cax, font_size) else: spike_times_list = [ np.where(spk_mat_view[i, :] >= 1)[0] for i in range(spk_mat_view.shape[0]) ] ax.eventplot(spike_times_list, colors="black", linewidths=0.5) ax.set_ylim([-0.5, spk_mat_view.shape[0] - 0.5]) if raster_style != "imshow": ax.invert_yaxis() ax.set_ylabel("Unit") if raster_style == "imshow": _style_axes_heatmap(ax) else: _style_axes(ax) _apply_font_size(ax, font_size) # ------------------------------------------------------------------ # 6. Population rate panel (+ cont_prob overlay) # ------------------------------------------------------------------ if "pop_rate" in panel_axes: ax = panel_axes["pop_rate"] if pop_rate_view is not None: # Convert from spikes/bin (summed over units) to Hz/unit n_units = spk_mat.shape[0] bin_duration_s = raster_bin_size_ms / 1000.0 pop_rate_hz = pop_rate_view / (bin_duration_s * n_units) x_pop = np.linspace(0, n_samples, len(pop_rate_hz), endpoint=False) ax.plot(x_pop, pop_rate_hz, color="blue", label="Pop. rate") ax.set_ylabel("Pop. rate (Hz/unit)") if cont_prob_view is not None: ax2 = ax.twinx() x_cont = np.linspace(0, n_samples, len(cont_prob_view), endpoint=False) ax2.plot(x_cont, cont_prob_view, color="red", alpha=0.7, label="P(cont.)") ax2.set_ylabel("P(continuous)", color="red") ax2.tick_params(axis="y", labelcolor="red") _apply_font_size(ax2, font_size) # Burst overlays if burst_times_view is not None and pop_rate_view is not None: # Scale burst times from raster-bin coords to pop_rate coords scale = len(pop_rate_hz) / n_samples bt_scaled = np.round(burst_times_view * scale).astype(int) valid = bt_scaled < len(pop_rate_hz) bt_scaled = bt_scaled[valid] bt_plot = burst_times_view[valid] # x position in raster coords if burst_colors_times_view is not None: ax.scatter( bt_plot, pop_rate_hz[bt_scaled], c=burst_colors_times_view[valid], s=40, zorder=9, ) else: ax.scatter(bt_plot, pop_rate_hz[bt_scaled], c="k", zorder=9) if burst_edges_view is not None: for i, (t0, t1) in enumerate(burst_edges_view): color = ( burst_colors_edges_view[i] if burst_colors_edges_view is not None else "b" ) ax.axvspan(t0, t1, color=color, alpha=0.2) _style_axes(ax) _apply_font_size(ax, font_size) # ------------------------------------------------------------------ # 7. Firing-rate heatmap panel # ------------------------------------------------------------------ if "fr_heatmap" in panel_axes: ax = panel_axes["fr_heatmap"] fr_extent = None if fr_rates_view.shape[1] != n_samples: fr_extent = (0, n_samples, 0, fr_rates_view.shape[0]) # Auto-clip vmax to percentile when not explicitly set vmax_fr = vmax_heatmap if vmax_fr is None and heatmap_clip_pct is not None: vmax_fr = np.max(fr_rates_view) * (heatmap_clip_pct / 100.0) plot_heatmap( fr_rates_view, ax=ax, norm=norm_heatmap, vmin=vmin_heatmap, vmax=vmax_fr, origin="upper", extent=fr_extent, xlabel="Time (s)", ylabel="Unit", show_colorbar=False, font_size=font_size, ) cax = panel_cbar["fr_heatmap"] cax.axis("on") cb_label = "Norm. Rate (Hz)" if norm_heatmap else "Rate (Hz)" fig.colorbar(ax.images[0], cax=cax, label=cb_label) _apply_font_size(cax, font_size) # ------------------------------------------------------------------ # 8. Model states panel # ------------------------------------------------------------------ if "model_states" in panel_axes: ax = panel_axes["model_states"] ms_extent = None if model_states_view.shape[1] != n_samples: ms_extent = (0, n_samples, 0, model_states_view.shape[0]) plot_heatmap( model_states_view, ax=ax, cmap=model_states_cmap, vmin=model_states_vmin, vmax=model_states_vmax, origin="lower", extent=ms_extent, xlabel="Time (s)", ylabel="State", show_colorbar=False, font_size=font_size, ) cax = panel_cbar["model_states"] cax.axis("on") fig.colorbar(ax.images[0], cax=cax, label="Probability") _apply_font_size(cax, font_size) # ------------------------------------------------------------------ # 9. X-axis formatting # ------------------------------------------------------------------ # Set x limits on all axes (sharex propagates when axes are created # internally, but external axes may not be linked) for ax in main_axes: ax.set_xlim(0, n_samples) # Hide tick labels on all but the bottom panel for ax in main_axes[:-1]: plt.setp(ax.get_xticklabels(), visible=False) ax.set_xlabel("") # Choose x-axis unit: ms when the plotted range is < 1000 ms, else seconds. # Bin 0 in the view corresponds to ms = sd.start_time + start * raster_bin_size_ms. ms_offset = sd.start_time + start * raster_bin_size_ms range_ms = n_samples * raster_bin_size_ms use_ms = range_ms < 1000.0 if use_ms: xlabel = "Time (ms)" if absolute_xticks: formatter = mticker.FuncFormatter( lambda x, _: f"{x * raster_bin_size_ms + ms_offset:.1f}" ) else: formatter = mticker.FuncFormatter( lambda x, _: f"{x * raster_bin_size_ms:.1f}" ) else: xlabel = "Time (s)" bin_to_s = raster_bin_size_ms / 1000.0 if absolute_xticks: s_offset = ms_offset / 1000.0 formatter = mticker.FuncFormatter( lambda x, _: f"{x * bin_to_s + s_offset:.1f}" ) else: formatter = mticker.FuncFormatter(lambda x, _: f"{x * bin_to_s:.1f}") main_axes[-1].set_xlabel(xlabel) _apply_font_size(main_axes[-1], font_size) for ax in main_axes: ax.xaxis.set_major_formatter(formatter) # ------------------------------------------------------------------ # 10. Output # ------------------------------------------------------------------ if not external_axes: gs.tight_layout(fig) if save_path is not None: fig.savefig(save_path, bbox_inches="tight") plt.close(fig) elif show: plt.show() return fig
[docs] def plot_spatial_network( ax, positions, matrix, edge_threshold=None, top_pct=None, node_size_range=(2, 20), node_cmap="viridis", node_linewidth=0.2, edge_color="red", edge_linewidth=0.6, edge_alpha_range=(0.15, 1.0), scale_bar_um=500, font_size=None, ): """Plot a spatial network of units on their MEA positions. Units are drawn as scatter markers sized by their mean pairwise value (row mean excluding diagonal) and coloured by the same metric. Edges are drawn between unit pairs whose matrix value exceeds a threshold or falls in the top percentile, with alpha proportional to edge strength. Exactly one of edge_threshold or top_pct must be provided. Parameters: ax (matplotlib.axes.Axes): Target axes. positions (np.ndarray): Unit positions, shape ``(N, 2)`` with columns ``[x, y]`` in micrometres. matrix (np.ndarray): Symmetric ``(N, N)`` pairwise matrix (e.g. correlation, STTC). Diagonal values are ignored. edge_threshold (float or None): Minimum matrix value to draw an edge. top_pct (float or None): Percentage of top edges to draw (e.g. ``1.0`` draws the top 1 %). node_size_range (tuple): ``(min_size, max_size)`` in points squared for the scatter markers. node_cmap (str): Matplotlib colourmap for node colour. node_linewidth (float): Outline width of node markers. edge_color (str): Colour for network edges. edge_linewidth (float): Line width for network edges. edge_alpha_range (tuple): ``(min_alpha, max_alpha)`` for edge transparency, scaled by edge strength. scale_bar_um (float): Length of the spatial scale bar in micrometres. Set to 0 or None to omit. font_size (int or None): Font size for scale bar label. If None, uses the current rcParams default. Returns: scatter (matplotlib.collections.PathCollection): The scatter artist, useful for adding a colorbar. """ _import_matplotlib() if edge_threshold is None and top_pct is None: raise ValueError("Provide either edge_threshold or top_pct.") if edge_threshold is not None and top_pct is not None: raise ValueError("Provide only one of edge_threshold or top_pct.") positions = np.asarray(positions) if positions.ndim != 2 or positions.shape[1] < 2: raise ValueError( f"positions must be 2D with at least 2 columns (N, 2+), " f"got shape {positions.shape}" ) matrix = np.asarray(matrix, dtype=float) n = len(positions) if matrix.shape != (n, n): raise ValueError( f"matrix shape {matrix.shape} does not match " f"positions length {n}." ) x, y = positions[:, 0], positions[:, 1] # Mean value per unit (excluding diagonal) mat = matrix.copy() np.fill_diagonal(mat, np.nan) mean_val = np.nanmean(mat, axis=1) # Upper-triangle values for edge selection triu_i, triu_j = np.triu_indices(n, k=1) vals = mat[triu_i, triu_j] valid = np.isfinite(vals) triu_i, triu_j, vals = triu_i[valid], triu_j[valid], vals[valid] # Determine threshold if top_pct is not None: threshold = np.percentile(vals, 100 - top_pct) else: threshold = edge_threshold edge_mask = vals >= threshold edge_vals = vals[edge_mask] edge_i = triu_i[edge_mask] edge_j = triu_j[edge_mask] # Draw edges with alpha proportional to strength alpha_lo, alpha_hi = edge_alpha_range if len(edge_vals) > 0: e_min, e_max = threshold, np.max(edge_vals) if e_max > e_min: edge_alpha = alpha_lo + (alpha_hi - alpha_lo) * (edge_vals - e_min) / ( e_max - e_min ) else: edge_alpha = np.full_like(edge_vals, (alpha_lo + alpha_hi) / 2) edge_alpha = np.clip(edge_alpha, alpha_lo, alpha_hi) from matplotlib.collections import LineCollection segments = np.array( [ [[x[edge_i[k]], y[edge_i[k]]], [x[edge_j[k]], y[edge_j[k]]]] for k in range(len(edge_i)) ] ) from matplotlib.colors import to_rgb colors = [(*to_rgb(edge_color), a) for a in edge_alpha] lc = LineCollection( segments, colors=colors, linewidths=edge_linewidth, zorder=2 ) ax.add_collection(lc) # Draw nodes sized by mean value size_min, size_max = node_size_range mc_min, mc_max = np.nanmin(mean_val), np.nanmax(mean_val) if mc_max > mc_min: sizes = size_min + (size_max - size_min) * (mean_val - mc_min) / ( mc_max - mc_min ) else: sizes = np.full_like(mean_val, (size_min + size_max) / 2) sc = ax.scatter( x, y, s=sizes, c=mean_val, cmap=node_cmap, edgecolors="face", linewidths=node_linewidth, zorder=1, ) ax.set_aspect("equal") ax.axis("off") # Scale bar if scale_bar_um: xlim = ax.get_xlim() ylim = ax.get_ylim() bar_x_end = xlim[1] - (xlim[1] - xlim[0]) * 0.02 bar_x_start = bar_x_end - scale_bar_um bar_y = ylim[0] + (ylim[1] - ylim[0]) * 0.02 ax.plot( [bar_x_start, bar_x_end], [bar_y, bar_y], color="black", linewidth=2.0, clip_on=False, solid_capstyle="butt", ) fs = font_size or 8 ax.text( (bar_x_start + bar_x_end) / 2, bar_y - (ylim[1] - ylim[0]) * 0.03, f"{scale_bar_um}\u00b5m", ha="center", va="top", fontsize=fs, ) return sc