Source code for spikelab.spikedata.spikeslicestack

import warnings

import numpy as np

__all__ = ["SpikeSliceStack"]

from .pairwise import PairwiseCompMatrix, PairwiseCompMatrixStack
from .spikedata import SpikeData
from concurrent.futures import ThreadPoolExecutor

from .utils import (
    _validate_time_start_to_end,
    _get_attr,
    get_sttc,
    compute_cross_correlation_with_lag,
    _resolve_n_jobs,
)


[docs] class SpikeSliceStack: """A list of SpikeData objects, one per slice, with spike-based comparison capabilities. U is units (neurons) and S is slices (bursts, events, etc). Construct from either a single SpikeData with time specifications, or directly from a pre-built list of SpikeData objects. Parameters: data_obj (SpikeData or None): A SpikeData object to slice. Provide either this or spike_stack, not both. times_start_to_end (list or None): Each entry is a tuple (start, end) representing the start and end times of a desired slice. Each tuple must have the same duration. time_peaks (list or None): List of times as int or float where there is a burst peak or stimulation event. Must be paired with time_bounds. Alternative to times_start_to_end. time_bounds (tuple or None): Single tuple (left_bound, right_bound). For example, (250, 500) means 250 ms before peak and 500 ms after peak. Must be paired with time_peaks. spike_stack (list or None): List of SpikeData objects, one per slice. All must have the same number of units. Spike times must be relative to the slice (0-based or event-centered via start_time), not absolute recording times. Provide either this or data_obj, not both. neuron_attributes (list or None): List of attribute dicts, one per unit. If None, inherited from data_obj when available. drop_slice_attributes (bool): If True (default), neuron_attributes are removed from individual SpikeData slices after construction. The shared copy is stored at neuron_attributes. This avoids duplicating large per-unit data (e.g. waveform templates) across every slice. Set to False to keep per-slice attributes. Attributes: spike_stack (list): List of SpikeData objects, one per slice. Spike times are relative to the slice window. For 0-based slices, times run from 0 to duration. For event-centered slices, times run from -pre_ms to +post_ms with t=0 at the event. Use self.times for absolute recording time positions. times (list): List of (start, end) time bounds for each slice in absolute recording time, sorted chronologically. Length equals S. Example: [(100, 350), (500, 750), (1000, 1250)]. N (int): Number of units. neuron_attributes (list or None): List of attribute dicts, one per unit. None if not provided. When drop_slice_attributes is True (default), this is the only copy and individual slices will have neuron_attributes set to None. """
[docs] def __init__( self, data_obj=None, times_start_to_end=None, time_peaks=None, time_bounds=None, spike_stack=None, neuron_attributes=None, drop_slice_attributes=True, ): if data_obj is None and spike_stack is None: raise TypeError( "Must input either a SpikeData as data_obj (option 1) or spike_stack (option 2)" ) if data_obj is not None and spike_stack is not None: warnings.warn( "User input both data_obj and spike_stack. " "Ignoring data_obj and using spike_stack instead.", UserWarning, ) data_obj = None # Option 1: Using data_obj if data_obj is not None: if not isinstance(data_obj, SpikeData): raise TypeError("data_obj must be a SpikeData object") if times_start_to_end is None: if time_peaks is None or time_bounds is None: raise ValueError( "Must provide either times_start_to_end or " "both time_peaks and time_bounds" ) if not isinstance(time_bounds, tuple) or len(time_bounds) != 2: raise TypeError( "time_bounds must be a tuple of (before, after) durations" ) before, after = time_bounds time_peaks = sorted(time_peaks) times_start_to_end = [] for t in time_peaks: times_start_to_end.append((t - before, t + after)) rec_range = ( data_obj.start_time, data_obj.start_time + data_obj.length, ) times_start_to_end = _validate_time_start_to_end( times_start_to_end, recording_range=rec_range ) self.times = times_start_to_end self.spike_stack = [] if time_peaks is not None: # Event-centered: shift_to=peak so t=0 is the event for peak, (start, end) in zip(time_peaks, times_start_to_end): self.spike_stack.append(data_obj.subtime(start, end, shift_to=peak)) else: # Standard: shift_to=start so t=0 is the window start for start, end in times_start_to_end: self.spike_stack.append(data_obj.subtime(start, end)) if neuron_attributes is None: neuron_attributes = data_obj.neuron_attributes # Option 2: Using spike_stack directly if spike_stack is not None: if not isinstance(spike_stack, list): raise TypeError("spike_stack must be a list of SpikeData objects") for s in spike_stack: if not isinstance(s, SpikeData): raise TypeError("spike_stack must be a list of SpikeData objects") if len(spike_stack) == 0: raise ValueError("spike_stack must not be empty") N = spike_stack[0].N for s in spike_stack: if s.N != N: raise ValueError( "All SpikeData objects in spike_stack must have the same number of units" ) if times_start_to_end is None: t = 0.0 times_start_to_end = [] for s in spike_stack: times_start_to_end.append((t, t + s.length)) t += s.length else: warn_neg = spike_stack[0].start_time >= 0 times_start_to_end = _validate_time_start_to_end( times_start_to_end, warn_negative_start=warn_neg ) if len(times_start_to_end) != len(spike_stack): raise ValueError( "times_start_to_end must have the same length as spike_stack" ) self.spike_stack = list(spike_stack) self.times = times_start_to_end # Validate that spike times are consistent with the slice # duration. Spike times must be relative to the slice (0-based # or event-centered), not absolute recording times. for i, (sd, (start, end)) in enumerate(zip(self.spike_stack, self.times)): duration = end - start expected_start = sd.start_time expected_end = sd.start_time + duration for u, train in enumerate(sd.train): if len(train) == 0: continue if train[0] < expected_start or train[-1] > expected_end: raise ValueError( f"Slice {i}, unit {u}: spike times " f"[{train[0]:.1f}, {train[-1]:.1f}] ms fall outside " f"expected range [{expected_start:.1f}, " f"{expected_end:.1f}] ms. " "Spike times must be relative to the slice (0-based " "or event-centered), not absolute recording times." ) self.N = self.spike_stack[0].N self.neuron_attributes = None if neuron_attributes is not None: self.neuron_attributes = neuron_attributes.copy() if len(self.neuron_attributes) != self.N: raise ValueError( f"neuron_attributes has {len(self.neuron_attributes)} items " f"but spike_stack has {self.N} units" ) # Strip per-slice neuron_attributes to avoid duplicating large data # (e.g. waveform templates) across every slice. if drop_slice_attributes: for sd in self.spike_stack: sd.neuron_attributes = None
def __repr__(self) -> str: S = len(self.spike_stack) return f"SpikeSliceStack(N={self.N}, S={S})" def __len__(self) -> int: return len(self.spike_stack) def __iter__(self): return iter(self.spike_stack)
[docs] def subslice(self, slices): """Extract a subset of slices from the spike stack. Parameters: slices (int or list): Slice index or list of slice indices to extract. Returns: result (SpikeSliceStack): New SpikeSliceStack containing only the specified slices. Shape changes from S to S_trimmed. All units and neuron_attributes are carried over. """ S = len(self.spike_stack) if isinstance(slices, int): slices = [slices] for s in slices: if s >= S or s < -S: raise ValueError(f"One or more slice indices out of range for S={S}") slices = sorted(slices) new_spike_stack = [] new_times = [] for s in slices: new_spike_stack.append(self.spike_stack[s]) new_times.append(self.times[s]) return SpikeSliceStack( spike_stack=new_spike_stack, times_start_to_end=new_times, neuron_attributes=self.neuron_attributes, )
[docs] def subset(self, units, by=None): """Extract a subset of units from every slice in the spike stack. Parameters: units (int, str, or list): Unit indices to extract. If by is None, must be int(s). If by is set, values to match in neuron_attributes. by (str or None): If set, select units by this neuron_attribute key instead of by index. Returns: result (SpikeSliceStack): New SpikeSliceStack containing only the specified units across all slices. All slices and neuron_attributes are carried over. Notes: - Units are included in the output in the order they appear in the train (ascending index order), not the order listed in units. - If IDs are not unique (when using by), every matching neuron is included. """ if isinstance(units, (int, str)): units = [units] # Resolve which indices will be kept so we can update neuron_attributes if by is not None: if self.neuron_attributes is None: raise ValueError("can't use `by` without `neuron_attributes`") _missing = object() unit_set = set(units) kept_indices = [] for i in range(self.N): if _get_attr(self.neuron_attributes[i], by, _missing) in unit_set: kept_indices.append(i) else: kept_indices = sorted(set(int(u) for u in units)) for u in kept_indices: if u < 0 or u >= self.N: raise ValueError(f"Unit index {u} out of range for {self.N} units.") new_spike_stack = [] for sd in self.spike_stack: new_spike_stack.append(sd.subset(kept_indices)) new_neuron_attributes = None if self.neuron_attributes is not None: new_neuron_attributes = [] for i in kept_indices: new_neuron_attributes.append(self.neuron_attributes[i]) return SpikeSliceStack( spike_stack=new_spike_stack, times_start_to_end=self.times, neuron_attributes=new_neuron_attributes, )
[docs] def subtime_by_index(self, start_idx, end_idx): """Trim each slice to a sub-window specified by millisecond indices. Indices are measured from the start of each slice (1 index = 1 ms). Trims along the time axis while preserving all slices and units. Parameters: start_idx (int): Start index in ms from slice start (inclusive). Supports negative indexing. end_idx (int): End index in ms from slice start (exclusive). Supports negative indexing. Returns: result (SpikeSliceStack): New SpikeSliceStack where each slice is trimmed to the corresponding absolute time window. Absolute spike times are preserved (not shifted). self.times is updated to reflect the new absolute time bounds. Notes: - Indices are relative to each slice's own start (index 0 = slice start ms). They are converted to absolute recording times internally before trimming. - Original absolute timestamps are preserved. To get shifted-to-zero timestamps, create a new SpikeSliceStack. - All slices and neuron_attributes are carried over from the original. """ slice_duration_ms = self.times[0][1] - self.times[0][0] T = int(round(slice_duration_ms)) if start_idx < 0: start_idx += T if end_idx < 0: end_idx += T if start_idx < 0 or start_idx >= T: raise ValueError(f"start_idx {start_idx} out of range for T={T}") if end_idx <= start_idx or end_idx > T: raise ValueError(f"end_idx {end_idx} invalid for T={T}") new_spike_stack = [] new_times = [] for sd, t in zip(self.spike_stack, self.times): new_spike_stack.append( sd.subtime( sd.start_time + float(start_idx), sd.start_time + float(end_idx) ) ) abs_start = t[0] + float(start_idx) abs_end = t[0] + float(end_idx) new_times.append((abs_start, abs_end)) return SpikeSliceStack( spike_stack=new_spike_stack, times_start_to_end=new_times, neuron_attributes=self.neuron_attributes, )
[docs] def to_raster_array(self, bin_size=1.0, absolute_times=False): """Convert the spike stack into a 3D raster array of shape (N, T, S). Each slice is rasterized with the given bin size, producing a spike count matrix where entry (n, t, s) is the number of spikes unit n fired in time bin t of slice s. Parameters: bin_size (float): Time bin size in ms (default 1.0). absolute_times (bool): If False (default), time bin 0 corresponds to the start of each slice (0-based). If True, each slice's spikes are offset by its absolute start time from self.times, so bin indices reflect the original recording position. The T dimension is sized to cover the full time span from the earliest slice start to the latest slice end. **Caution:** this can produce very large arrays when the recording span is long and bin_size is small. Returns: raster_stack (np.ndarray): 3D array of shape (N, T, S) with non-negative integer spike counts. When absolute_times is True, T covers the full recording span and all slices share the same time axis. """ if not absolute_times: dense_list = [] for sd in self.spike_stack: # Spike times are relative to each slice (0-based or event-centered). # sparse_raster handles start_time internally. dense_list.append(sd.sparse_raster(bin_size=bin_size).toarray()) return np.stack(dense_list, axis=2) # Absolute times: offset each slice by its start time so bin indices # reflect original recording position. All slices share the same # time axis spanning [min(start), max(end)]. global_start = min(start for start, _ in self.times) global_end = max(end for _, end in self.times) total_bins = int(np.ceil((global_end - global_start) / bin_size)) raster_stack = np.zeros((self.N, total_bins, len(self.spike_stack)), dtype=int) for s_idx, (sd, (start, _)) in enumerate(zip(self.spike_stack, self.times)): offset = start - global_start r = sd.sparse_raster(bin_size=bin_size, time_offset=offset).toarray() raster_stack[:, : r.shape[1], s_idx] = r return raster_stack
[docs] def compute_frac_active(self, min_spikes=2): """Compute the fraction of slices each unit is active in. A unit counts as active in a slice if it has at least min_spikes spikes within that slice's time window. Parameters: min_spikes (int): Minimum number of spikes for a unit to count as active in a slice (default: 2). Returns: frac_active (np.ndarray): 1-D array of shape ``(U,)`` with the fraction of slices each unit is active in (values in [0, 1]). Notes: - The returned array can be passed as ``frac_active`` to ``RateSliceStack.order_units_across_slices``, ``RateSliceStack.get_slice_to_slice_unit_corr_from_stack``, ``SpikeSliceStack.order_units_across_slices``, or ``SpikeSliceStack.get_slice_to_slice_unit_comparison`` to override their internal activity calculation. - ``SpikeData.get_frac_active`` produces a compatible ``(U,)`` array based on burst edges and can be used in the same way. """ num_units = self.N num_slices = len(self.spike_stack) active_count = np.zeros(num_units, dtype=int) for sd, (start, end) in zip(self.spike_stack, self.times): for u in range(num_units): spikes = np.asarray(sd.train[u]) n_valid = np.sum( (spikes >= sd.start_time) & (spikes <= sd.start_time + (end - start)) ) if n_valid >= min_spikes: active_count[u] += 1 return active_count / num_slices if num_slices > 0 else np.zeros(num_units)
[docs] def order_units_across_slices( self, agg_func="median", timing="median", min_spikes=2, min_frac_active=0.0, frac_active=None, timing_matrix=None, ): """Reorder units by their typical spike timing across slices. For each unit in each slice, computes a representative spike time (median, mean, or first spike) relative to the slice's time origin. These per-slice values are aggregated across slices to obtain a single typical timing per unit. Units are then sorted by this value from earliest to latest and optionally split into a highly-active group and a low-activity group. Parameters: agg_func (str): How to aggregate per-slice timing values across slices. ``"median"`` (default) or ``"mean"``. timing (str): Which spike time to extract per unit per slice. ``"median"`` — median spike time within the slice (default). ``"mean"`` — mean spike time within the slice. ``"first"`` — first spike time (onset latency). Ignored when ``timing_matrix`` is provided. min_spikes (int): Minimum number of spikes for a unit to count as active in a slice (default: 2). Ignored when ``timing_matrix`` is provided. min_frac_active (float or None): Minimum fraction of slices a unit must be active in to be placed in the highly-active group. ``0.0`` or ``None`` (default: 0.0) skips the split entirely and places all units in the highly-active group without computing activity fractions. frac_active (np.ndarray or None): Optional pre-computed fraction-active array of shape ``(U,)`` to override the internal calculation for the group split. Only used when ``min_frac_active > 0``. Compatible sources: ``SpikeSliceStack.compute_frac_active`` and ``SpikeData.get_frac_active`` (``frac_per_unit`` output). timing_matrix (np.ndarray or None): Optional pre-computed ``(U, S)`` timing matrix from ``get_unit_timing_per_slice``. When provided, ``timing`` and ``min_spikes`` are ignored and this matrix is used directly. Returns: reordered_stacks (tuple): Two ``SpikeSliceStack`` objects ``(highly_active, low_active)`` with units reordered by typical timing. The low-activity stack is ``None`` when the group is empty. unit_ids_in_order (tuple): Two ``ndarray`` ``(highly_active, low_active)`` of original unit indices in the reordered sequence. unit_std (tuple): Two ``ndarray`` ``(highly_active, low_active)`` of standard deviation of per-slice timing values. Lower values indicate more consistent timing across slices. unit_peak_times_ms (tuple): Two ``ndarray`` ``(highly_active, low_active)`` of the aggregated typical timing in milliseconds relative to slice start. NaN for units with no active slices. unit_frac_active (tuple): Two ``ndarray`` ``(highly_active, low_active)`` of the fraction of slices each unit was active in. Notes: - Call ``get_unit_timing_per_slice`` first to pre-compute the timing matrix if you want to reuse it across multiple calls (e.g. ``rank_order_correlation`` and this method). - When ``frac_active`` is None and ``min_frac_active > 0``, activity fraction is computed via ``compute_frac_active``. - Analogous to ``RateSliceStack.order_units_across_slices`` but operates on raw spike trains instead of firing rate curves. """ if agg_func not in ("median", "mean"): raise ValueError(f"agg_func must be 'median' or 'mean', got {agg_func!r}") num_units = self.N num_slices = len(self.spike_stack) if timing_matrix is not None: timing_matrix = np.asarray(timing_matrix, dtype=float) if timing_matrix.shape != (num_units, num_slices): raise ValueError( f"timing_matrix must have shape ({num_units}, {num_slices}), " f"got {timing_matrix.shape}" ) else: timing_matrix = self.get_unit_timing_per_slice( timing=timing, min_spikes=min_spikes ) # Standard deviation across slices unit_std_values = np.nanstd(timing_matrix, axis=1) # Aggregate across slices if agg_func == "median": unit_timing = np.nanmedian(timing_matrix, axis=1) else: unit_timing = np.nanmean(timing_matrix, axis=1) # Compute or validate frac_active only when splitting is requested skip_split = not min_frac_active if skip_split: frac_active = np.ones(num_units) ha_units = np.arange(num_units) la_units = np.array([], dtype=int) else: if frac_active is not None: frac_active = np.asarray(frac_active, dtype=float) if frac_active.shape != (num_units,): raise ValueError( f"frac_active must have shape ({num_units},), " f"got {frac_active.shape}" ) else: frac_active = self.compute_frac_active(min_spikes=min_spikes) highly_active_mask = frac_active >= min_frac_active ha_units = np.where(highly_active_mask)[0] la_units = np.where(~highly_active_mask)[0] # Sort within each group by typical timing ha_order = ha_units[np.argsort(unit_timing[ha_units])] la_order = la_units[np.argsort(unit_timing[la_units])] # Build reordered SpikeSliceStacks def _reorder_stack(unit_indices): if len(unit_indices) == 0: return None return self.subset(list(unit_indices)) ha_stack = _reorder_stack(ha_order) la_stack = _reorder_stack(la_order) return ( (ha_stack, la_stack), (ha_order, la_order), (unit_std_values[ha_order], unit_std_values[la_order]), (unit_timing[ha_order], unit_timing[la_order]), (frac_active[ha_order], frac_active[la_order]), )
[docs] def apply(self, func, *args, **kwargs): """Apply a function to each SpikeData in the stack and return stacked results. Calls ``func(sd, *args, **kwargs)`` on every slice and stacks the outputs into a single numpy array with a new leading axis of size S (number of slices). Parameters: func (callable): Function that accepts a SpikeData as its first argument and returns a numeric value (scalar, 1-D, or 2-D array). Output shape must be consistent across all slices. *args: Additional positional arguments forwarded to func. **kwargs: Additional keyword arguments forwarded to func. Returns: result (np.ndarray): Stacked results with shape ``(S, ...)``. Notes: - Intended for use with stacks built by ``SpikeData.frames``, ``SpikeData.align_to_events``, ``SpikeData.spike_shuffle_stack``, or ``SpikeData.subset_stack``. Pair with ``shuffle_z_score``, ``shuffle_percentile``, ``slice_trend``, or ``slice_stability`` from ``utils`` to interpret the results. """ results = [func(sd, *args, **kwargs) for sd in self.spike_stack] return np.stack(results, axis=0)
[docs] def unit_to_unit_comparison( self, metric="ccg", delt=20.0, bin_size=1.0, max_lag=350, n_jobs=-1, ): """Compute pairwise unit-to-unit similarity within each slice using spike-based metrics. For each slice, computes a (U, U) similarity matrix between all unit pairs, then stacks the results into a ``PairwiseCompMatrixStack (U, U, S)``. Parameters: metric (str): Similarity metric to use. ``"ccg"`` for cross-correlogram on binned rasters (default), ``"sttc"`` for spike time tiling coefficient. delt (float): STTC time window in milliseconds (default: 20.0). Only used when metric is ``"sttc"``. bin_size (float): Bin size in milliseconds for the binary raster (default: 1.0). Only used when metric is ``"ccg"``. max_lag (float): Maximum lag in milliseconds to search for the peak correlation (default: 350). Only used when metric is ``"ccg"``. Returns: corr_stack (PairwiseCompMatrixStack): Pairwise similarity scores between all unit pairs for each slice. Shape is ``(U, U, S)``. lag_stack (PairwiseCompMatrixStack or None): Lag at which maximum similarity occurs for each pair per slice. Shape is ``(U, U, S)``. ``None`` when metric is ``"sttc"`` (STTC has no lag). av_corr (np.ndarray): Average similarity per slice across all unit pairs in the lower triangle. Shape is ``(S,)``. av_lag (np.ndarray or None): Average lag per slice. Shape is ``(S,)``. ``None`` when metric is ``"sttc"``. Notes: - Analogous to ``RateSliceStack.unit_to_unit_correlation`` but operates on raw spike trains instead of firing rate time series. """ if metric not in ("sttc", "ccg"): raise ValueError(f"metric must be 'sttc' or 'ccg', got {metric!r}") num_units = self.N num_slices = len(self.spike_stack) if num_units < 2: warnings.warn( "Cannot compute unit-to-unit comparison with fewer than " "2 units. Returning NaN.", RuntimeWarning, ) nan_stack = np.full((num_units, num_units, num_slices), np.nan) nan_avgs = np.full(num_slices, np.nan) return ( PairwiseCompMatrixStack(stack=nan_stack, times=self.times), ( PairwiseCompMatrixStack(stack=nan_stack.copy(), times=self.times) if metric == "ccg" else None ), nan_avgs, nan_avgs.copy() if metric == "ccg" else None, ) corr_matrices = [] lag_matrices = [] for sd in self.spike_stack: if metric == "sttc": pcm = sd.spike_time_tilings(delt=delt) corr_matrices.append(pcm.matrix) else: # ccg corr_pcm, lag_pcm = sd.get_pairwise_ccg( bin_size=bin_size, max_lag=max_lag, n_jobs=n_jobs ) corr_matrices.append(corr_pcm.matrix) lag_matrices.append(lag_pcm.matrix) # Stack: list of (U, U) -> (S, U, U) -> transpose to (U, U, S) corr_array = np.moveaxis(np.stack(corr_matrices, axis=0), 0, 2) lower_tri = np.tril_indices(num_units, k=-1) av_corr = np.nanmean(corr_array[lower_tri[0], lower_tri[1], :], axis=0) corr_stack = PairwiseCompMatrixStack(stack=corr_array, times=self.times) if metric == "ccg": lag_array = np.moveaxis(np.stack(lag_matrices, axis=0), 0, 2) av_lag = np.nanmean(lag_array[lower_tri[0], lower_tri[1], :], axis=0) lag_stack = PairwiseCompMatrixStack(stack=lag_array, times=self.times) else: lag_stack = None av_lag = None return corr_stack, lag_stack, av_corr, av_lag
[docs] def get_slice_to_slice_unit_comparison( self, metric="ccg", delt=20.0, bin_size=1.0, max_lag=350, min_spikes=2, min_frac=0.3, frac_active=None, n_jobs=-1, ): """Compute slice-to-slice similarity for each unit using spike-based metrics. For each unit independently, compares its spike train across every pair of slices. Asks: "Does unit X fire in the same temporal pattern across repeated events?" Returns a ``PairwiseCompMatrixStack (S, S, U)``. Parameters: metric (str): Similarity metric to use. ``"ccg"`` for cross-correlogram on binned rasters (default), ``"sttc"`` for spike time tiling coefficient. delt (float): STTC time window in milliseconds (default: 20.0). Only used when metric is ``"sttc"``. bin_size (float): Bin size in milliseconds for the binary raster (default: 1.0). Only used when metric is ``"ccg"``. max_lag (float): Maximum lag in milliseconds to search for the peak correlation (default: 350). Only used when metric is ``"ccg"``. min_spikes (int): Minimum number of spikes in a slice for a unit to be considered valid in that slice (default: 2). min_frac (float): Maximum fraction of slices that can be invalid before a unit's average is set to NaN (default: 0.3). frac_active (np.ndarray or None): Optional pre-computed fraction-active array of shape ``(U,)`` to override the internal per-unit validity check for computing averages. When provided, a unit's average is set to NaN if ``frac_active[u] < (1 - min_frac)``. ``min_spikes`` still controls which individual slice pairs are computed. Compatible sources: ``SpikeSliceStack.compute_frac_active`` and ``SpikeData.get_frac_active`` (``frac_per_unit`` output). n_jobs (int): Number of threads for parallel computation. -1 uses all cores (default), 1 disables parallelism, None is serial. Returns: all_corr (PairwiseCompMatrixStack): Pairwise similarity between all slice pairs for each unit. Shape is ``(S, S, U)``. all_lag (PairwiseCompMatrixStack or None): Lag at which maximum similarity occurs for each slice pair per unit. Shape is ``(S, S, U)``. ``None`` when metric is ``"sttc"``. av_corr (np.ndarray): Average similarity per unit across all valid slice pairs. Shape is ``(U,)``. av_lag (np.ndarray or None): Average lag per unit. Shape is ``(U,)``. ``None`` when metric is ``"sttc"``. Notes: - Analogous to ``RateSliceStack.get_slice_to_slice_unit_corr_from_stack`` but operates on raw spike trains. - Spike times within each slice are relative to the slice time origin (0-based or event-centered) for aligned comparison. """ if metric not in ("sttc", "ccg"): raise ValueError(f"metric must be 'sttc' or 'ccg', got {metric!r}") num_units = self.N num_slices = len(self.spike_stack) if num_slices < 2: warnings.warn( "Cannot compute slice-to-slice unit comparison with fewer than " "2 slices. Returning NaN.", RuntimeWarning, ) av_corr = np.full(num_units, np.nan) nan_stack = np.full((num_slices, num_slices, num_units), np.nan) return ( PairwiseCompMatrixStack(stack=nan_stack), ( PairwiseCompMatrixStack(stack=nan_stack.copy()) if metric == "ccg" else None ), av_corr, av_corr.copy() if metric == "ccg" else None, ) # Pre-compute spike trains per slice (relative to slice time origin) # and per-slice rasters for CCG shifted_trains = [] # list of S lists, each containing U spike arrays slice_durations = [] slice_rasters = [] # only populated for CCG for sd, (start, end) in zip(self.spike_stack, self.times): duration = end - start slice_durations.append(duration) trains = [] for u in range(num_units): trains.append(np.asarray(sd.train[u])) shifted_trains.append(trains) if metric == "ccg": # Build shifted SpikeData for raster computation temp_sd = SpikeData( trains, length=duration, start_time=sd.start_time, N=num_units ) slice_rasters.append(temp_sd.raster(bin_size)) max_lag_bins = int(round(max_lag / bin_size)) if metric == "ccg" else 0 # Validate frac_active override if provided if frac_active is not None: frac_active = np.asarray(frac_active, dtype=float) if frac_active.shape != (num_units,): raise ValueError( f"frac_active must have shape ({num_units},), " f"got {frac_active.shape}" ) # Initialize result arrays: (U, S, S), will transpose to (S, S, U) at end all_corr_scores = np.full((num_units, num_slices, num_slices), np.nan) all_lag_scores = ( np.full((num_units, num_slices, num_slices), np.nan) if metric == "ccg" else None ) av_corr = np.full(num_units, np.nan) av_lag = np.full(num_units, np.nan) if metric == "ccg" else None lower_tri = np.tril_indices(num_slices, k=-1) start_times = [sd.start_time for sd in self.spike_stack] def _process_unit(unit): unit_corr = np.full((num_slices, num_slices), np.nan) unit_lag = ( np.full((num_slices, num_slices), np.nan) if metric == "ccg" else None ) invalid_count = 0 for ref_s in range(num_slices): ref_train = shifted_trains[ref_s][unit] if len(ref_train) < min_spikes: invalid_count += 1 continue for comp_s in range(ref_s, num_slices): comp_train = shifted_trains[comp_s][unit] if len(comp_train) < min_spikes: continue if metric == "sttc": length = max(slice_durations[ref_s], slice_durations[comp_s]) # start_time from the ref slice is correct here: # all slices share the same start_time (event-centered # data has -pre_ms, frames() produces 0-based slices). score = get_sttc( ref_train, comp_train, delt=delt, length=length, start_time=start_times[ref_s], ) unit_corr[ref_s, comp_s] = score unit_corr[comp_s, ref_s] = score else: ref_signal = slice_rasters[ref_s][unit, :] comp_signal = slice_rasters[comp_s][unit, :] score, lag = compute_cross_correlation_with_lag( ref_signal, comp_signal, max_lag=max_lag_bins ) unit_corr[ref_s, comp_s] = score unit_corr[comp_s, ref_s] = score unit_lag[ref_s, comp_s] = lag unit_lag[comp_s, ref_s] = -lag if frac_active is not None: unit_valid = frac_active[unit] >= (1 - min_frac) else: unit_valid = invalid_count / num_slices <= min_frac av_c = ( np.nanmean(unit_corr[lower_tri[0], lower_tri[1]]) if unit_valid else np.nan ) av_l = np.nan if metric == "ccg" and unit_valid: av_l = np.nanmean(unit_lag[lower_tri[0], lower_tri[1]]) return unit, unit_corr, unit_lag, av_c, av_l n_workers = _resolve_n_jobs(n_jobs) if n_workers > 1 and num_units > 1: with ThreadPoolExecutor(max_workers=n_workers) as pool: results = pool.map(_process_unit, range(num_units)) else: results = map(_process_unit, range(num_units)) for unit, unit_corr, unit_lag, av_c, av_l in results: all_corr_scores[unit] = unit_corr av_corr[unit] = av_c if metric == "ccg": all_lag_scores[unit] = unit_lag av_lag[unit] = av_l # Transpose from (U, S, S) to (S, S, U) all_corr_scores = np.moveaxis(all_corr_scores, 0, 2) all_corr_stack = PairwiseCompMatrixStack(stack=all_corr_scores) if metric == "ccg": all_lag_scores = np.moveaxis(all_lag_scores, 0, 2) all_lag_stack = PairwiseCompMatrixStack(stack=all_lag_scores) else: all_lag_stack = None return all_corr_stack, all_lag_stack, av_corr, av_lag
[docs] def get_unit_timing_per_slice( self, timing="median", min_spikes=2, ): """Compute a representative spike time for each unit in each slice. Returns a ``(U, S)`` matrix where entry ``[u, s]`` is the timing value (in milliseconds relative to the slice's time origin) for unit u in slice s. For event-centered slices, t=0 is the event. Units with fewer than min_spikes spikes in a slice are marked NaN. Parameters: timing (str): Which spike time to extract per unit per slice. ``"median"`` (default), ``"mean"``, or ``"first"`` (onset latency). min_spikes (int): Minimum number of spikes for a unit to count as active in a slice (default: 2). Returns: timing_matrix (np.ndarray): Array of shape ``(U, S)`` with timing values in milliseconds relative to each slice's time origin. NaN where the unit is inactive. Notes: - Values are in milliseconds, not bin indices. This differs from ``RateSliceStack.get_unit_timing_per_slice`` which returns bin indices (suitable for direct indexing into the event stack). Both representations preserve rank order, so ``rank_order_correlation`` produces identical results either way. - The returned matrix can be passed to ``rank_order_correlation`` to compute Spearman rank correlations between slice pairs, or used as input to ``order_units_across_slices`` for manual inspection of per-slice timing values. """ if timing not in ("median", "mean", "first"): raise ValueError( f"timing must be 'median', 'mean', or 'first', got {timing!r}" ) num_units = self.N num_slices = len(self.spike_stack) timing_matrix = np.full((num_units, num_slices), np.nan) for s_idx, (sd, (start, end)) in enumerate(zip(self.spike_stack, self.times)): for u in range(num_units): spikes = np.asarray(sd.train[u]) duration = end - start spikes = spikes[ (spikes >= sd.start_time) & (spikes <= sd.start_time + duration) ] if len(spikes) < min_spikes: continue if timing == "median": timing_matrix[u, s_idx] = np.median(spikes) elif timing == "mean": timing_matrix[u, s_idx] = np.mean(spikes) else: timing_matrix[u, s_idx] = spikes[0] return timing_matrix
[docs] def rank_order_correlation( self, timing_matrix=None, timing="median", min_spikes=2, min_overlap=3, n_shuffles=100, min_overlap_frac=None, seed=1, n_jobs=-1, ): """Compute Spearman rank-order correlation of unit timing between all slice pairs. For each pair of slices, only units active in both slices (non-NaN in both columns of the timing matrix) are included. If the overlap falls below the required minimum, the pair is set to NaN. When ``n_shuffles > 0``, the rank orders are shuffled n_shuffles times for each pair to build a null distribution, and the raw correlation is z-score normalised against it. Parameters: timing_matrix (np.ndarray or None): Array of shape ``(U, S)`` with timing values per unit per slice. NaN entries mark inactive units. Typically produced by ``get_unit_timing_per_slice``. When None, computed automatically using timing and min_spikes. timing (str): Which spike time to extract per unit per slice. ``"median"`` (default), ``"mean"``, or ``"first"``. Only used when timing_matrix is None. min_spikes (int): Minimum spikes for activity (default: 2). Only used when timing_matrix is None. min_overlap (int): Minimum number of units that must be active in both slices (default: 3). min_overlap_frac (float or None): Minimum fraction of total units that must be active in both slices (default: None). When provided, the effective threshold is ``max(min_overlap, ceil(min_overlap_frac * U))``. n_shuffles (int): Number of shuffle iterations for z-scoring (default: 100). Set to 0 to return raw Spearman correlations. Values between 1 and 4 are rejected (minimum 5 required for a meaningful null distribution). seed (int or None): Random seed for reproducibility of the shuffle (default: 1). n_jobs (int): Number of threads for parallel computation. -1 uses all cores (default), 1 disables parallelism, None is serial. Returns: corr_matrix (PairwiseCompMatrix): Spearman correlation matrix of shape ``(S, S)``. When ``n_shuffles > 0``, values are z-scores. When ``n_shuffles == 0``, values are raw Spearman correlations. av_corr (float): Average correlation (or z-score) across all valid lower-triangle pairs. overlap_matrix (PairwiseCompMatrix): Matrix of shape ``(S, S)`` with fraction of units active in both slices. """ from .utils import _rank_order_correlation_from_timing if timing_matrix is None: timing_matrix = self.get_unit_timing_per_slice( timing=timing, min_spikes=min_spikes ) return _rank_order_correlation_from_timing( timing_matrix, min_overlap=min_overlap, min_overlap_frac=min_overlap_frac, n_shuffles=n_shuffles, seed=seed, n_jobs=n_jobs, )
[docs] def plot_aligned_slice_single_unit( self, unit_idx, ax=None, color_vals=None, color_label="", cmap="viridis", time_offset=0, xlabel="Rel. time (ms)", ylabel="Burst", x_range=None, vlines=None, show_colorbar=True, marker_size=20, font_size=None, style="scatter", invert_y=False, linewidths=0.5, ): """Plot a single unit's spike times across all slices as a raster. Extracts the spike train for unit_idx from every slice and delegates to :func:`~SpikeLab.spikedata.plot_utils.plot_aligned_slice_single_unit`. Parameters: unit_idx (int): Index of the unit to plot. ax (matplotlib.axes.Axes or None): Target axes. If None, a new figure and axes are created. color_vals (np.ndarray or None): Per-slice colour values. color_label (str): Colorbar label. cmap (str): Matplotlib colormap name. time_offset (float): Value subtracted from every spike time before plotting. Slices from ``align_to_events`` are already event-centered (spike times in ``[-pre_ms, +post_ms]``), so use the default ``time_offset=0``. Only set a non-zero value when spike times are not already centered on the event. xlabel (str): X-axis label. ylabel (str): Y-axis label. x_range (tuple or None): ``(xmin, xmax)`` for the x-axis. vlines (list[dict] or None): Vertical reference lines. Each dict must contain ``'x'`` and may optionally include ``'color'``, ``'linestyle'``, ``'linewidth'``. show_colorbar (bool): Add a colorbar when color_vals is provided. marker_size (float): Scatter marker size. font_size (int or None): Font size for labels/ticks. style (str): ``"scatter"`` for dot markers, ``"eventplot"`` for vertical line markers. invert_y (bool): If True, first slice at top, last at bottom. linewidths (float): Line width for eventplot markers. Returns: result: ``(fig, ax, sc)`` when ax is None, otherwise just sc. sc is the scatter ``PathCollection`` (or None if no colour coding). """ from .plot_utils import ( plot_aligned_slice_single_unit as _plot_aligned_slice_single_unit, ) try: import matplotlib.pyplot as plt except ImportError as e: raise ImportError( "plot_aligned_slice_single_unit requires 'matplotlib'. " "Install with: pip install matplotlib" ) from e if unit_idx < 0 or unit_idx >= self.N: raise IndexError(f"unit_idx {unit_idx} out of range for {self.N} units.") # Extract per-slice spike times for this unit spike_times_per_slice = [sd.train[unit_idx] for sd in self.spike_stack] standalone = ax is None if standalone: fig, ax = plt.subplots(figsize=(8, 6)) sc = _plot_aligned_slice_single_unit( ax, spike_times_per_slice, color_vals=color_vals, color_label=color_label, cmap=cmap, time_offset=time_offset, xlabel=xlabel, ylabel=ylabel, x_range=x_range, vlines=vlines, show_colorbar=show_colorbar, marker_size=marker_size, font_size=font_size, style=style, invert_y=invert_y, linewidths=linewidths, ) if standalone: plt.tight_layout() return fig, ax, sc return sc