import warnings
import numpy as np
__all__ = ["SpikeSliceStack"]
from .pairwise import PairwiseCompMatrix, PairwiseCompMatrixStack
from .spikedata import SpikeData
from concurrent.futures import ThreadPoolExecutor
from .utils import (
_validate_time_start_to_end,
_get_attr,
get_sttc,
compute_cross_correlation_with_lag,
_resolve_n_jobs,
)
[docs]
class SpikeSliceStack:
"""A list of SpikeData objects, one per slice, with spike-based comparison capabilities.
U is units (neurons) and S is slices (bursts, events, etc). Construct from
either a single SpikeData with time specifications, or directly from a
pre-built list of SpikeData objects.
Parameters:
data_obj (SpikeData or None): A SpikeData object to slice. Provide
either this or spike_stack, not both.
times_start_to_end (list or None): Each entry is a tuple (start, end)
representing the start and end times of a desired slice. Each
tuple must have the same duration.
time_peaks (list or None): List of times as int or float where there
is a burst peak or stimulation event. Must be paired with
time_bounds. Alternative to times_start_to_end.
time_bounds (tuple or None): Single tuple (left_bound, right_bound).
For example, (250, 500) means 250 ms before peak and 500 ms
after peak. Must be paired with time_peaks.
spike_stack (list or None): List of SpikeData objects, one per slice.
All must have the same number of units. Spike times must be
relative to the slice (0-based or event-centered via
start_time), not absolute recording times. Provide either this
or data_obj, not both.
neuron_attributes (list or None): List of attribute dicts, one per
unit. If None, inherited from data_obj when available.
drop_slice_attributes (bool): If True (default), neuron_attributes
are removed from individual SpikeData slices after construction.
The shared copy is stored at neuron_attributes. This avoids
duplicating large per-unit data (e.g. waveform templates) across
every slice. Set to False to keep per-slice attributes.
Attributes:
spike_stack (list): List of SpikeData objects, one per slice. Spike
times are relative to the slice window. For 0-based slices,
times run from 0 to duration. For event-centered slices, times
run from -pre_ms to +post_ms with t=0 at the event. Use
self.times for absolute recording time positions.
times (list): List of (start, end) time bounds for each slice in
absolute recording time, sorted chronologically. Length equals S.
Example: [(100, 350), (500, 750), (1000, 1250)].
N (int): Number of units.
neuron_attributes (list or None): List of attribute dicts, one per
unit. None if not provided. When drop_slice_attributes is True
(default), this is the only copy and individual slices will have
neuron_attributes set to None.
"""
[docs]
def __init__(
self,
data_obj=None,
times_start_to_end=None,
time_peaks=None,
time_bounds=None,
spike_stack=None,
neuron_attributes=None,
drop_slice_attributes=True,
):
if data_obj is None and spike_stack is None:
raise TypeError(
"Must input either a SpikeData as data_obj (option 1) or spike_stack (option 2)"
)
if data_obj is not None and spike_stack is not None:
warnings.warn(
"User input both data_obj and spike_stack. "
"Ignoring data_obj and using spike_stack instead.",
UserWarning,
)
data_obj = None
# Option 1: Using data_obj
if data_obj is not None:
if not isinstance(data_obj, SpikeData):
raise TypeError("data_obj must be a SpikeData object")
if times_start_to_end is None:
if time_peaks is None or time_bounds is None:
raise ValueError(
"Must provide either times_start_to_end or "
"both time_peaks and time_bounds"
)
if not isinstance(time_bounds, tuple) or len(time_bounds) != 2:
raise TypeError(
"time_bounds must be a tuple of (before, after) durations"
)
before, after = time_bounds
time_peaks = sorted(time_peaks)
times_start_to_end = []
for t in time_peaks:
times_start_to_end.append((t - before, t + after))
rec_range = (
data_obj.start_time,
data_obj.start_time + data_obj.length,
)
times_start_to_end = _validate_time_start_to_end(
times_start_to_end, recording_range=rec_range
)
self.times = times_start_to_end
self.spike_stack = []
if time_peaks is not None:
# Event-centered: shift_to=peak so t=0 is the event
for peak, (start, end) in zip(time_peaks, times_start_to_end):
self.spike_stack.append(data_obj.subtime(start, end, shift_to=peak))
else:
# Standard: shift_to=start so t=0 is the window start
for start, end in times_start_to_end:
self.spike_stack.append(data_obj.subtime(start, end))
if neuron_attributes is None:
neuron_attributes = data_obj.neuron_attributes
# Option 2: Using spike_stack directly
if spike_stack is not None:
if not isinstance(spike_stack, list):
raise TypeError("spike_stack must be a list of SpikeData objects")
for s in spike_stack:
if not isinstance(s, SpikeData):
raise TypeError("spike_stack must be a list of SpikeData objects")
if len(spike_stack) == 0:
raise ValueError("spike_stack must not be empty")
N = spike_stack[0].N
for s in spike_stack:
if s.N != N:
raise ValueError(
"All SpikeData objects in spike_stack must have the same number of units"
)
if times_start_to_end is None:
t = 0.0
times_start_to_end = []
for s in spike_stack:
times_start_to_end.append((t, t + s.length))
t += s.length
else:
warn_neg = spike_stack[0].start_time >= 0
times_start_to_end = _validate_time_start_to_end(
times_start_to_end, warn_negative_start=warn_neg
)
if len(times_start_to_end) != len(spike_stack):
raise ValueError(
"times_start_to_end must have the same length as spike_stack"
)
self.spike_stack = list(spike_stack)
self.times = times_start_to_end
# Validate that spike times are consistent with the slice
# duration. Spike times must be relative to the slice (0-based
# or event-centered), not absolute recording times.
for i, (sd, (start, end)) in enumerate(zip(self.spike_stack, self.times)):
duration = end - start
expected_start = sd.start_time
expected_end = sd.start_time + duration
for u, train in enumerate(sd.train):
if len(train) == 0:
continue
if train[0] < expected_start or train[-1] > expected_end:
raise ValueError(
f"Slice {i}, unit {u}: spike times "
f"[{train[0]:.1f}, {train[-1]:.1f}] ms fall outside "
f"expected range [{expected_start:.1f}, "
f"{expected_end:.1f}] ms. "
"Spike times must be relative to the slice (0-based "
"or event-centered), not absolute recording times."
)
self.N = self.spike_stack[0].N
self.neuron_attributes = None
if neuron_attributes is not None:
self.neuron_attributes = neuron_attributes.copy()
if len(self.neuron_attributes) != self.N:
raise ValueError(
f"neuron_attributes has {len(self.neuron_attributes)} items "
f"but spike_stack has {self.N} units"
)
# Strip per-slice neuron_attributes to avoid duplicating large data
# (e.g. waveform templates) across every slice.
if drop_slice_attributes:
for sd in self.spike_stack:
sd.neuron_attributes = None
def __repr__(self) -> str:
S = len(self.spike_stack)
return f"SpikeSliceStack(N={self.N}, S={S})"
def __len__(self) -> int:
return len(self.spike_stack)
def __iter__(self):
return iter(self.spike_stack)
[docs]
def subslice(self, slices):
"""Extract a subset of slices from the spike stack.
Parameters:
slices (int or list): Slice index or list of slice indices to
extract.
Returns:
result (SpikeSliceStack): New SpikeSliceStack containing only the
specified slices. Shape changes from S to S_trimmed. All
units and neuron_attributes are carried over.
"""
S = len(self.spike_stack)
if isinstance(slices, int):
slices = [slices]
for s in slices:
if s >= S or s < -S:
raise ValueError(f"One or more slice indices out of range for S={S}")
slices = sorted(slices)
new_spike_stack = []
new_times = []
for s in slices:
new_spike_stack.append(self.spike_stack[s])
new_times.append(self.times[s])
return SpikeSliceStack(
spike_stack=new_spike_stack,
times_start_to_end=new_times,
neuron_attributes=self.neuron_attributes,
)
[docs]
def subset(self, units, by=None):
"""Extract a subset of units from every slice in the spike stack.
Parameters:
units (int, str, or list): Unit indices to extract. If by is None,
must be int(s). If by is set, values to match in
neuron_attributes.
by (str or None): If set, select units by this neuron_attribute
key instead of by index.
Returns:
result (SpikeSliceStack): New SpikeSliceStack containing only the
specified units across all slices. All slices and
neuron_attributes are carried over.
Notes:
- Units are included in the output in the order they appear in the
train (ascending index order), not the order listed in units.
- If IDs are not unique (when using by), every matching neuron is
included.
"""
if isinstance(units, (int, str)):
units = [units]
# Resolve which indices will be kept so we can update neuron_attributes
if by is not None:
if self.neuron_attributes is None:
raise ValueError("can't use `by` without `neuron_attributes`")
_missing = object()
unit_set = set(units)
kept_indices = []
for i in range(self.N):
if _get_attr(self.neuron_attributes[i], by, _missing) in unit_set:
kept_indices.append(i)
else:
kept_indices = sorted(set(int(u) for u in units))
for u in kept_indices:
if u < 0 or u >= self.N:
raise ValueError(f"Unit index {u} out of range for {self.N} units.")
new_spike_stack = []
for sd in self.spike_stack:
new_spike_stack.append(sd.subset(kept_indices))
new_neuron_attributes = None
if self.neuron_attributes is not None:
new_neuron_attributes = []
for i in kept_indices:
new_neuron_attributes.append(self.neuron_attributes[i])
return SpikeSliceStack(
spike_stack=new_spike_stack,
times_start_to_end=self.times,
neuron_attributes=new_neuron_attributes,
)
[docs]
def subtime_by_index(self, start_idx, end_idx):
"""Trim each slice to a sub-window specified by millisecond indices.
Indices are measured from the start of each slice (1 index = 1 ms).
Trims along the time axis while preserving all slices and units.
Parameters:
start_idx (int): Start index in ms from slice start (inclusive).
Supports negative indexing.
end_idx (int): End index in ms from slice start (exclusive).
Supports negative indexing.
Returns:
result (SpikeSliceStack): New SpikeSliceStack where each slice is
trimmed to the corresponding absolute time window. Absolute
spike times are preserved (not shifted). self.times is updated
to reflect the new absolute time bounds.
Notes:
- Indices are relative to each slice's own start (index 0 = slice
start ms). They are converted to absolute recording times
internally before trimming.
- Original absolute timestamps are preserved. To get
shifted-to-zero timestamps, create a new SpikeSliceStack.
- All slices and neuron_attributes are carried over from the
original.
"""
slice_duration_ms = self.times[0][1] - self.times[0][0]
T = int(round(slice_duration_ms))
if start_idx < 0:
start_idx += T
if end_idx < 0:
end_idx += T
if start_idx < 0 or start_idx >= T:
raise ValueError(f"start_idx {start_idx} out of range for T={T}")
if end_idx <= start_idx or end_idx > T:
raise ValueError(f"end_idx {end_idx} invalid for T={T}")
new_spike_stack = []
new_times = []
for sd, t in zip(self.spike_stack, self.times):
new_spike_stack.append(
sd.subtime(
sd.start_time + float(start_idx), sd.start_time + float(end_idx)
)
)
abs_start = t[0] + float(start_idx)
abs_end = t[0] + float(end_idx)
new_times.append((abs_start, abs_end))
return SpikeSliceStack(
spike_stack=new_spike_stack,
times_start_to_end=new_times,
neuron_attributes=self.neuron_attributes,
)
[docs]
def to_raster_array(self, bin_size=1.0, absolute_times=False):
"""Convert the spike stack into a 3D raster array of shape (N, T, S).
Each slice is rasterized with the given bin size, producing a spike
count matrix where entry (n, t, s) is the number of spikes unit n
fired in time bin t of slice s.
Parameters:
bin_size (float): Time bin size in ms (default 1.0).
absolute_times (bool): If False (default), time bin 0 corresponds
to the start of each slice (0-based). If True, each slice's
spikes are offset by its absolute start time from self.times,
so bin indices reflect the original recording position. The T
dimension is sized to cover the full time span from the
earliest slice start to the latest slice end. **Caution:**
this can produce very large arrays when the recording span is
long and bin_size is small.
Returns:
raster_stack (np.ndarray): 3D array of shape (N, T, S) with
non-negative integer spike counts. When absolute_times is
True, T covers the full recording span and all slices share
the same time axis.
"""
if not absolute_times:
dense_list = []
for sd in self.spike_stack:
# Spike times are relative to each slice (0-based or event-centered).
# sparse_raster handles start_time internally.
dense_list.append(sd.sparse_raster(bin_size=bin_size).toarray())
return np.stack(dense_list, axis=2)
# Absolute times: offset each slice by its start time so bin indices
# reflect original recording position. All slices share the same
# time axis spanning [min(start), max(end)].
global_start = min(start for start, _ in self.times)
global_end = max(end for _, end in self.times)
total_bins = int(np.ceil((global_end - global_start) / bin_size))
raster_stack = np.zeros((self.N, total_bins, len(self.spike_stack)), dtype=int)
for s_idx, (sd, (start, _)) in enumerate(zip(self.spike_stack, self.times)):
offset = start - global_start
r = sd.sparse_raster(bin_size=bin_size, time_offset=offset).toarray()
raster_stack[:, : r.shape[1], s_idx] = r
return raster_stack
[docs]
def compute_frac_active(self, min_spikes=2):
"""Compute the fraction of slices each unit is active in.
A unit counts as active in a slice if it has at least min_spikes
spikes within that slice's time window.
Parameters:
min_spikes (int): Minimum number of spikes for a unit to count as
active in a slice (default: 2).
Returns:
frac_active (np.ndarray): 1-D array of shape ``(U,)`` with the
fraction of slices each unit is active in (values in [0, 1]).
Notes:
- The returned array can be passed as ``frac_active`` to
``RateSliceStack.order_units_across_slices``,
``RateSliceStack.get_slice_to_slice_unit_corr_from_stack``,
``SpikeSliceStack.order_units_across_slices``, or
``SpikeSliceStack.get_slice_to_slice_unit_comparison``
to override their internal activity calculation.
- ``SpikeData.get_frac_active`` produces a compatible ``(U,)``
array based on burst edges and can be used in the same way.
"""
num_units = self.N
num_slices = len(self.spike_stack)
active_count = np.zeros(num_units, dtype=int)
for sd, (start, end) in zip(self.spike_stack, self.times):
for u in range(num_units):
spikes = np.asarray(sd.train[u])
n_valid = np.sum(
(spikes >= sd.start_time)
& (spikes <= sd.start_time + (end - start))
)
if n_valid >= min_spikes:
active_count[u] += 1
return active_count / num_slices if num_slices > 0 else np.zeros(num_units)
[docs]
def order_units_across_slices(
self,
agg_func="median",
timing="median",
min_spikes=2,
min_frac_active=0.0,
frac_active=None,
timing_matrix=None,
):
"""Reorder units by their typical spike timing across slices.
For each unit in each slice, computes a representative spike time
(median, mean, or first spike) relative to the slice's time origin. These
per-slice values are aggregated across slices to obtain a single
typical timing per unit. Units are then sorted by this value from
earliest to latest and optionally split into a highly-active group
and a low-activity group.
Parameters:
agg_func (str): How to aggregate per-slice timing values across
slices. ``"median"`` (default) or ``"mean"``.
timing (str): Which spike time to extract per unit per slice.
``"median"`` — median spike time within the slice (default).
``"mean"`` — mean spike time within the slice.
``"first"`` — first spike time (onset latency).
Ignored when ``timing_matrix`` is provided.
min_spikes (int): Minimum number of spikes for a unit to count as
active in a slice (default: 2). Ignored when ``timing_matrix``
is provided.
min_frac_active (float or None): Minimum fraction of slices a unit
must be active in to be placed in the highly-active group.
``0.0`` or ``None`` (default: 0.0) skips the split entirely
and places all units in the highly-active group without
computing activity fractions.
frac_active (np.ndarray or None): Optional pre-computed
fraction-active array of shape ``(U,)`` to override the
internal calculation for the group split. Only used when
``min_frac_active > 0``. Compatible sources:
``SpikeSliceStack.compute_frac_active`` and
``SpikeData.get_frac_active`` (``frac_per_unit`` output).
timing_matrix (np.ndarray or None): Optional pre-computed ``(U, S)``
timing matrix from ``get_unit_timing_per_slice``. When provided,
``timing`` and ``min_spikes`` are ignored and this matrix is
used directly.
Returns:
reordered_stacks (tuple): Two ``SpikeSliceStack`` objects
``(highly_active, low_active)`` with units reordered by typical
timing. The low-activity stack is ``None`` when the group is
empty.
unit_ids_in_order (tuple): Two ``ndarray``
``(highly_active, low_active)`` of original unit indices in the
reordered sequence.
unit_std (tuple): Two ``ndarray`` ``(highly_active, low_active)``
of standard deviation of per-slice timing values. Lower values
indicate more consistent timing across slices.
unit_peak_times_ms (tuple): Two ``ndarray``
``(highly_active, low_active)`` of the aggregated typical
timing in milliseconds relative to slice start. NaN for units
with no active slices.
unit_frac_active (tuple): Two ``ndarray``
``(highly_active, low_active)`` of the fraction of slices each
unit was active in.
Notes:
- Call ``get_unit_timing_per_slice`` first to pre-compute the
timing matrix if you want to reuse it across multiple calls
(e.g. ``rank_order_correlation`` and this method).
- When ``frac_active`` is None and ``min_frac_active > 0``,
activity fraction is computed via ``compute_frac_active``.
- Analogous to ``RateSliceStack.order_units_across_slices`` but
operates on raw spike trains instead of firing rate curves.
"""
if agg_func not in ("median", "mean"):
raise ValueError(f"agg_func must be 'median' or 'mean', got {agg_func!r}")
num_units = self.N
num_slices = len(self.spike_stack)
if timing_matrix is not None:
timing_matrix = np.asarray(timing_matrix, dtype=float)
if timing_matrix.shape != (num_units, num_slices):
raise ValueError(
f"timing_matrix must have shape ({num_units}, {num_slices}), "
f"got {timing_matrix.shape}"
)
else:
timing_matrix = self.get_unit_timing_per_slice(
timing=timing, min_spikes=min_spikes
)
# Standard deviation across slices
unit_std_values = np.nanstd(timing_matrix, axis=1)
# Aggregate across slices
if agg_func == "median":
unit_timing = np.nanmedian(timing_matrix, axis=1)
else:
unit_timing = np.nanmean(timing_matrix, axis=1)
# Compute or validate frac_active only when splitting is requested
skip_split = not min_frac_active
if skip_split:
frac_active = np.ones(num_units)
ha_units = np.arange(num_units)
la_units = np.array([], dtype=int)
else:
if frac_active is not None:
frac_active = np.asarray(frac_active, dtype=float)
if frac_active.shape != (num_units,):
raise ValueError(
f"frac_active must have shape ({num_units},), "
f"got {frac_active.shape}"
)
else:
frac_active = self.compute_frac_active(min_spikes=min_spikes)
highly_active_mask = frac_active >= min_frac_active
ha_units = np.where(highly_active_mask)[0]
la_units = np.where(~highly_active_mask)[0]
# Sort within each group by typical timing
ha_order = ha_units[np.argsort(unit_timing[ha_units])]
la_order = la_units[np.argsort(unit_timing[la_units])]
# Build reordered SpikeSliceStacks
def _reorder_stack(unit_indices):
if len(unit_indices) == 0:
return None
return self.subset(list(unit_indices))
ha_stack = _reorder_stack(ha_order)
la_stack = _reorder_stack(la_order)
return (
(ha_stack, la_stack),
(ha_order, la_order),
(unit_std_values[ha_order], unit_std_values[la_order]),
(unit_timing[ha_order], unit_timing[la_order]),
(frac_active[ha_order], frac_active[la_order]),
)
[docs]
def apply(self, func, *args, **kwargs):
"""Apply a function to each SpikeData in the stack and return stacked results.
Calls ``func(sd, *args, **kwargs)`` on every slice and stacks the
outputs into a single numpy array with a new leading axis of size S
(number of slices).
Parameters:
func (callable): Function that accepts a SpikeData as its first
argument and returns a numeric value (scalar, 1-D, or 2-D
array). Output shape must be consistent across all slices.
*args: Additional positional arguments forwarded to func.
**kwargs: Additional keyword arguments forwarded to func.
Returns:
result (np.ndarray): Stacked results with shape ``(S, ...)``.
Notes:
- Intended for use with stacks built by ``SpikeData.frames``,
``SpikeData.align_to_events``, ``SpikeData.spike_shuffle_stack``,
or ``SpikeData.subset_stack``. Pair with ``shuffle_z_score``,
``shuffle_percentile``, ``slice_trend``, or ``slice_stability``
from ``utils`` to interpret the results.
"""
results = [func(sd, *args, **kwargs) for sd in self.spike_stack]
return np.stack(results, axis=0)
[docs]
def unit_to_unit_comparison(
self,
metric="ccg",
delt=20.0,
bin_size=1.0,
max_lag=350,
n_jobs=-1,
):
"""Compute pairwise unit-to-unit similarity within each slice using spike-based metrics.
For each slice, computes a (U, U) similarity matrix between all unit pairs,
then stacks the results into a ``PairwiseCompMatrixStack (U, U, S)``.
Parameters:
metric (str): Similarity metric to use. ``"ccg"`` for cross-correlogram
on binned rasters (default), ``"sttc"`` for spike time tiling coefficient.
delt (float): STTC time window in milliseconds (default: 20.0).
Only used when metric is ``"sttc"``.
bin_size (float): Bin size in milliseconds for the binary raster
(default: 1.0). Only used when metric is ``"ccg"``.
max_lag (float): Maximum lag in milliseconds to search for the peak
correlation (default: 350). Only used when metric is ``"ccg"``.
Returns:
corr_stack (PairwiseCompMatrixStack): Pairwise similarity scores between
all unit pairs for each slice. Shape is ``(U, U, S)``.
lag_stack (PairwiseCompMatrixStack or None): Lag at which maximum
similarity occurs for each pair per slice. Shape is ``(U, U, S)``.
``None`` when metric is ``"sttc"`` (STTC has no lag).
av_corr (np.ndarray): Average similarity per slice across all unit
pairs in the lower triangle. Shape is ``(S,)``.
av_lag (np.ndarray or None): Average lag per slice. Shape is ``(S,)``.
``None`` when metric is ``"sttc"``.
Notes:
- Analogous to ``RateSliceStack.unit_to_unit_correlation`` but operates
on raw spike trains instead of firing rate time series.
"""
if metric not in ("sttc", "ccg"):
raise ValueError(f"metric must be 'sttc' or 'ccg', got {metric!r}")
num_units = self.N
num_slices = len(self.spike_stack)
if num_units < 2:
warnings.warn(
"Cannot compute unit-to-unit comparison with fewer than "
"2 units. Returning NaN.",
RuntimeWarning,
)
nan_stack = np.full((num_units, num_units, num_slices), np.nan)
nan_avgs = np.full(num_slices, np.nan)
return (
PairwiseCompMatrixStack(stack=nan_stack, times=self.times),
(
PairwiseCompMatrixStack(stack=nan_stack.copy(), times=self.times)
if metric == "ccg"
else None
),
nan_avgs,
nan_avgs.copy() if metric == "ccg" else None,
)
corr_matrices = []
lag_matrices = []
for sd in self.spike_stack:
if metric == "sttc":
pcm = sd.spike_time_tilings(delt=delt)
corr_matrices.append(pcm.matrix)
else: # ccg
corr_pcm, lag_pcm = sd.get_pairwise_ccg(
bin_size=bin_size, max_lag=max_lag, n_jobs=n_jobs
)
corr_matrices.append(corr_pcm.matrix)
lag_matrices.append(lag_pcm.matrix)
# Stack: list of (U, U) -> (S, U, U) -> transpose to (U, U, S)
corr_array = np.moveaxis(np.stack(corr_matrices, axis=0), 0, 2)
lower_tri = np.tril_indices(num_units, k=-1)
av_corr = np.nanmean(corr_array[lower_tri[0], lower_tri[1], :], axis=0)
corr_stack = PairwiseCompMatrixStack(stack=corr_array, times=self.times)
if metric == "ccg":
lag_array = np.moveaxis(np.stack(lag_matrices, axis=0), 0, 2)
av_lag = np.nanmean(lag_array[lower_tri[0], lower_tri[1], :], axis=0)
lag_stack = PairwiseCompMatrixStack(stack=lag_array, times=self.times)
else:
lag_stack = None
av_lag = None
return corr_stack, lag_stack, av_corr, av_lag
[docs]
def get_slice_to_slice_unit_comparison(
self,
metric="ccg",
delt=20.0,
bin_size=1.0,
max_lag=350,
min_spikes=2,
min_frac=0.3,
frac_active=None,
n_jobs=-1,
):
"""Compute slice-to-slice similarity for each unit using spike-based metrics.
For each unit independently, compares its spike train across every pair of
slices. Asks: "Does unit X fire in the same temporal pattern across repeated
events?" Returns a ``PairwiseCompMatrixStack (S, S, U)``.
Parameters:
metric (str): Similarity metric to use. ``"ccg"`` for cross-correlogram
on binned rasters (default), ``"sttc"`` for spike time tiling coefficient.
delt (float): STTC time window in milliseconds (default: 20.0).
Only used when metric is ``"sttc"``.
bin_size (float): Bin size in milliseconds for the binary raster
(default: 1.0). Only used when metric is ``"ccg"``.
max_lag (float): Maximum lag in milliseconds to search for the peak
correlation (default: 350). Only used when metric is ``"ccg"``.
min_spikes (int): Minimum number of spikes in a slice for a unit to
be considered valid in that slice (default: 2).
min_frac (float): Maximum fraction of slices that can be invalid before
a unit's average is set to NaN (default: 0.3).
frac_active (np.ndarray or None): Optional pre-computed
fraction-active array of shape ``(U,)`` to override the
internal per-unit validity check for computing averages.
When provided, a unit's average is set to NaN if
``frac_active[u] < (1 - min_frac)``. ``min_spikes`` still
controls which individual slice pairs are computed.
Compatible sources: ``SpikeSliceStack.compute_frac_active``
and ``SpikeData.get_frac_active`` (``frac_per_unit`` output).
n_jobs (int): Number of threads for parallel computation. -1 uses
all cores (default), 1 disables parallelism, None is serial.
Returns:
all_corr (PairwiseCompMatrixStack): Pairwise similarity between all
slice pairs for each unit. Shape is ``(S, S, U)``.
all_lag (PairwiseCompMatrixStack or None): Lag at which maximum
similarity occurs for each slice pair per unit. Shape is ``(S, S, U)``.
``None`` when metric is ``"sttc"``.
av_corr (np.ndarray): Average similarity per unit across all valid
slice pairs. Shape is ``(U,)``.
av_lag (np.ndarray or None): Average lag per unit. Shape is ``(U,)``.
``None`` when metric is ``"sttc"``.
Notes:
- Analogous to ``RateSliceStack.get_slice_to_slice_unit_corr_from_stack``
but operates on raw spike trains.
- Spike times within each slice are relative to the slice time
origin (0-based or event-centered) for aligned comparison.
"""
if metric not in ("sttc", "ccg"):
raise ValueError(f"metric must be 'sttc' or 'ccg', got {metric!r}")
num_units = self.N
num_slices = len(self.spike_stack)
if num_slices < 2:
warnings.warn(
"Cannot compute slice-to-slice unit comparison with fewer than "
"2 slices. Returning NaN.",
RuntimeWarning,
)
av_corr = np.full(num_units, np.nan)
nan_stack = np.full((num_slices, num_slices, num_units), np.nan)
return (
PairwiseCompMatrixStack(stack=nan_stack),
(
PairwiseCompMatrixStack(stack=nan_stack.copy())
if metric == "ccg"
else None
),
av_corr,
av_corr.copy() if metric == "ccg" else None,
)
# Pre-compute spike trains per slice (relative to slice time origin)
# and per-slice rasters for CCG
shifted_trains = [] # list of S lists, each containing U spike arrays
slice_durations = []
slice_rasters = [] # only populated for CCG
for sd, (start, end) in zip(self.spike_stack, self.times):
duration = end - start
slice_durations.append(duration)
trains = []
for u in range(num_units):
trains.append(np.asarray(sd.train[u]))
shifted_trains.append(trains)
if metric == "ccg":
# Build shifted SpikeData for raster computation
temp_sd = SpikeData(
trains, length=duration, start_time=sd.start_time, N=num_units
)
slice_rasters.append(temp_sd.raster(bin_size))
max_lag_bins = int(round(max_lag / bin_size)) if metric == "ccg" else 0
# Validate frac_active override if provided
if frac_active is not None:
frac_active = np.asarray(frac_active, dtype=float)
if frac_active.shape != (num_units,):
raise ValueError(
f"frac_active must have shape ({num_units},), "
f"got {frac_active.shape}"
)
# Initialize result arrays: (U, S, S), will transpose to (S, S, U) at end
all_corr_scores = np.full((num_units, num_slices, num_slices), np.nan)
all_lag_scores = (
np.full((num_units, num_slices, num_slices), np.nan)
if metric == "ccg"
else None
)
av_corr = np.full(num_units, np.nan)
av_lag = np.full(num_units, np.nan) if metric == "ccg" else None
lower_tri = np.tril_indices(num_slices, k=-1)
start_times = [sd.start_time for sd in self.spike_stack]
def _process_unit(unit):
unit_corr = np.full((num_slices, num_slices), np.nan)
unit_lag = (
np.full((num_slices, num_slices), np.nan) if metric == "ccg" else None
)
invalid_count = 0
for ref_s in range(num_slices):
ref_train = shifted_trains[ref_s][unit]
if len(ref_train) < min_spikes:
invalid_count += 1
continue
for comp_s in range(ref_s, num_slices):
comp_train = shifted_trains[comp_s][unit]
if len(comp_train) < min_spikes:
continue
if metric == "sttc":
length = max(slice_durations[ref_s], slice_durations[comp_s])
# start_time from the ref slice is correct here:
# all slices share the same start_time (event-centered
# data has -pre_ms, frames() produces 0-based slices).
score = get_sttc(
ref_train,
comp_train,
delt=delt,
length=length,
start_time=start_times[ref_s],
)
unit_corr[ref_s, comp_s] = score
unit_corr[comp_s, ref_s] = score
else:
ref_signal = slice_rasters[ref_s][unit, :]
comp_signal = slice_rasters[comp_s][unit, :]
score, lag = compute_cross_correlation_with_lag(
ref_signal, comp_signal, max_lag=max_lag_bins
)
unit_corr[ref_s, comp_s] = score
unit_corr[comp_s, ref_s] = score
unit_lag[ref_s, comp_s] = lag
unit_lag[comp_s, ref_s] = -lag
if frac_active is not None:
unit_valid = frac_active[unit] >= (1 - min_frac)
else:
unit_valid = invalid_count / num_slices <= min_frac
av_c = (
np.nanmean(unit_corr[lower_tri[0], lower_tri[1]])
if unit_valid
else np.nan
)
av_l = np.nan
if metric == "ccg" and unit_valid:
av_l = np.nanmean(unit_lag[lower_tri[0], lower_tri[1]])
return unit, unit_corr, unit_lag, av_c, av_l
n_workers = _resolve_n_jobs(n_jobs)
if n_workers > 1 and num_units > 1:
with ThreadPoolExecutor(max_workers=n_workers) as pool:
results = pool.map(_process_unit, range(num_units))
else:
results = map(_process_unit, range(num_units))
for unit, unit_corr, unit_lag, av_c, av_l in results:
all_corr_scores[unit] = unit_corr
av_corr[unit] = av_c
if metric == "ccg":
all_lag_scores[unit] = unit_lag
av_lag[unit] = av_l
# Transpose from (U, S, S) to (S, S, U)
all_corr_scores = np.moveaxis(all_corr_scores, 0, 2)
all_corr_stack = PairwiseCompMatrixStack(stack=all_corr_scores)
if metric == "ccg":
all_lag_scores = np.moveaxis(all_lag_scores, 0, 2)
all_lag_stack = PairwiseCompMatrixStack(stack=all_lag_scores)
else:
all_lag_stack = None
return all_corr_stack, all_lag_stack, av_corr, av_lag
[docs]
def get_unit_timing_per_slice(
self,
timing="median",
min_spikes=2,
):
"""Compute a representative spike time for each unit in each slice.
Returns a ``(U, S)`` matrix where entry ``[u, s]`` is the timing
value (in milliseconds relative to the slice's time origin) for unit u
in slice s. For event-centered slices, t=0 is the event. Units with
fewer than min_spikes spikes in a slice are marked NaN.
Parameters:
timing (str): Which spike time to extract per unit per slice.
``"median"`` (default), ``"mean"``, or ``"first"`` (onset
latency).
min_spikes (int): Minimum number of spikes for a unit to count
as active in a slice (default: 2).
Returns:
timing_matrix (np.ndarray): Array of shape ``(U, S)`` with timing
values in milliseconds relative to each slice's time origin.
NaN where the unit is inactive.
Notes:
- Values are in milliseconds, not bin indices. This differs from
``RateSliceStack.get_unit_timing_per_slice`` which returns bin
indices (suitable for direct indexing into the event stack).
Both representations preserve rank order, so
``rank_order_correlation`` produces identical results either way.
- The returned matrix can be passed to ``rank_order_correlation``
to compute Spearman rank correlations between slice pairs, or
used as input to ``order_units_across_slices`` for manual
inspection of per-slice timing values.
"""
if timing not in ("median", "mean", "first"):
raise ValueError(
f"timing must be 'median', 'mean', or 'first', got {timing!r}"
)
num_units = self.N
num_slices = len(self.spike_stack)
timing_matrix = np.full((num_units, num_slices), np.nan)
for s_idx, (sd, (start, end)) in enumerate(zip(self.spike_stack, self.times)):
for u in range(num_units):
spikes = np.asarray(sd.train[u])
duration = end - start
spikes = spikes[
(spikes >= sd.start_time) & (spikes <= sd.start_time + duration)
]
if len(spikes) < min_spikes:
continue
if timing == "median":
timing_matrix[u, s_idx] = np.median(spikes)
elif timing == "mean":
timing_matrix[u, s_idx] = np.mean(spikes)
else:
timing_matrix[u, s_idx] = spikes[0]
return timing_matrix
[docs]
def rank_order_correlation(
self,
timing_matrix=None,
timing="median",
min_spikes=2,
min_overlap=3,
n_shuffles=100,
min_overlap_frac=None,
seed=1,
n_jobs=-1,
):
"""Compute Spearman rank-order correlation of unit timing between all slice pairs.
For each pair of slices, only units active in both slices (non-NaN in
both columns of the timing matrix) are included. If the overlap falls
below the required minimum, the pair is set to NaN.
When ``n_shuffles > 0``, the rank orders are shuffled n_shuffles
times for each pair to build a null distribution, and the raw
correlation is z-score normalised against it.
Parameters:
timing_matrix (np.ndarray or None): Array of shape ``(U, S)`` with
timing values per unit per slice. NaN entries mark inactive
units. Typically produced by ``get_unit_timing_per_slice``.
When None, computed automatically using timing and
min_spikes.
timing (str): Which spike time to extract per unit per slice.
``"median"`` (default), ``"mean"``, or ``"first"``. Only used
when timing_matrix is None.
min_spikes (int): Minimum spikes for activity (default: 2). Only
used when timing_matrix is None.
min_overlap (int): Minimum number of units that must be active in
both slices (default: 3).
min_overlap_frac (float or None): Minimum fraction of total units
that must be active in both slices (default: None). When
provided, the effective threshold is
``max(min_overlap, ceil(min_overlap_frac * U))``.
n_shuffles (int): Number of shuffle iterations for z-scoring
(default: 100). Set to 0 to return raw Spearman correlations.
Values between 1 and 4 are rejected (minimum 5 required for
a meaningful null distribution).
seed (int or None): Random seed for reproducibility of the shuffle
(default: 1).
n_jobs (int): Number of threads for parallel computation. -1 uses
all cores (default), 1 disables parallelism, None is serial.
Returns:
corr_matrix (PairwiseCompMatrix): Spearman correlation matrix of
shape ``(S, S)``. When ``n_shuffles > 0``, values are z-scores.
When ``n_shuffles == 0``, values are raw Spearman correlations.
av_corr (float): Average correlation (or z-score) across all valid
lower-triangle pairs.
overlap_matrix (PairwiseCompMatrix): Matrix of shape ``(S, S)``
with fraction of units active in both slices.
"""
from .utils import _rank_order_correlation_from_timing
if timing_matrix is None:
timing_matrix = self.get_unit_timing_per_slice(
timing=timing, min_spikes=min_spikes
)
return _rank_order_correlation_from_timing(
timing_matrix,
min_overlap=min_overlap,
min_overlap_frac=min_overlap_frac,
n_shuffles=n_shuffles,
seed=seed,
n_jobs=n_jobs,
)
[docs]
def plot_aligned_slice_single_unit(
self,
unit_idx,
ax=None,
color_vals=None,
color_label="",
cmap="viridis",
time_offset=0,
xlabel="Rel. time (ms)",
ylabel="Burst",
x_range=None,
vlines=None,
show_colorbar=True,
marker_size=20,
font_size=None,
style="scatter",
invert_y=False,
linewidths=0.5,
):
"""Plot a single unit's spike times across all slices as a raster.
Extracts the spike train for unit_idx from every slice and delegates
to :func:`~SpikeLab.spikedata.plot_utils.plot_aligned_slice_single_unit`.
Parameters:
unit_idx (int): Index of the unit to plot.
ax (matplotlib.axes.Axes or None): Target axes. If None, a new
figure and axes are created.
color_vals (np.ndarray or None): Per-slice colour values.
color_label (str): Colorbar label.
cmap (str): Matplotlib colormap name.
time_offset (float): Value subtracted from every spike time
before plotting. Slices from ``align_to_events`` are
already event-centered (spike times in
``[-pre_ms, +post_ms]``), so use the default
``time_offset=0``. Only set a non-zero value when spike
times are not already centered on the event.
xlabel (str): X-axis label.
ylabel (str): Y-axis label.
x_range (tuple or None): ``(xmin, xmax)`` for the x-axis.
vlines (list[dict] or None): Vertical reference lines. Each dict
must contain ``'x'`` and may optionally include ``'color'``,
``'linestyle'``, ``'linewidth'``.
show_colorbar (bool): Add a colorbar when color_vals is provided.
marker_size (float): Scatter marker size.
font_size (int or None): Font size for labels/ticks.
style (str): ``"scatter"`` for dot markers, ``"eventplot"`` for
vertical line markers.
invert_y (bool): If True, first slice at top, last at bottom.
linewidths (float): Line width for eventplot markers.
Returns:
result: ``(fig, ax, sc)`` when ax is None, otherwise just sc.
sc is the scatter ``PathCollection`` (or None if no colour
coding).
"""
from .plot_utils import (
plot_aligned_slice_single_unit as _plot_aligned_slice_single_unit,
)
try:
import matplotlib.pyplot as plt
except ImportError as e:
raise ImportError(
"plot_aligned_slice_single_unit requires 'matplotlib'. "
"Install with: pip install matplotlib"
) from e
if unit_idx < 0 or unit_idx >= self.N:
raise IndexError(f"unit_idx {unit_idx} out of range for {self.N} units.")
# Extract per-slice spike times for this unit
spike_times_per_slice = [sd.train[unit_idx] for sd in self.spike_stack]
standalone = ax is None
if standalone:
fig, ax = plt.subplots(figsize=(8, 6))
sc = _plot_aligned_slice_single_unit(
ax,
spike_times_per_slice,
color_vals=color_vals,
color_label=color_label,
cmap=cmap,
time_offset=time_offset,
xlabel=xlabel,
ylabel=ylabel,
x_range=x_range,
vlines=vlines,
show_colorbar=show_colorbar,
marker_size=marker_size,
font_size=font_size,
style=style,
invert_y=invert_y,
linewidths=linewidths,
)
if standalone:
plt.tight_layout()
return fig, ax, sc
return sc