Source code for spikelab.spikedata.spikeslicestack

import warnings

import numpy as np

__all__ = ["SpikeSliceStack"]

from .pairwise import PairwiseCompMatrix, PairwiseCompMatrixStack
from .spikedata import SpikeData
from concurrent.futures import ThreadPoolExecutor

from .utils import (
    _validate_time_start_to_end,
    _get_attr,
    get_sttc,
    compute_cross_correlation_with_lag,
    _resolve_n_jobs,
    _slice_to_slice_similarity_matrix,
)


[docs] class SpikeSliceStack: """A list of SpikeData objects, one per slice, with spike-based comparison capabilities. U is units (neurons) and S is slices (bursts, events, etc). Construct from either a single SpikeData with time specifications, or directly from a pre-built list of SpikeData objects. Parameters: data_obj (SpikeData or None): A SpikeData object to slice. Provide either this or spike_stack, not both. times_start_to_end (list or None): Each entry is a tuple (start, end) representing the start and end times of a desired slice. Each tuple must have the same duration. time_peaks (list or None): List of times as int or float where there is a burst peak or stimulation event. Must be paired with time_bounds. Alternative to times_start_to_end. time_bounds (tuple or None): Single tuple (left_bound, right_bound). For example, (250, 500) means 250 ms before peak and 500 ms after peak. Must be paired with time_peaks. spike_stack (list or None): List of SpikeData objects, one per slice. All must have the same number of units. Spike times must be relative to the slice (0-based or event-centered via start_time), not absolute recording times. Provide either this or data_obj, not both. neuron_attributes (list or None): List of attribute dicts, one per unit. If None, inherited from data_obj when available. drop_slice_attributes (bool): If True (default), neuron_attributes are removed from individual SpikeData slices after construction. The shared copy is stored at neuron_attributes. This avoids duplicating large per-unit data (e.g. waveform templates) across every slice. Set to False to keep per-slice attributes. Attributes: spike_stack (list): List of SpikeData objects, one per slice. Spike times are relative to the slice window. For 0-based slices, times run from 0 to duration. For event-centered slices, times run from -pre_ms to +post_ms with t=0 at the event. Use self.times for absolute recording time positions. times (list): List of (start, end) time bounds for each slice in absolute recording time, sorted chronologically. Length equals S. Example: [(100, 350), (500, 750), (1000, 1250)]. N (int): Number of units. neuron_attributes (list or None): List of attribute dicts, one per unit. None if not provided. When drop_slice_attributes is True (default), this is the only copy and individual slices will have neuron_attributes set to None. """
[docs] def __init__( self, data_obj=None, times_start_to_end=None, time_peaks=None, time_bounds=None, spike_stack=None, neuron_attributes=None, drop_slice_attributes=True, ): if data_obj is None and spike_stack is None: raise TypeError( "Must input either a SpikeData as data_obj (option 1) or spike_stack (option 2)" ) if data_obj is not None and spike_stack is not None: warnings.warn( "User input both data_obj and spike_stack. " "Ignoring data_obj and using spike_stack instead.", UserWarning, ) data_obj = None # Option 1: Using data_obj if data_obj is not None: if not isinstance(data_obj, SpikeData): raise TypeError("data_obj must be a SpikeData object") if times_start_to_end is None: if time_peaks is None or time_bounds is None: raise ValueError( "Must provide either times_start_to_end or " "both time_peaks and time_bounds" ) if not isinstance(time_bounds, tuple) or len(time_bounds) != 2: raise TypeError( "time_bounds must be a tuple of (before, after) durations" ) before, after = time_bounds time_peaks = sorted(time_peaks) times_start_to_end = [] for t in time_peaks: times_start_to_end.append((t - before, t + after)) rec_range = ( data_obj.start_time, data_obj.start_time + data_obj.length, ) times_start_to_end = _validate_time_start_to_end( times_start_to_end, recording_range=rec_range ) self.times = times_start_to_end self.spike_stack = [] if time_peaks is not None: # Event-centered: shift_to=peak so t=0 is the event for peak, (start, end) in zip(time_peaks, times_start_to_end): self.spike_stack.append(data_obj.subtime(start, end, shift_to=peak)) else: # Standard: shift_to=start so t=0 is the window start for start, end in times_start_to_end: self.spike_stack.append(data_obj.subtime(start, end)) if neuron_attributes is None: neuron_attributes = data_obj.neuron_attributes # Option 2: Using spike_stack directly if spike_stack is not None: if not isinstance(spike_stack, list): raise TypeError("spike_stack must be a list of SpikeData objects") for s in spike_stack: if not isinstance(s, SpikeData): raise TypeError("spike_stack must be a list of SpikeData objects") if len(spike_stack) == 0: raise ValueError("spike_stack must not be empty") N = spike_stack[0].N for s in spike_stack: if s.N != N: raise ValueError( "All SpikeData objects in spike_stack must have the same number of units" ) if times_start_to_end is None: t = 0.0 times_start_to_end = [] for s in spike_stack: times_start_to_end.append((t, t + s.length)) t += s.length else: warn_neg = spike_stack[0].start_time >= 0 times_start_to_end = _validate_time_start_to_end( times_start_to_end, warn_negative_start=warn_neg ) if len(times_start_to_end) != len(spike_stack): raise ValueError( "times_start_to_end must have the same length as spike_stack" ) self.spike_stack = list(spike_stack) self.times = times_start_to_end # Validate that all slices share the same ``start_time`` # convention. Mixing 0-based slices (start_time=0) with # event-centered slices (start_time=-pre) — or two event- # centered stacks with different ``pre`` values — silently # mis-aligns downstream raster outputs. Require uniformity. if len(self.spike_stack) > 1: start_times = [sd.start_time for sd in self.spike_stack] if len(set(start_times)) > 1: raise ValueError( "All slices in spike_stack must share the same " f"start_time convention; got {sorted(set(start_times))}. " "Mixing 0-based and event-centered slices (or two " "event-centered stacks with different pre-windows) " "would silently mis-align downstream raster outputs." ) # Validate that spike times are consistent with the slice # duration. Spike times must be relative to the slice (0-based # or event-centered), not absolute recording times. for i, (sd, (start, end)) in enumerate(zip(self.spike_stack, self.times)): duration = end - start expected_start = sd.start_time expected_end = sd.start_time + duration for u, train in enumerate(sd.train): if len(train) == 0: continue if train[0] < expected_start or train[-1] > expected_end: raise ValueError( f"Slice {i}, unit {u}: spike times " f"[{train[0]:.1f}, {train[-1]:.1f}] ms fall outside " f"expected range [{expected_start:.1f}, " f"{expected_end:.1f}] ms. " "Spike times must be relative to the slice (0-based " "or event-centered), not absolute recording times." ) self.N = self.spike_stack[0].N self.neuron_attributes = None if neuron_attributes is not None: self.neuron_attributes = neuron_attributes.copy() if len(self.neuron_attributes) != self.N: raise ValueError( f"neuron_attributes has {len(self.neuron_attributes)} items " f"but spike_stack has {self.N} units" ) # Strip per-slice neuron_attributes to avoid duplicating large data # (e.g. waveform templates) across every slice. if drop_slice_attributes: for sd in self.spike_stack: sd.neuron_attributes = None
def __repr__(self) -> str: S = len(self.spike_stack) return f"SpikeSliceStack(N={self.N}, S={S})" def __len__(self) -> int: return len(self.spike_stack) def __iter__(self): return iter(self.spike_stack)
[docs] def subslice(self, slices): """Extract a subset of slices from the spike stack. Parameters: slices (int or list): Slice index or list of slice indices to extract. Returns: result (SpikeSliceStack): New SpikeSliceStack containing only the specified slices. Shape changes from S to S_trimmed. All units and neuron_attributes are carried over. """ S = len(self.spike_stack) if isinstance(slices, int): slices = [slices] for s in slices: if s >= S or s < -S: raise ValueError(f"One or more slice indices out of range for S={S}") slices = sorted(slices) new_spike_stack = [] new_times = [] for s in slices: new_spike_stack.append(self.spike_stack[s]) new_times.append(self.times[s]) return SpikeSliceStack( spike_stack=new_spike_stack, times_start_to_end=new_times, neuron_attributes=self.neuron_attributes, )
[docs] def subset(self, units, by=None, preserve_order=False): """Extract a subset of units from every slice in the spike stack. Parameters: units (int, str, or list): Unit indices to extract. If by is None, must be int(s). If by is set, values to match in neuron_attributes. by (str or None): If set, select units by this neuron_attribute key instead of by index. preserve_order (bool): When False (default), output is sorted ascending by index — consistent with the other SpikeLab data classes. When True, output respects the order of the input ``units`` list. Duplicates are deduplicated either way. Returns: result (SpikeSliceStack): New SpikeSliceStack containing only the specified units across all slices. All slices and neuron_attributes are carried over. Notes: - If IDs are not unique (when using by), every matching neuron is included. """ if isinstance(units, (int, str)): units = [units] # Resolve which indices will be kept so we can update neuron_attributes if by is not None: # ``by`` resolves to whichever units carry the matching # attribute, in self.train order — caller-supplied order # cannot be honoured because the value list has no # positional correspondence to unit indices. if self.neuron_attributes is None: raise ValueError("can't use `by` without `neuron_attributes`") _missing = object() unit_set = set(units) kept_indices = [] for i in range(self.N): if _get_attr(self.neuron_attributes[i], by, _missing) in unit_set: kept_indices.append(i) else: for u in units: ui = int(u) if ui < 0 or ui >= self.N: raise ValueError( f"Unit index {ui} out of range for {self.N} units." ) if preserve_order: seen: set = set() ordered = [] for u in units: ui = int(u) if ui not in seen: seen.add(ui) ordered.append(ui) kept_indices = ordered else: kept_indices = sorted({int(u) for u in units}) new_spike_stack = [] for sd in self.spike_stack: # Forward preserve_order so per-slice subsets agree with # the SpikeSliceStack-level ordering decision. new_spike_stack.append(sd.subset(kept_indices, preserve_order=True)) new_neuron_attributes = None if self.neuron_attributes is not None: new_neuron_attributes = [] for i in kept_indices: new_neuron_attributes.append(self.neuron_attributes[i]) return SpikeSliceStack( spike_stack=new_spike_stack, times_start_to_end=self.times, neuron_attributes=new_neuron_attributes, )
[docs] def subtime_by_index(self, start_idx, end_idx): """Trim each slice to a sub-window specified by millisecond indices. Indices are measured from the start of each slice (1 index = 1 ms). Trims along the time axis while preserving all slices and units. Parameters: start_idx (int): Start index in ms from slice start (inclusive). Supports negative indexing. end_idx (int): End index in ms from slice start (exclusive). Supports negative indexing. Returns: result (SpikeSliceStack): New SpikeSliceStack where each slice is trimmed to the corresponding absolute time window. Absolute spike times are preserved (not shifted). self.times is updated to reflect the new absolute time bounds. Raises: ValueError: If the underlying slice duration (``times[0][1] - times[0][0]``) is not an integer number of milliseconds. Use ``SpikeData.subtime()`` with explicit ms bounds for non-integer windows. Notes: - Indices are relative to each slice's own start (index 0 = slice start ms). They are converted to absolute recording times internally before trimming. - Original absolute timestamps are preserved. To get shifted-to-zero timestamps, create a new SpikeSliceStack. - All slices and neuron_attributes are carried over from the original. """ slice_duration_ms = self.times[0][1] - self.times[0][0] # 1 index = 1 ms; non-integer durations would silently drop the # sub-ms tail. Push the rounding decision back to the caller. if abs(slice_duration_ms - round(slice_duration_ms)) > 1e-9: raise ValueError( f"slice_duration_ms ({slice_duration_ms}) must be an " f"integer number of milliseconds for subtime_by_index " f"(1 index = 1 ms). For non-integer windows, call " f"SpikeData.subtime() with explicit ms bounds, or " f"reconstruct the SpikeSliceStack with an integer slice " f"duration." ) T = int(round(slice_duration_ms)) if start_idx < 0: start_idx += T if end_idx < 0: end_idx += T if start_idx < 0 or start_idx >= T: raise ValueError(f"start_idx {start_idx} out of range for T={T}") if end_idx <= start_idx or end_idx > T: raise ValueError(f"end_idx {end_idx} invalid for T={T}") new_spike_stack = [] new_times = [] for sd, t in zip(self.spike_stack, self.times): new_spike_stack.append( sd.subtime( sd.start_time + float(start_idx), sd.start_time + float(end_idx) ) ) abs_start = t[0] + float(start_idx) abs_end = t[0] + float(end_idx) new_times.append((abs_start, abs_end)) return SpikeSliceStack( spike_stack=new_spike_stack, times_start_to_end=new_times, neuron_attributes=self.neuron_attributes, )
[docs] def to_raster_array(self, bin_size=1.0, absolute_times=False): """Convert the spike stack into a 3D raster array of shape (N, T, S). Each slice is rasterized with the given bin size, producing a spike count matrix where entry (n, t, s) is the number of spikes unit n fired in time bin t of slice s. Parameters: bin_size (float): Time bin size in ms (default 1.0). absolute_times (bool): If False (default), time bin 0 corresponds to the start of each slice (0-based). If True, each slice's spikes are offset by its absolute start time from self.times, so bin indices reflect the original recording position. The T dimension is sized to cover the full time span from the earliest slice start to the latest slice end. **Caution:** this can produce very large arrays when the recording span is long and bin_size is small. Returns: raster_stack (np.ndarray): 3D array of shape (N, T, S) with non-negative integer spike counts. When absolute_times is True, T covers the full recording span and all slices share the same time axis. """ if bin_size <= 0: raise ValueError(f"bin_size must be > 0, got {bin_size}") if not absolute_times: dense_list = [] for sd in self.spike_stack: # Spike times are relative to each slice (0-based or event-centered). # sparse_raster handles start_time internally. dense_list.append(sd.sparse_raster(bin_size=bin_size).toarray()) return np.stack(dense_list, axis=2) # Absolute times: offset each slice by its start time so bin indices # reflect original recording position. All slices share the same # time axis spanning [min(start), max(end)]. global_start = min(start for start, _ in self.times) global_end = max(end for _, end in self.times) total_bins = int(np.ceil((global_end - global_start) / bin_size)) raster_stack = np.zeros((self.N, total_bins, len(self.spike_stack)), dtype=int) for s_idx, (sd, (start, _)) in enumerate(zip(self.spike_stack, self.times)): offset = start - global_start r = sd.sparse_raster(bin_size=bin_size, time_offset=offset).toarray() # Clamp r.shape[1] to total_bins. Both shapes come from # independent np.ceil calls on float arithmetic, so a # ULP-level difference between (global_end - global_start) # and (slice_length + offset) can leave r one bin larger # than the buffer — the unclamped assignment would raise # a broadcasting error. The reverse case (r smaller than # total_bins) is benign — trailing bins stay zero. n = min(r.shape[1], total_bins) raster_stack[:, :n, s_idx] = r[:, :n] return raster_stack
[docs] def baseline_normalized_raster( self, bin_size, baseline_window_ms, *, mode="subtract" ): """Per-slice raster normalized against a per-slice baseline rate. Wraps ``to_raster_array(bin_size)`` and converts each bin into a baseline-normalized response value. The baseline rate is computed from spikes inside ``baseline_window_ms`` (in milliseconds relative to each slice's time origin) and projected to each bin via ``rate * bin_size``. Output shape matches the raster: ``(U, T, S)``. Parameters: bin_size (float): Raster bin size in milliseconds. Passed to ``to_raster_array``. baseline_window_ms (tuple[float, float]): ``(start_ms, end_ms)`` window relative to slice origin used to estimate the per-slice baseline rate. mode (str): Normalization mode: - ``"subtract"`` (default) — counts above baseline expectation. - ``"ratio"`` — counts / expected_counts (NaN where expected is 0). - ``"zscore"`` — (counts - expected) / sqrt(expected), the Poisson z-score (NaN where expected is 0). Returns: normalized (np.ndarray): Float array of shape ``(U, T, S)``. Notes: - Baseline window is validated against each slice's actual time range. ``ValueError`` if any slice doesn't contain it. - For uniform-bin response counts (no normalization), use ``to_raster_array(bin_size)`` directly; this method adds the per-slice baseline correction on top. """ if mode not in ("subtract", "ratio", "zscore"): raise ValueError( f"mode must be 'subtract', 'ratio', or 'zscore', got {mode!r}" ) if ( not isinstance(baseline_window_ms, (tuple, list)) or len(baseline_window_ms) != 2 ): raise ValueError("baseline_window_ms must be a (start_ms, end_ms) tuple.") b_start, b_end = float(baseline_window_ms[0]), float(baseline_window_ms[1]) if b_end <= b_start: raise ValueError("baseline_window_ms end must be greater than start.") for s_idx, sd in enumerate(self.spike_stack): slice_start = sd.start_time slice_end = sd.start_time + (self.times[s_idx][1] - self.times[s_idx][0]) if b_start < slice_start - 1e-9 or b_end > slice_end + 1e-9: raise ValueError( f"baseline_window_ms ({b_start}, {b_end}) falls outside " f"slice {s_idx} time range [{slice_start}, {slice_end}]." ) counts = self.to_raster_array(bin_size=bin_size).astype(float) # (U, T, S) baseline_width = b_end - b_start baseline_counts = np.zeros((self.N, len(self.spike_stack)), dtype=float) for s_idx, sd in enumerate(self.spike_stack): for u in range(self.N): train = np.asarray(sd.train[u], dtype=float) if train.size == 0: continue baseline_counts[u, s_idx] = float( np.sum((train >= b_start) & (train < b_end)) ) # Per-slice expected counts per bin = rate * bin_size expected_per_bin = baseline_counts * (bin_size / baseline_width) # (U, S) expected = expected_per_bin[:, np.newaxis, :] # broadcasts over T if mode == "subtract": return counts - expected if mode == "ratio": with np.errstate(divide="ignore", invalid="ignore"): return np.where(expected > 0, counts / expected, np.nan) with np.errstate(divide="ignore", invalid="ignore"): return np.where( expected > 0, (counts - expected) / np.sqrt(expected), np.nan )
[docs] def responsive_units( self, bin_size, baseline_window_ms, *, response_window_ms=None, z_threshold=2.0, aggregator="mean", ): """Identify units that show a significant evoked response. Builds the Poisson-z-scored baseline-normalized raster, optionally restricts to a response time window, aggregates across slices (mean or max), and returns a unit mask where any time bin's aggregated z-score exceeds ``z_threshold``. Parameters: bin_size (float): Raster bin size in milliseconds. baseline_window_ms (tuple[float, float]): Baseline window ``(start_ms, end_ms)`` relative to slice origin used to estimate the per-slice baseline rate. response_window_ms (tuple[float, float] or None): Optional response window ``(start_ms, end_ms)`` relative to slice origin. When None (default), the full slice is searched. z_threshold (float): Z-score threshold (default 2.0). aggregator (str): How to combine z-scores across slices before thresholding. ``"mean"`` (default) or ``"max"``. Returns: mask (np.ndarray): Boolean array of shape ``(U,)``. True for responsive units. Notes: - Units with no baseline spikes in any slice are flagged non-responsive (z-scores are NaN there). """ if aggregator not in ("mean", "max"): raise ValueError(f"aggregator must be 'mean' or 'max', got {aggregator!r}") z = self.baseline_normalized_raster( bin_size, baseline_window_ms, mode="zscore" ) # (U, T, S) if response_window_ms is not None: if ( not isinstance(response_window_ms, (tuple, list)) or len(response_window_ms) != 2 ): raise ValueError( "response_window_ms must be a (start_ms, end_ms) tuple or None." ) r_start = float(response_window_ms[0]) r_end = float(response_window_ms[1]) if r_end <= r_start: raise ValueError("response_window_ms end must be greater than start.") sd0 = self.spike_stack[0] bin_start = int(np.floor((r_start - sd0.start_time) / bin_size)) bin_end = int(np.ceil((r_end - sd0.start_time) / bin_size)) bin_start = max(0, bin_start) bin_end = min(z.shape[1], bin_end) if bin_end <= bin_start: raise ValueError( f"response_window_ms ({r_start}, {r_end}) maps to an empty " f"bin range; check it against the slice duration and bin_size." ) z = z[:, bin_start:bin_end, :] if aggregator == "mean": agg = np.nanmean(z, axis=2) else: agg = np.nanmax(z, axis=2) with np.errstate(invalid="ignore"): return np.any(agg > z_threshold, axis=1)
[docs] def decode_slice_labels( self, labels, response_window_ms, *, bin_size, baseline_window_ms=None, classifier="ridge", cv="loo", classifier_kwargs=None, random_state=None, ): """Decode per-slice labels (e.g. stim identity) from population responses. Builds an ``(S, U)`` feature matrix by summing per-unit spike counts in ``response_window_ms`` (optionally with baseline subtraction) and runs cross-validated classifier decoding via :func:`spikelab.spikedata.decoding.cross_validated_decode`. Parameters: labels (array-like): Per-slice labels of length ``S`` (e.g. stim electrode index, treatment category). response_window_ms (tuple[float, float]): Window relative to slice origin over which response counts are summed. bin_size (float): Raster bin size in ms. baseline_window_ms (tuple[float, float] or None): Optional per-slice baseline window; when provided, counts are baseline-subtracted via ``baseline_normalized_raster``. classifier (str): ``"ridge"`` (default), ``"mlp"``, or ``"random_forest"``. cv (str or int): ``"loo"`` (default) or int ``>= 2``. classifier_kwargs (dict or None): Forwarded to the sklearn classifier constructor. random_state (int or None): Reproducibility seed. Returns: result (dict): Same shape as :func:`spikelab.spikedata.decoding.cross_validated_decode` — ``accuracy``, ``predictions``, ``true_labels``, ``confusion_matrix``, ``per_fold_accuracy``, ``classes``, ``classifier_name``. Notes: - Requires ``scikit-learn`` (optional dependency). - For decoding from the full ``(U, T)`` raster (not just summed counts), call ``decoding.cross_validated_decode`` directly on ``self.to_raster_array(bin_size).reshape(U * T, S).T``. """ from .decoding import cross_validated_decode if ( not isinstance(response_window_ms, (tuple, list)) or len(response_window_ms) != 2 ): raise ValueError("response_window_ms must be a (start_ms, end_ms) tuple.") r_start = float(response_window_ms[0]) r_end = float(response_window_ms[1]) if r_end <= r_start: raise ValueError("response_window_ms end must be greater than start.") if baseline_window_ms is None: counts = self.to_raster_array(bin_size=bin_size).astype(float) else: counts = self.baseline_normalized_raster( bin_size, baseline_window_ms, mode="subtract" ) sd0 = self.spike_stack[0] bin_start = int(np.floor((r_start - sd0.start_time) / bin_size)) bin_end = int(np.ceil((r_end - sd0.start_time) / bin_size)) bin_start = max(0, bin_start) bin_end = min(counts.shape[1], bin_end) if bin_end <= bin_start: raise ValueError( f"response_window_ms ({r_start}, {r_end}) maps to an empty " f"bin range; check it against the slice duration and bin_size." ) # (U, S) per-unit summed response amplitude per slice -> (S, U) features X = counts[:, bin_start:bin_end, :].sum(axis=1).T labels = np.asarray(labels).ravel() if len(labels) != X.shape[0]: raise ValueError( f"labels must have length S={X.shape[0]}; got {len(labels)}." ) return cross_validated_decode( X, labels, classifier=classifier, cv=cv, classifier_kwargs=classifier_kwargs, random_state=random_state, )
[docs] def group_pair_similarity( self, stim_labels, *, metric="cosine", bin_size=1.0, slice_indices=None, ): """Pairwise similarity between mean response vectors for each stim class. For each unique stimulus label, averages the per-slice ``(U, T)`` raster across all slices that share that label, then computes a ``(K, K)`` similarity matrix between the resulting per-stim mean response vectors. Lets you ask: "how distinguishable are responses to different stims?". Parameters: stim_labels (array-like): Per-slice stim label of length ``S``. metric (str): ``"cosine"`` (default), ``"pearson"``, ``"euclidean"`` (distance), or ``"cross_entropy"``. bin_size (float): Raster bin size in ms (default 1.0). slice_indices (array-like or None): Optional subset of slice indices to use (e.g. an "early-cycle" or "late-cycle" window). When None (default), uses all slices. Returns: sim (PairwiseCompMatrix): ``(K, K)`` similarity matrix with ``labels`` set to the unique stim classes (sorted). Notes: - Stim classes that have no slices in ``slice_indices`` are dropped from the output. """ from .utils import _slice_to_slice_similarity_matrix stim_labels = np.asarray(stim_labels).ravel() if len(stim_labels) != len(self): raise ValueError( f"stim_labels must have length S={len(self)}; got {len(stim_labels)}." ) if slice_indices is None: slice_indices = np.arange(len(self)) else: slice_indices = np.asarray(slice_indices, dtype=int).ravel() if (slice_indices < 0).any() or (slice_indices >= len(self)).any(): raise IndexError(f"slice_indices out of range for S={len(self)}.") unique_labels = np.array(sorted(np.unique(stim_labels[slice_indices]))) if len(unique_labels) < 2: raise ValueError( "Need at least 2 distinct stim classes in the selected slices." ) # (U, T, S_subset) raster = self.subslice(list(slice_indices)).to_raster_array(bin_size=bin_size) sub_labels = stim_labels[slice_indices] # Per-class mean across slices -> (U, T, K) mean_per_class = np.stack( [raster[:, :, sub_labels == cls].mean(axis=2) for cls in unique_labels], axis=2, ) sim = _slice_to_slice_similarity_matrix(mean_per_class, metric) return PairwiseCompMatrix( matrix=sim, labels=list(unique_labels), metadata={"metric": metric, "n_classes": len(unique_labels)}, )
[docs] def responsive_units_per_group( self, group_labels, bin_size, baseline_window_ms, *, response_window_ms=None, z_threshold=2.0, aggregator="mean", ): """Per-cycle responsive-unit mask for tracking responsiveness over time. For each unique cycle, runs ``responsive_units`` on the slices belonging to that cycle and returns a ``(U, n_cycles)`` boolean matrix. Use the per-cycle masks to compute gained / lost / preserved responsive units across cycle groups, or to correlate responsiveness changes with intrinsic activity changes per unit. Parameters: group_labels (array-like): Per-slice cycle index of length ``S``. bin_size (float): Raster bin size in ms. baseline_window_ms (tuple[float, float]): Baseline window for Poisson z-score normalization. response_window_ms (tuple[float, float] or None): Optional response window (default: full slice). z_threshold (float): Per-cycle z-threshold (default 2.0). aggregator (str): ``"mean"`` (default) or ``"max"`` across slices within each cycle. Returns: result (dict): - ``cycles`` (np.ndarray): Sorted unique cycle indices. - ``mask`` (np.ndarray): ``(U, n_cycles)`` boolean responsiveness mask. - ``responsive_count`` (np.ndarray): Per-cycle responsive unit count, shape ``(n_cycles,)``. """ group_labels = np.asarray(group_labels).ravel() if len(group_labels) != len(self): raise ValueError( f"group_labels must have length S={len(self)}; got {len(group_labels)}." ) groups = np.array(sorted(np.unique(group_labels))) mask = np.zeros((self.N, len(groups)), dtype=bool) for j, c in enumerate(groups): slice_idx = np.where(group_labels == c)[0] if slice_idx.size == 0: continue sub = self.subslice(slice_idx.tolist()) mask[:, j] = sub.responsive_units( bin_size=bin_size, baseline_window_ms=baseline_window_ms, response_window_ms=response_window_ms, z_threshold=z_threshold, aggregator=aggregator, ) return { "groups": groups, "mask": mask, "responsive_count": mask.sum(axis=0), }
[docs] def responsiveness_change( self, group_labels, early_groups, late_groups, bin_size, baseline_window_ms, *, response_window_ms=None, z_threshold=2.0, aggregator="mean", ): """Gained / lost / preserved responsive units between two cycle groups. Computes responsive-unit masks separately for slices in ``early_groups`` and ``late_groups`` (any iterables of cycle indices), and reports which units become responsive ("gained"), stop being responsive ("lost"), or stay responsive ("preserved"). Parameters: group_labels (array-like): Per-slice cycle index of length ``S``. early_groups (array-like): Cycle indices for the early group. late_groups (array-like): Cycle indices for the late group. bin_size (float): Raster bin size in ms. baseline_window_ms (tuple[float, float]): Baseline window. response_window_ms (tuple[float, float] or None): Response window (default: full slice). z_threshold (float): Per-group z-threshold. aggregator (str): ``"mean"`` (default) or ``"max"``. Returns: result (dict): - ``early_mask`` (np.ndarray ``(U,)`` bool): Responsive in early. - ``late_mask`` (np.ndarray ``(U,)`` bool): Responsive in late. - ``gained`` (np.ndarray ``(U,)`` bool): NOT responsive in early AND responsive in late. - ``lost`` (np.ndarray ``(U,)`` bool): Responsive in early AND NOT responsive in late. - ``preserved`` (np.ndarray ``(U,)`` bool): Responsive in BOTH. - ``early_count``, ``late_count``, ``gained_count``, ``lost_count``, ``preserved_count`` (int). Notes: - Pair this with intrinsic-activity changes per unit (e.g. ``cv_isi`` differences) and correlate via ``stat_utils.linear_regression`` to ask whether responsiveness changes track changes in baseline activity. """ group_labels = np.asarray(group_labels).ravel() if len(group_labels) != len(self): raise ValueError( f"group_labels must have length S={len(self)}; got {len(group_labels)}." ) early_groups = np.asarray(early_groups).ravel() late_groups = np.asarray(late_groups).ravel() early_idx = np.where(np.isin(group_labels, early_groups))[0] late_idx = np.where(np.isin(group_labels, late_groups))[0] if early_idx.size == 0: raise ValueError("No slices match early_groups.") if late_idx.size == 0: raise ValueError("No slices match late_groups.") kwargs = dict( bin_size=bin_size, baseline_window_ms=baseline_window_ms, response_window_ms=response_window_ms, z_threshold=z_threshold, aggregator=aggregator, ) early_mask = self.subslice(early_idx.tolist()).responsive_units(**kwargs) late_mask = self.subslice(late_idx.tolist()).responsive_units(**kwargs) gained = (~early_mask) & late_mask lost = early_mask & (~late_mask) preserved = early_mask & late_mask return { "early_mask": early_mask, "late_mask": late_mask, "gained": gained, "lost": lost, "preserved": preserved, "early_count": int(early_mask.sum()), "late_count": int(late_mask.sum()), "gained_count": int(gained.sum()), "lost_count": int(lost.sum()), "preserved_count": int(preserved.sum()), }
[docs] def slice_to_slice_similarity(self, metric="cosine", *, bin_size=1.0): """Pairwise similarity between slice-wise population response vectors. Each slice is converted to a ``(U * T)`` flat vector via ``to_raster_array(bin_size).reshape(U*T, S).T`` and a square ``(S, S)`` similarity matrix is computed using the requested metric. Parameters: metric (str): One of ``"cosine"`` (default), ``"pearson"``, ``"euclidean"`` (distance), or ``"cross_entropy"`` (symmetric KL on normalized bin distributions). bin_size (float): Raster bin size in ms (default 1.0). Returns: sim (PairwiseCompMatrix): ``(S, S)`` similarity matrix. For cosine and pearson, higher = more similar (diagonal ~1.0); for euclidean and cross_entropy, lower = more similar (diagonal 0). Notes: - ``cosine`` and ``pearson`` return values in ``[-1, 1]``. - ``euclidean`` returns raw L2 distance. - ``cross_entropy`` returns symmetric KL divergence (i.e. ``(KL(p||q) + KL(q||p)) / 2``) between bin distributions normalized to sum to 1. - Use ``PairwiseCompMatrix.extract_lower_triangle()`` for feature extraction. """ stack = self.to_raster_array(bin_size=bin_size) # (U, T, S) sim = _slice_to_slice_similarity_matrix(stack, metric) return PairwiseCompMatrix(matrix=sim, metadata={"metric": metric})
[docs] def per_unit_response_regression( self, bin_size, response_window_ms, *, x_values=None, baseline_window_ms=None, min_valid_slices=3, ): """Per-unit OLS regression of evoked response amplitude across slices. For each slice and unit, computes response amplitude as the sum of spike counts in ``response_window_ms`` — optionally with a per-slice baseline subtraction. Then fits a linear regression of amplitude against ``x_values`` for every unit. Use this to detect facilitation / depression of the evoked response across cycles or stimulus intensities. Parameters: bin_size (float): Raster bin size in milliseconds (passed to ``to_raster_array``). response_window_ms (tuple[float, float]): ``(start_ms, end_ms)`` window relative to slice origin over which response counts are summed. x_values (array-like or None): Per-slice x values for the regression (e.g. cycle index, stimulus intensity). Length must equal the number of slices ``S``. When None (default), uses ``np.arange(S)``. baseline_window_ms (tuple[float, float] or None): Optional baseline window for per-slice subtraction. When None (default), uses raw response counts; otherwise subtracts the expected count per bin (``baseline_rate * bin_size``) before summing. min_valid_slices (int): Minimum number of valid (non-NaN) ``(x, y)`` pairs required to fit a regression. Units with fewer return NaN for all coefficients. Default 3. Returns: result (dict): Dictionary with keys: - ``slope`` (np.ndarray ``(U,)``): Slope per unit. - ``intercept`` (np.ndarray ``(U,)``): Intercept per unit. - ``r_squared`` (np.ndarray ``(U,)``): R² per unit. - ``p_value`` (np.ndarray ``(U,)``): Two-sided p-value of the slope per unit. - ``stderr`` (np.ndarray ``(U,)``): Standard error of the slope per unit. - ``amplitudes`` (np.ndarray ``(U, S)``): Per-slice response amplitude (raw or baseline-subtracted). - ``x_values`` (np.ndarray ``(S,)``): The x values used. Notes: - Requires ``scipy`` (optional dependency); raises ``ImportError`` with installation instructions if missing. - Units with constant amplitudes get ``r_squared = 0``, ``slope = 0`` and ``p_value = 1.0``. """ try: from scipy import stats as sp_stats except ImportError as e: raise ImportError( "per_unit_response_regression requires 'scipy'. " "Install with: pip install scipy" ) from e if ( not isinstance(response_window_ms, (tuple, list)) or len(response_window_ms) != 2 ): raise ValueError("response_window_ms must be a (start_ms, end_ms) tuple.") r_start = float(response_window_ms[0]) r_end = float(response_window_ms[1]) if r_end <= r_start: raise ValueError("response_window_ms end must be greater than start.") S = len(self.spike_stack) if x_values is None: x_values = np.arange(S, dtype=float) else: x_values = np.asarray(x_values, dtype=float).ravel() if x_values.size != S: raise ValueError( f"x_values must have length S={S}, got {x_values.size}." ) if baseline_window_ms is None: counts = self.to_raster_array(bin_size=bin_size).astype(float) # (U, T, S) else: counts = self.baseline_normalized_raster( bin_size, baseline_window_ms, mode="subtract" ) sd0 = self.spike_stack[0] bin_start = int(np.floor((r_start - sd0.start_time) / bin_size)) bin_end = int(np.ceil((r_end - sd0.start_time) / bin_size)) bin_start = max(0, bin_start) bin_end = min(counts.shape[1], bin_end) if bin_end <= bin_start: raise ValueError( f"response_window_ms ({r_start}, {r_end}) maps to an empty " f"bin range; check it against the slice duration and bin_size." ) amplitudes = np.nansum(counts[:, bin_start:bin_end, :], axis=1) # (U, S) slope = np.full(self.N, np.nan) intercept = np.full(self.N, np.nan) r_squared = np.full(self.N, np.nan) p_value = np.full(self.N, np.nan) stderr = np.full(self.N, np.nan) for u in range(self.N): y = amplitudes[u, :] valid = np.isfinite(x_values) & np.isfinite(y) if int(np.sum(valid)) < min_valid_slices: continue xv = x_values[valid] yv = y[valid] # Constant predictor or response: linregress would emit a warning # and return NaN; handle explicitly. if np.ptp(xv) == 0: continue try: res = sp_stats.linregress(xv, yv) except (ValueError, FloatingPointError): continue slope[u] = float(res.slope) intercept[u] = float(res.intercept) r_squared[u] = float(res.rvalue) ** 2 p_value[u] = float(res.pvalue) stderr[u] = float(res.stderr) return { "slope": slope, "intercept": intercept, "r_squared": r_squared, "p_value": p_value, "stderr": stderr, "amplitudes": amplitudes, "x_values": x_values, }
[docs] def compute_frac_active(self, min_spikes=2): """Compute the fraction of slices each unit is active in. A unit counts as active in a slice if it has at least min_spikes spikes within that slice's time window. Parameters: min_spikes (int): Minimum number of spikes for a unit to count as active in a slice (default: 2). Returns: frac_active (np.ndarray): 1-D array of shape ``(U,)`` with the fraction of slices each unit is active in (values in [0, 1]). Notes: - The returned array can be passed as ``frac_active`` to ``RateSliceStack.order_units_across_slices``, ``RateSliceStack.get_slice_to_slice_unit_corr_from_stack``, ``SpikeSliceStack.order_units_across_slices``, or ``SpikeSliceStack.get_slice_to_slice_unit_comparison`` to override their internal activity calculation. - ``SpikeData.get_frac_active`` produces a compatible ``(U,)`` array based on burst edges and can be used in the same way. """ num_units = self.N num_slices = len(self.spike_stack) active_count = np.zeros(num_units, dtype=int) for sd, (start, end) in zip(self.spike_stack, self.times): for u in range(num_units): spikes = np.asarray(sd.train[u]) n_valid = np.sum( (spikes >= sd.start_time) & (spikes < sd.start_time + (end - start)) ) if n_valid >= min_spikes: active_count[u] += 1 return active_count / num_slices if num_slices > 0 else np.zeros(num_units)
[docs] def order_units_across_slices( self, agg_func="median", timing="median", min_spikes=2, min_frac_active=0.0, frac_active=None, timing_matrix=None, ): """Reorder units by their typical spike timing across slices. For each unit in each slice, computes a representative spike time (median, mean, or first spike) relative to the slice's time origin. These per-slice values are aggregated across slices to obtain a single typical timing per unit. Units are then sorted by this value from earliest to latest and optionally split into a highly-active group and a low-activity group. Parameters: agg_func (str): How to aggregate per-slice timing values across slices. ``"median"`` (default) or ``"mean"``. timing (str): Which spike time to extract per unit per slice. ``"median"`` — median spike time within the slice (default). ``"mean"`` — mean spike time within the slice. ``"first"`` — first spike time (onset latency). Ignored when ``timing_matrix`` is provided. min_spikes (int): Minimum number of spikes for a unit to count as active in a slice (default: 2). Ignored when ``timing_matrix`` is provided. min_frac_active (float or None): Minimum fraction of slices a unit must be active in to be placed in the highly-active group. ``0.0`` or ``None`` (default: 0.0) skips the split entirely and places all units in the highly-active group without computing activity fractions. frac_active (np.ndarray or None): Optional pre-computed fraction-active array of shape ``(U,)`` to override the internal calculation for the group split. Only used when ``min_frac_active > 0``. Compatible sources: ``SpikeSliceStack.compute_frac_active`` and ``SpikeData.get_frac_active`` (``frac_per_unit`` output). timing_matrix (np.ndarray or None): Optional pre-computed ``(U, S)`` timing matrix from ``get_unit_timing_per_slice``. When provided, ``timing`` and ``min_spikes`` are ignored and this matrix is used directly. Returns: reordered_stacks (tuple): Two ``SpikeSliceStack`` objects ``(highly_active, low_active)`` with units reordered by typical timing. The low-activity stack is ``None`` when the group is empty. unit_ids_in_order (tuple): Two ``ndarray`` ``(highly_active, low_active)`` of original unit indices in the reordered sequence. unit_std (tuple): Two ``ndarray`` ``(highly_active, low_active)`` of standard deviation of per-slice timing values. Lower values indicate more consistent timing across slices. unit_peak_times_ms (tuple): Two ``ndarray`` ``(highly_active, low_active)`` of the aggregated typical timing in milliseconds relative to slice start. NaN for units with no active slices. unit_frac_active (tuple): Two ``ndarray`` ``(highly_active, low_active)`` of the fraction of slices each unit was active in. Notes: - Call ``get_unit_timing_per_slice`` first to pre-compute the timing matrix if you want to reuse it across multiple calls (e.g. ``rank_order_correlation`` and this method). - When ``frac_active`` is None and ``min_frac_active > 0``, activity fraction is computed via ``compute_frac_active``. - Analogous to ``RateSliceStack.order_units_across_slices`` but operates on raw spike trains instead of firing rate curves. """ if agg_func not in ("median", "mean"): raise ValueError(f"agg_func must be 'median' or 'mean', got {agg_func!r}") num_units = self.N num_slices = len(self.spike_stack) if timing_matrix is not None: timing_matrix = np.asarray(timing_matrix, dtype=float) if timing_matrix.shape != (num_units, num_slices): raise ValueError( f"timing_matrix must have shape ({num_units}, {num_slices}), " f"got {timing_matrix.shape}" ) else: timing_matrix = self.get_unit_timing_per_slice( timing=timing, min_spikes=min_spikes ) # Standard deviation across slices unit_std_values = np.nanstd(timing_matrix, axis=1) # Aggregate across slices if agg_func == "median": unit_timing = np.nanmedian(timing_matrix, axis=1) else: unit_timing = np.nanmean(timing_matrix, axis=1) # Compute or validate frac_active only when splitting is requested skip_split = not min_frac_active if skip_split: frac_active = np.ones(num_units) ha_units = np.arange(num_units) la_units = np.array([], dtype=int) else: if frac_active is not None: frac_active = np.asarray(frac_active, dtype=float) if frac_active.shape != (num_units,): raise ValueError( f"frac_active must have shape ({num_units},), " f"got {frac_active.shape}" ) else: frac_active = self.compute_frac_active(min_spikes=min_spikes) highly_active_mask = frac_active >= min_frac_active ha_units = np.where(highly_active_mask)[0] la_units = np.where(~highly_active_mask)[0] # Sort within each group by typical timing ha_order = ha_units[np.argsort(unit_timing[ha_units])] la_order = la_units[np.argsort(unit_timing[la_units])] # Build reordered SpikeSliceStacks def _reorder_stack(unit_indices): if len(unit_indices) == 0: return None return self.subset(list(unit_indices)) ha_stack = _reorder_stack(ha_order) la_stack = _reorder_stack(la_order) return ( (ha_stack, la_stack), (ha_order, la_order), (unit_std_values[ha_order], unit_std_values[la_order]), (unit_timing[ha_order], unit_timing[la_order]), (frac_active[ha_order], frac_active[la_order]), )
[docs] def apply(self, func, *args, **kwargs): """Apply a function to each SpikeData in the stack and return stacked results. Calls ``func(sd, *args, **kwargs)`` on every slice and stacks the outputs into a single numpy array with a new leading axis of size S (number of slices). Parameters: func (callable): Function that accepts a SpikeData as its first argument and returns a numeric value (scalar, 1-D, or 2-D array). Output shape must be consistent across all slices. *args: Additional positional arguments forwarded to func. **kwargs: Additional keyword arguments forwarded to func. Returns: result (np.ndarray): Stacked results with shape ``(S, ...)``. Notes: - Intended for use with stacks built by ``SpikeData.frames``, ``SpikeData.align_to_events``, ``SpikeData.spike_shuffle_stack``, or ``SpikeData.subset_stack``. Pair with ``shuffle_z_score``, ``shuffle_percentile``, ``slice_trend``, or ``slice_stability`` from ``utils`` to interpret the results. """ results = [func(sd, *args, **kwargs) for sd in self.spike_stack] return np.stack(results, axis=0)
[docs] def unit_to_unit_comparison( self, metric="ccg", delt=20.0, bin_size=1.0, max_lag=350, n_jobs=-1, ): """Compute pairwise unit-to-unit similarity within each slice using spike-based metrics. For each slice, computes a (U, U) similarity matrix between all unit pairs, then stacks the results into a ``PairwiseCompMatrixStack (U, U, S)``. Parameters: metric (str): Similarity metric to use. ``"ccg"`` for cross-correlogram on binned rasters (default), ``"sttc"`` for spike time tiling coefficient. delt (float): STTC time window in milliseconds (default: 20.0). Only used when metric is ``"sttc"``. bin_size (float): Bin size in milliseconds for the binary raster (default: 1.0). Only used when metric is ``"ccg"``. max_lag (float): Maximum lag in milliseconds to search for the peak correlation (default: 350). Only used when metric is ``"ccg"``. Returns: corr_stack (PairwiseCompMatrixStack): Pairwise similarity scores between all unit pairs for each slice. Shape is ``(U, U, S)``. lag_stack (PairwiseCompMatrixStack or None): Lag at which maximum similarity occurs for each pair per slice. Shape is ``(U, U, S)``. ``None`` when metric is ``"sttc"`` (STTC has no lag). av_corr (np.ndarray): Average similarity per slice across all unit pairs in the lower triangle. Shape is ``(S,)``. av_lag (np.ndarray or None): Average lag per slice. Shape is ``(S,)``. ``None`` when metric is ``"sttc"``. Notes: - Analogous to ``RateSliceStack.unit_to_unit_correlation`` but operates on raw spike trains instead of firing rate time series. """ if metric not in ("sttc", "ccg"): raise ValueError(f"metric must be 'sttc' or 'ccg', got {metric!r}") num_units = self.N num_slices = len(self.spike_stack) if num_units < 2: warnings.warn( "Cannot compute unit-to-unit comparison with fewer than " "2 units. Returning NaN.", RuntimeWarning, ) nan_stack = np.full((num_units, num_units, num_slices), np.nan) nan_avgs = np.full(num_slices, np.nan) return ( PairwiseCompMatrixStack(stack=nan_stack, times=self.times), ( PairwiseCompMatrixStack(stack=nan_stack.copy(), times=self.times) if metric == "ccg" else None ), nan_avgs, nan_avgs.copy() if metric == "ccg" else None, ) corr_matrices = [] lag_matrices = [] for sd in self.spike_stack: if metric == "sttc": pcm = sd.spike_time_tilings(delt=delt) corr_matrices.append(pcm.matrix) else: # ccg corr_pcm, lag_pcm = sd.get_pairwise_ccg( bin_size=bin_size, max_lag=max_lag, n_jobs=n_jobs ) corr_matrices.append(corr_pcm.matrix) lag_matrices.append(lag_pcm.matrix) # Stack: list of (U, U) -> (S, U, U) -> transpose to (U, U, S) corr_array = np.moveaxis(np.stack(corr_matrices, axis=0), 0, 2) lower_tri = np.tril_indices(num_units, k=-1) av_corr = np.nanmean(corr_array[lower_tri[0], lower_tri[1], :], axis=0) corr_stack = PairwiseCompMatrixStack(stack=corr_array, times=self.times) if metric == "ccg": lag_array = np.moveaxis(np.stack(lag_matrices, axis=0), 0, 2) av_lag = np.nanmean(lag_array[lower_tri[0], lower_tri[1], :], axis=0) lag_stack = PairwiseCompMatrixStack(stack=lag_array, times=self.times) else: lag_stack = None av_lag = None return corr_stack, lag_stack, av_corr, av_lag
[docs] def get_slice_to_slice_unit_comparison( self, metric="ccg", delt=20.0, bin_size=1.0, max_lag=350, min_spikes=2, min_frac=0.3, frac_active=None, n_jobs=-1, ): """Compute slice-to-slice similarity for each unit using spike-based metrics. For each unit independently, compares its spike train across every pair of slices. Asks: "Does unit X fire in the same temporal pattern across repeated events?" Returns a ``PairwiseCompMatrixStack (S, S, U)``. Parameters: metric (str): Similarity metric to use. ``"ccg"`` for cross-correlogram on binned rasters (default), ``"sttc"`` for spike time tiling coefficient. delt (float): STTC time window in milliseconds (default: 20.0). Only used when metric is ``"sttc"``. bin_size (float): Bin size in milliseconds for the binary raster (default: 1.0). Only used when metric is ``"ccg"``. max_lag (float): Maximum lag in milliseconds to search for the peak correlation (default: 350). Only used when metric is ``"ccg"``. min_spikes (int): Minimum number of spikes in a slice for a unit to be considered valid in that slice (default: 2). min_frac (float): Maximum fraction of slices that can be invalid before a unit's average is set to NaN (default: 0.3). frac_active (np.ndarray or None): Optional pre-computed fraction-active array of shape ``(U,)`` to override the internal per-unit validity check for computing averages. When provided, a unit's average is set to NaN if ``frac_active[u] < (1 - min_frac)``. ``min_spikes`` still controls which individual slice pairs are computed. Compatible sources: ``SpikeSliceStack.compute_frac_active`` and ``SpikeData.get_frac_active`` (``frac_per_unit`` output). n_jobs (int): Number of threads for parallel computation. -1 uses all cores (default), 1 disables parallelism, None is serial. Returns: all_corr (PairwiseCompMatrixStack): Pairwise similarity between all slice pairs for each unit. Shape is ``(S, S, U)``. all_lag (PairwiseCompMatrixStack or None): Lag at which maximum similarity occurs for each slice pair per unit. Shape is ``(S, S, U)``. ``None`` when metric is ``"sttc"``. av_corr (np.ndarray): Average similarity per unit across all valid slice pairs. Shape is ``(U,)``. av_lag (np.ndarray or None): Average lag per unit. Shape is ``(U,)``. ``None`` when metric is ``"sttc"``. Notes: - Analogous to ``RateSliceStack.get_slice_to_slice_unit_corr_from_stack`` but operates on raw spike trains. - Spike times within each slice are relative to the slice time origin (0-based or event-centered) for aligned comparison. """ if metric not in ("sttc", "ccg"): raise ValueError(f"metric must be 'sttc' or 'ccg', got {metric!r}") num_units = self.N num_slices = len(self.spike_stack) if num_slices < 2: warnings.warn( "Cannot compute slice-to-slice unit comparison with fewer than " "2 slices. Returning NaN.", RuntimeWarning, ) av_corr = np.full(num_units, np.nan) nan_stack = np.full((num_slices, num_slices, num_units), np.nan) return ( PairwiseCompMatrixStack(stack=nan_stack), ( PairwiseCompMatrixStack(stack=nan_stack.copy()) if metric == "ccg" else None ), av_corr, av_corr.copy() if metric == "ccg" else None, ) # Pre-compute spike trains per slice (relative to slice time origin) # and per-slice rasters for CCG shifted_trains = [] # list of S lists, each containing U spike arrays slice_durations = [] slice_rasters = [] # only populated for CCG for sd, (start, end) in zip(self.spike_stack, self.times): duration = end - start slice_durations.append(duration) trains = [] for u in range(num_units): trains.append(np.asarray(sd.train[u])) shifted_trains.append(trains) if metric == "ccg": # Build shifted SpikeData for raster computation temp_sd = SpikeData( trains, length=duration, start_time=sd.start_time, N=num_units ) slice_rasters.append(temp_sd.raster(bin_size)) max_lag_bins = int(round(max_lag / bin_size)) if metric == "ccg" else 0 # Warn when a positive max_lag rounds down to zero bins for the # ccg path — the user's lag request is silently discarded # otherwise. Matches the guard in ``SpikeData.get_pairwise_ccg``. if metric == "ccg" and max_lag > 0 and max_lag_bins == 0: warnings.warn( f"max_lag={max_lag} ms is smaller than bin_size={bin_size} ms; " f"max_lag_bins collapsed to 0 (zero-lag only). To resolve " f"sub-bin lags, decrease bin_size.", UserWarning, stacklevel=2, ) # Validate frac_active override if provided if frac_active is not None: frac_active = np.asarray(frac_active, dtype=float) if frac_active.shape != (num_units,): raise ValueError( f"frac_active must have shape ({num_units},), " f"got {frac_active.shape}" ) # Initialize result arrays: (U, S, S), will transpose to (S, S, U) at end all_corr_scores = np.full((num_units, num_slices, num_slices), np.nan) all_lag_scores = ( np.full((num_units, num_slices, num_slices), np.nan) if metric == "ccg" else None ) av_corr = np.full(num_units, np.nan) av_lag = np.full(num_units, np.nan) if metric == "ccg" else None lower_tri = np.tril_indices(num_slices, k=-1) start_times = [sd.start_time for sd in self.spike_stack] def _process_unit(unit): unit_corr = np.full((num_slices, num_slices), np.nan) unit_lag = ( np.full((num_slices, num_slices), np.nan) if metric == "ccg" else None ) invalid_count = 0 for ref_s in range(num_slices): ref_train = shifted_trains[ref_s][unit] if len(ref_train) < min_spikes: invalid_count += 1 continue for comp_s in range(ref_s, num_slices): comp_train = shifted_trains[comp_s][unit] if len(comp_train) < min_spikes: continue if metric == "sttc": length = max(slice_durations[ref_s], slice_durations[comp_s]) # start_time from the ref slice is correct here: # all slices share the same start_time (event-centered # data has -pre_ms, frames() produces 0-based slices). score = get_sttc( ref_train, comp_train, delt=delt, length=length, start_time=start_times[ref_s], ) unit_corr[ref_s, comp_s] = score unit_corr[comp_s, ref_s] = score else: ref_signal = slice_rasters[ref_s][unit, :] comp_signal = slice_rasters[comp_s][unit, :] score, lag = compute_cross_correlation_with_lag( ref_signal, comp_signal, max_lag=max_lag_bins ) unit_corr[ref_s, comp_s] = score unit_corr[comp_s, ref_s] = score unit_lag[ref_s, comp_s] = lag unit_lag[comp_s, ref_s] = -lag if frac_active is not None: unit_valid = frac_active[unit] >= (1 - min_frac) else: unit_valid = invalid_count / num_slices <= min_frac av_c = ( np.nanmean(unit_corr[lower_tri[0], lower_tri[1]]) if unit_valid else np.nan ) av_l = np.nan if metric == "ccg" and unit_valid: av_l = np.nanmean(unit_lag[lower_tri[0], lower_tri[1]]) return unit, unit_corr, unit_lag, av_c, av_l n_workers = _resolve_n_jobs(n_jobs) if n_workers > 1 and num_units > 1: with ThreadPoolExecutor(max_workers=n_workers) as pool: results = pool.map(_process_unit, range(num_units)) else: results = map(_process_unit, range(num_units)) for unit, unit_corr, unit_lag, av_c, av_l in results: all_corr_scores[unit] = unit_corr av_corr[unit] = av_c if metric == "ccg": all_lag_scores[unit] = unit_lag av_lag[unit] = av_l # Transpose from (U, S, S) to (S, S, U) all_corr_scores = np.moveaxis(all_corr_scores, 0, 2) all_corr_stack = PairwiseCompMatrixStack(stack=all_corr_scores) if metric == "ccg": all_lag_scores = np.moveaxis(all_lag_scores, 0, 2) all_lag_stack = PairwiseCompMatrixStack(stack=all_lag_scores) else: all_lag_stack = None return all_corr_stack, all_lag_stack, av_corr, av_lag
[docs] def get_unit_timing_per_slice( self, timing="median", min_spikes=2, ): """Compute a representative spike time for each unit in each slice. Returns a ``(U, S)`` matrix where entry ``[u, s]`` is the timing value (in milliseconds relative to the slice's time origin) for unit u in slice s. For event-centered slices, t=0 is the event. Units with fewer than min_spikes spikes in a slice are marked NaN. Parameters: timing (str): Which spike time to extract per unit per slice. ``"median"`` (default), ``"mean"``, or ``"first"`` (onset latency). min_spikes (int): Minimum number of spikes for a unit to count as active in a slice (default: 2). Returns: timing_matrix (np.ndarray): Array of shape ``(U, S)`` with timing values in milliseconds relative to each slice's time origin. NaN where the unit is inactive. Notes: - Values are in milliseconds, not bin indices. This differs from ``RateSliceStack.get_unit_timing_per_slice`` which returns bin indices (suitable for direct indexing into the event stack). Both representations preserve rank order, so ``rank_order_correlation`` produces identical results either way. - The returned matrix can be passed to ``rank_order_correlation`` to compute Spearman rank correlations between slice pairs, or used as input to ``order_units_across_slices`` for manual inspection of per-slice timing values. """ if timing not in ("median", "mean", "first"): raise ValueError( f"timing must be 'median', 'mean', or 'first', got {timing!r}" ) num_units = self.N num_slices = len(self.spike_stack) timing_matrix = np.full((num_units, num_slices), np.nan) for s_idx, (sd, (start, end)) in enumerate(zip(self.spike_stack, self.times)): for u in range(num_units): spikes = np.asarray(sd.train[u]) duration = end - start spikes = spikes[ (spikes >= sd.start_time) & (spikes < sd.start_time + duration) ] if len(spikes) < min_spikes: continue if timing == "median": timing_matrix[u, s_idx] = np.median(spikes) elif timing == "mean": timing_matrix[u, s_idx] = np.mean(spikes) else: timing_matrix[u, s_idx] = spikes[0] return timing_matrix
[docs] def rank_order_correlation( self, timing_matrix=None, timing="median", min_spikes=2, min_overlap=3, n_shuffles=100, min_overlap_frac=None, seed=1, n_jobs=-1, ): """Compute Spearman rank-order correlation of unit timing between all slice pairs. For each pair of slices, only units active in both slices (non-NaN in both columns of the timing matrix) are included. If the overlap falls below the required minimum, the pair is set to NaN. When ``n_shuffles > 0``, the rank orders are shuffled n_shuffles times for each pair to build a null distribution, and the raw correlation is z-score normalised against it. Parameters: timing_matrix (np.ndarray or None): Array of shape ``(U, S)`` with timing values per unit per slice. NaN entries mark inactive units. Typically produced by ``get_unit_timing_per_slice``. When None, computed automatically using timing and min_spikes. timing (str): Which spike time to extract per unit per slice. ``"median"`` (default), ``"mean"``, or ``"first"``. Only used when timing_matrix is None. min_spikes (int): Minimum spikes for activity (default: 2). Only used when timing_matrix is None. min_overlap (int): Minimum number of units that must be active in both slices (default: 3). min_overlap_frac (float or None): Minimum fraction of total units that must be active in both slices (default: None). When provided, the effective threshold is ``max(min_overlap, ceil(min_overlap_frac * U))``. n_shuffles (int): Number of shuffle iterations for z-scoring (default: 100). Set to 0 to return raw Spearman correlations. Values between 1 and 4 are rejected (minimum 5 required for a meaningful null distribution). seed (int or None): Random seed for reproducibility of the shuffle (default: 1). n_jobs (int): Number of threads for parallel computation. -1 uses all cores (default), 1 disables parallelism, None is serial. Returns: corr_matrix (PairwiseCompMatrix): Spearman correlation matrix of shape ``(S, S)``. When ``n_shuffles > 0``, values are z-scores. When ``n_shuffles == 0``, values are raw Spearman correlations. av_corr (float): Average correlation (or z-score) across all valid lower-triangle pairs. overlap_matrix (PairwiseCompMatrix): Matrix of shape ``(S, S)`` with fraction of units active in both slices. """ from .utils import _rank_order_correlation_from_timing if timing_matrix is None: timing_matrix = self.get_unit_timing_per_slice( timing=timing, min_spikes=min_spikes ) return _rank_order_correlation_from_timing( timing_matrix, min_overlap=min_overlap, min_overlap_frac=min_overlap_frac, n_shuffles=n_shuffles, seed=seed, n_jobs=n_jobs, )
[docs] def plot_aligned_slice_single_unit( self, unit_idx, ax=None, color_vals=None, color_label="", cmap="viridis", time_offset=0, xlabel="Rel. time (ms)", ylabel="Burst", x_range=None, vlines=None, show_colorbar=True, marker_size=20, font_size=None, style="scatter", invert_y=False, linewidths=0.5, ): """Plot a single unit's spike times across all slices as a raster. Extracts the spike train for unit_idx from every slice and delegates to :func:`~SpikeLab.spikedata.plot_utils.plot_aligned_slice_single_unit`. Parameters: unit_idx (int): Index of the unit to plot. ax (matplotlib.axes.Axes or None): Target axes. If None, a new figure and axes are created. color_vals (np.ndarray or None): Per-slice colour values. color_label (str): Colorbar label. cmap (str): Matplotlib colormap name. time_offset (float): Value subtracted from every spike time before plotting. Slices from ``align_to_events`` are already event-centered (spike times in ``[-pre_ms, +post_ms]``), so use the default ``time_offset=0``. Only set a non-zero value when spike times are not already centered on the event. xlabel (str): X-axis label. ylabel (str): Y-axis label. x_range (tuple or None): ``(xmin, xmax)`` for the x-axis. vlines (list[dict] or None): Vertical reference lines. Each dict must contain ``'x'`` and may optionally include ``'color'``, ``'linestyle'``, ``'linewidth'``. show_colorbar (bool): Add a colorbar when color_vals is provided. marker_size (float): Scatter marker size. font_size (int or None): Font size for labels/ticks. style (str): ``"scatter"`` for dot markers, ``"eventplot"`` for vertical line markers. invert_y (bool): If True, first slice at top, last at bottom. linewidths (float): Line width for eventplot markers. Returns: result: ``(fig, ax, sc)`` when ax is None, otherwise just sc. sc is the scatter ``PathCollection`` (or None if no colour coding). """ from .plot_utils import ( plot_aligned_slice_single_unit as _plot_aligned_slice_single_unit, ) try: import matplotlib.pyplot as plt except ImportError as e: raise ImportError( "plot_aligned_slice_single_unit requires 'matplotlib'. " "Install with: pip install matplotlib" ) from e if unit_idx < 0 or unit_idx >= self.N: raise IndexError(f"unit_idx {unit_idx} out of range for {self.N} units.") # Extract per-slice spike times for this unit spike_times_per_slice = [sd.train[unit_idx] for sd in self.spike_stack] standalone = ax is None if standalone: fig, ax = plt.subplots(figsize=(8, 6)) sc = _plot_aligned_slice_single_unit( ax, spike_times_per_slice, color_vals=color_vals, color_label=color_label, cmap=cmap, time_offset=time_offset, xlabel=xlabel, ylabel=ylabel, x_range=x_range, vlines=vlines, show_colorbar=show_colorbar, marker_size=marker_size, font_size=font_size, style=style, invert_y=invert_y, linewidths=linewidths, ) if standalone: plt.tight_layout() return fig, ax, sc return sc