import math
import os
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, List, Literal, Union, Dict, Any
import numpy as np
from itertools import groupby as _groupby
from scipy import ndimage, signal
from scipy.stats import norm
__all__ = [
"get_sttc",
"swap",
"randomize",
"trough_between",
"TimeUnit",
"ensure_h5py",
"times_from_ms",
"to_ms",
"extract_waveforms",
"check_neuron_attributes",
"get_channels_for_unit",
"compute_avg_waveform",
"get_valid_spike_times",
"waveforms_by_channel",
"extract_unit_waveforms",
"consecutive_durations",
"gplvm_state_entropy",
"gplvm_continuity_prob",
"gplvm_average_state_probability",
"shuffle_z_score",
"shuffle_percentile",
"slice_trend",
"slice_stability",
]
TimeUnit = Literal["ms", "s", "samples"]
try:
import h5py
except ImportError: # pragma: no cover
h5py = None # type: ignore
# Optional dependencies for manifold learning and graph-based clustering.
try: # optional, only needed for UMAP-based reductions
import umap # type: ignore
except ImportError: # pragma: no cover
umap = None # type: ignore
try: # optional, only needed for graph/community detection
import networkx as nx # type: ignore
except ImportError: # pragma: no cover
nx = None # type: ignore
try: # optional, only needed for Louvain community detection
import community as community_louvain # type: ignore
except ImportError: # pragma: no cover
community_louvain = None # type: ignore
# ---------------------------------------------------------------------------
# Thread-pool parallelisation helpers
# ---------------------------------------------------------------------------
def _resolve_n_jobs(n_jobs):
"""Resolve an n_jobs parameter to a concrete worker count.
Parameters:
n_jobs (int or None): Desired parallelism. -1 means all cores, None or
1 means serial execution, negative values count from cpu_count.
Returns:
n_workers (int): Positive integer worker count (1 = serial).
"""
if n_jobs is None or n_jobs == 1:
return 1
if n_jobs == -1:
return os.cpu_count() or 1
if n_jobs < -1:
cores = os.cpu_count() or 1
return max(1, cores + 1 + n_jobs)
return n_jobs
[docs]
def get_sttc(
tA, tB, delt=20.0, length: Optional[float] = None, start_time: float = 0.0
):
"""Calculate the spike time tiling coefficient between two spike trains.
Parameters:
tA (list): List of spike times for the first spike train.
tB (list): List of spike times for the second spike train.
delt (float): Time window in milliseconds (default: 20.0).
length (float or None): Total duration in milliseconds. If None,
inferred from the latest spike time after shifting, which may
underestimate the true recording duration if the last spike does
not fall near the end. Pass the actual recording length for
unbiased STTC.
start_time (float): Time origin of the spike trains (default 0.0).
Spike times are shifted by ``-start_time`` before computation so
that the STTC edge corrections work correctly for event-centered
data with negative spike times.
Returns:
sttc (float): Spike time tiling coefficient between the two spike
trains.
Notes:
Formula: STTC = ((PA - TB) / (1 - PA * TB) + (PB - TA) / (1 - PB * TA)) / 2
[1] Cutts & Eglen. Detecting pairwise correlations in spike trains:
An objective comparison of methods and application to the study of
retinal waves. Journal of Neuroscience 34:43, 14288-14303 (2014).
"""
if delt <= 0:
raise ValueError(f"delt must be positive, got {delt}")
if len(tA) == 0 or len(tB) == 0:
return 0.0
# Shift both trains by -start_time so they are 0-based. This ensures
# _sttc_ta edge corrections work correctly for event-centered data.
tA = np.asarray(tA, dtype=float) - start_time
tB = np.asarray(tB, dtype=float) - start_time
if length is None:
length = float(max(tA[-1], tB[-1]))
TA = _sttc_ta(tA, delt, length) / length
TB = _sttc_ta(tB, delt, length) / length
return _spike_time_tiling(tA, tB, TA, TB, delt)
def _spike_time_tiling(tA, tB, TA, TB, delt):
"""Internal helper method for the second half of STTC calculation."""
if len(tA) == 0 or len(tB) == 0:
return 0
PA = _sttc_na(tA, tB, delt) / len(tA)
PB = _sttc_na(tB, tA, delt) / len(tB)
aa = (PA - TB) / (1 - PA * TB) if PA * TB != 1 else 0
bb = (PB - TA) / (1 - PB * TA) if PB * TA != 1 else 0
return (aa + bb) / 2
def _sttc_ta(tA, delt: float, tmax: float) -> float:
"""Calculate the total amount of time within a range delt of spikes within tA."""
if len(tA) == 0:
return 0.0
base = min(delt, tA[0]) + min(delt, max(0, tmax - tA[-1]))
return base + np.minimum(np.diff(tA), 2 * delt).sum()
def _sttc_na(tA, tB, delt: float) -> int:
"""Helper function for STTC: Calculate the number of spikes in tA within delt of any spike in tB."""
if len(tB) == 0:
return 0
tA, tB = np.asarray(tA), np.asarray(tB)
if len(tB) == 1:
return int((np.abs(tA - tB[0]) <= delt).sum())
# Find the closest spike in B after spikes in A.
iB = np.searchsorted(tB, tA)
# Clip to ensure legal indexing, then check the spike at that
# index and its predecessor to see which is closer.
np.clip(iB, 1, len(tB) - 1, out=iB)
dt_left = np.abs(tB[iB] - tA)
dt_right = np.abs(tB[iB - 1] - tA)
# Return how many of those spikes are actually within delt.
# Uses inclusive <= (common implementation practice) rather than strict <
# from Cutts & Eglen (2014). For continuous spike times the difference is
# negligible; for binned data it may slightly increase coincidence counts.
return (np.minimum(dt_left, dt_right) <= delt).sum()
def _resampled_isi(spikes, times, sigma_ms):
"""Calculate the firing rate of a spike train at specific times using the reciprocal inter-spike interval.
Parameters:
spikes (list): List of spike times.
times (list): List of times.
sigma_ms (float): Standard deviation in milliseconds.
Returns:
fr (np.ndarray): Firing rate at specific times. Same size as times.
Notes:
Assumed to have been sampled halfway between any two given spikes,
interpolated, and then smoothed by a Gaussian kernel with the given
width.
"""
if len(spikes) == 0 or len(spikes) == 1:
# Need at least 2 spikes to do get inter-spike interval
return np.zeros_like(times)
if len(times) < 2:
# Single-time query: return unsmoothed ISI-derived rate at that time.
# If time is outside valid spike-interval support, rate is 0.
t = float(times[0])
spikes = np.array(spikes)
idx = np.searchsorted(spikes, t, side="right") - 1
if idx < 0 or idx >= len(spikes) - 1:
return np.zeros_like(times, dtype=float)
isi = spikes[idx + 1] - spikes[idx]
if isi <= 0:
return np.zeros_like(times, dtype=float)
return np.array([1.0 / isi * 1000], dtype=float)
spikes = np.array(spikes)
times = np.array(times)
# Remove duplicate spike times (BUG-002)
unique_spikes = np.unique(spikes)
if len(unique_spikes) < len(spikes):
warnings.warn(
f"{len(spikes) - len(unique_spikes)} duplicate spike time(s) removed "
f"before ISI computation.",
RuntimeWarning,
)
spikes = unique_spikes
if len(spikes) < 2:
return np.zeros_like(times)
# Reject duplicate time grid values (BUG-003)
if len(np.unique(times)) < len(times):
raise ValueError(
"times array contains duplicate values. "
"Provide an evenly-spaced grid with unique time points."
)
# Compute inter spike intervals (piece 1 logic)
isi = np.diff(spikes)
isi = np.insert(isi, 0, 0) # Add spacer for first spike
# Compute instantaneous firing rates (1/isi, in Hz assuming ms units)
isi_rate = np.zeros_like(isi, dtype=float)
isi_rate[1:] = 1.0 / isi[1:] * 1000
# Create temporary result array matching times resolution
t_start, t_end = times[0], times[-1]
dt_ms = times[1] - times[0]
n_bins = int(round((t_end - t_start) / dt_ms)) + 1
isi_rate_temp = np.zeros(n_bins)
# Assign rates to bins between spikes.
# Note: int(round(...)) bin assignment can shift spikes at exact bin
# boundaries to adjacent bins — a known sub-ms precision limitation.
for i in range(1, len(spikes)):
start_bin = int(round((spikes[i - 1] - t_start) / dt_ms))
end_bin = int(round((spikes[i] - t_start) / dt_ms))
if start_bin < n_bins:
isi_rate_temp[start_bin : min(end_bin, n_bins)] = isi_rate[i]
# Interpolate to exact times grid (if needed)
fr = np.interp(times, t_start + dt_ms * np.arange(n_bins), isi_rate_temp)
# Apply Gaussian smoothing
if len(fr) < 2:
return fr
sigma = sigma_ms / dt_ms
if sigma > 0:
return ndimage.gaussian_filter1d(fr, sigma)
else:
return fr
def _sliding_rate_single_train(
spike_times,
window_size,
step_size=None,
sampling_rate=None,
t_start=None,
t_end=None,
gauss_sigma=0.0,
apply_square=True,
):
"""
Compute continuous firing rate from spike times using square and/or Gaussian smoothing.
For each time bin t, this can apply:
- square smoothing: counts spikes in centered window [t - W/2, t + W/2], rate R(t)=N/W
- Gaussian smoothing: 1D Gaussian filter over the rate trace
- both: square smoothing followed by Gaussian smoothing
Parameters:
spike_times (array_like): array_like
1D array of spike timestamps (time units consistent with other args).
window_size (float): Width of the sliding window W. Centered window [t - W/2, t + W/2].
step_size (float, optional): Advance step for time bins. If both step_size and sampling_rate
are provided, step_size takes precedence and sampling_rate is ignored.
sampling_rate (float, optional): Samples per time unit; step_size = 1 / sampling_rate if
step_size is not provided.
t_start (float, optional): Start of output time range in ms. Default: 0 - window_size/2.
t_end (float, optional): End of output time range in ms. Default: self.length + window_size/2.
gauss_sigma (float, optional): Gaussian smoothing sigma in ms. If 0, Gaussian smoothing is disabled.
apply_square (bool, optional): If True, apply square-window smoothing (existing behavior).
If False, skip square smoothing and compute rates from per-bin spike counts before optional Gaussian smoothing.
Returns:
RateData: Single-unit rate object with inst_Frate_data (1, T) and times; units: spikes per time (e.g. kHz).
Notes:
Uses zero-padding at boundaries for square smoothing (mode='same'). Rate near edges
may be lower when the effective window extends beyond the data.
- Assumes spike_times are sorted.
"""
spike_times = np.asarray(spike_times)
if len(spike_times) == 0:
from .ratedata import RateData
return RateData(inst_Frate_data=np.zeros((1, 0)), times=np.array([]))
if window_size <= 0:
raise ValueError(f"window_size must be positive, got {window_size}")
if step_size is None and sampling_rate is None:
raise ValueError("Must provide either step_size or sampling_rate")
if step_size is not None and sampling_rate is not None:
raise ValueError(
"step_size and sampling_rate are mutually exclusive; provide one, not both"
)
if step_size is None:
if sampling_rate is None or sampling_rate <= 0:
raise ValueError(
f"sampling_rate must be positive when step_size is not provided, got {sampling_rate}"
)
step_size = 1.0 / sampling_rate
else:
if step_size <= 0:
raise ValueError(f"step_size must be positive, got {step_size}")
if gauss_sigma < 0:
raise ValueError(f"gauss_sigma must be non-negative, got {gauss_sigma}")
# Default time range extends half-window beyond first/last spike so edges are covered
half_window = window_size / 2
if t_start is None:
t_start = float(np.min(spike_times)) - half_window
if t_end is None:
t_end = float(np.max(spike_times)) + half_window
if t_end <= t_start:
raise ValueError(
f"t_end must be greater than t_start (got t_start={t_start}, t_end={t_end})"
)
# Use sparse_raster for binning (same rule as SpikeData)
span = t_end - t_start
n_bins_est = int(np.ceil(span / step_size))
remainder = span % step_size
if remainder < 1e-12 or abs(remainder - step_size) < 1e-12:
n_bins_est += 1
t_last = t_start + n_bins_est * step_size
mask = (spike_times >= t_start) & (spike_times < t_last)
spike_times_filtered = spike_times[mask] - t_start
from .spikedata import SpikeData
sd = SpikeData([spike_times_filtered], length=span)
raster = sd.sparse_raster(step_size)
hist = np.asarray(raster.toarray()).ravel()
n_bins = hist.size
bin_edges = t_start + np.arange(n_bins + 1) * step_size
if apply_square:
# Sliding window = convolution with uniform kernel: sums spike counts over
# window_size worth of adjacent bins. mode='same' keeps output aligned with input.
window_bins = min(max(1, int(round(window_size / step_size))), n_bins)
effective_window = window_bins * step_size
kernel = np.ones(window_bins)
counts = np.convolve(hist, kernel, mode="same")
# Rate = spike count in window / effective window duration (spikes per time unit)
rate_array = counts / effective_window
else:
# No square smoothing: convert per-bin counts directly to rates.
rate_array = hist / step_size
if gauss_sigma > 0:
sigma_bins = gauss_sigma / step_size
rate_array = ndimage.gaussian_filter1d(rate_array, sigma=sigma_bins)
time_vector = (bin_edges[:-1] + bin_edges[1:]) / 2 # Bin centers
from .ratedata import RateData
return RateData(inst_Frate_data=rate_array.reshape(1, -1), times=time_vector)
def _train_from_i_t_list(idces, times, N):
"""Given lists of spike times and unit indices, produce a list of per-unit spike times.
Parameters:
idces (list): List of spike indices.
times (list): List of spike times.
N (int): Number of units.
Returns:
ret (list): List whose ith entry is a list of the spike times of the
ith unit.
"""
idces, times = np.asarray(idces), np.asarray(times)
if N is None:
N = idces.max() + 1
ret = []
for i in range(N):
ret.append(times[idces == i])
return ret
def butter_filter(
data,
lowcut: Optional[float] = None,
highcut: Optional[float] = None,
fs=20000.0,
order=5,
):
"""Apply a digital Butterworth filter. Filter type is based on input values.
Parameters:
data (array_like): Data to be filtered.
lowcut (float or None): Low cutoff frequency. If None or 0, highcut
must be a number.
highcut (float or None): High cutoff frequency. If None, lowcut must
be a non-zero number.
fs (float): Sample rate.
order (int): Order of the filter.
Returns:
filtered_traces (np.ndarray): The filtered output with the same shape
as data.
Notes:
If lowcut and highcut are both given, this filter is bandpass. In
this case, lowcut must be smaller than highcut.
"""
if lowcut is None and highcut is None:
raise ValueError(
"Need at least a low cutoff (lowcut) or high cutoff (highcut) frequency!"
)
elif lowcut is None and highcut is not None:
filter_type = "lowpass"
Wn = highcut / fs * 2
elif lowcut is not None and highcut is None:
filter_type = "highpass"
Wn = lowcut / fs * 2
else:
if lowcut >= highcut:
raise ValueError("lowcut must be smaller than highcut")
filter_type = "bandpass"
band = [lowcut, highcut]
Wn = [e / fs * 2 for e in band]
filter_coeff = signal.iirfilter(
order, Wn, analog=False, btype=filter_type, output="sos"
)
filtered_traces = signal.sosfiltfilt(filter_coeff, data)
return filtered_traces
[docs]
def swap(ar, idxs, rng):
"""Attempt one double-edge swap in a binary spike raster while preserving per-row and per-column sums.
Parameters:
ar (np.ndarray): Binary spike raster.
idxs (tuple): Tuple of numpy arrays containing the indices of the
spikes.
rng (np.random.Generator): Random number generator for
reproducibility.
Returns:
success (bool): True if a swap was performed.
Notes:
Both ``ar`` and ``idxs`` are mutated in-place for performance.
The swap chooses two existing spike positions (i0, j0) and (i1, j1)
and, if the off-diagonal positions (i0, j1) and (i1, j0) are both
empty and the indices are distinct, swaps them so that spikes move
to those positions.
"""
idx0 = rng.integers(len(idxs[0]))
idx1 = rng.integers(len(idxs[0]))
i0, j0 = idxs[0][idx0], idxs[1][idx0]
i1, j1 = idxs[0][idx1], idxs[1][idx1]
if i0 == i1 or j0 == j1 or ar[i0, j1] == 1.0 or ar[i1, j0] == 1.0:
return False
ar[i0, j0] = ar[i1, j1] = 0.0
ar[i0, j1] = ar[i1, j0] = 1.0
idxs[0][idx0], idxs[1][idx0] = i0, j1
idxs[0][idx1], idxs[1][idx1] = i1, j0
return True
[docs]
def randomize(ar, swap_per_spike=5, seed=None):
"""Randomize a binary spike raster using degree-preserving double-edge swaps.
Parameters:
ar (array_like): Binary matrix shaped (neurons, time) or
(time, neurons). Values should be 0/1.
swap_per_spike (int): Target number of successful swaps per spike.
seed (int or None): Random seed number. Set for repeatability during
experiments.
Returns:
randomized_raster (np.ndarray): Randomized binary matrix with the
same shape and row/column sums.
Notes:
Shuffling preserves each neuron's average firing rate but shuffles
which time bins it spikes in. Each time bin's population rate is also
preserved but the specific units active are shuffled. Every spike
swap involves 2 different spikes so on average every spike will get
swapped 2 * swap_per_spike times.
Okun, M. et al. Population rate dynamics and multineuron firing
patterns in sensory cortex. J. Neurosci. 32, 17108-17119 (2012).
"""
rng = np.random.default_rng(seed)
ar = np.array(ar, dtype=float, copy=True)
unique_vals = np.unique(ar)
if not np.all(np.isin(unique_vals, [0.0, 1.0])):
raise ValueError(
"randomize() requires a binary (0/1) raster. "
f"Found values: {unique_vals}"
)
idxs = np.where(ar == 1.0)
n_spikes = int(np.sum(ar))
attempts = int((swap_per_spike + 1) * n_spikes)
cnt_swap = 0
for _ in range(attempts):
if swap(ar, idxs, rng):
cnt_swap += 1
if cnt_swap < swap_per_spike * n_spikes:
for _ in range(attempts):
if swap(ar, idxs, rng):
cnt_swap += 1
if cnt_swap < swap_per_spike * n_spikes:
warnings.warn(
"Not sufficient successful swaps, only {} of {} required".format(
cnt_swap, swap_per_spike * n_spikes
),
RuntimeWarning,
)
return ar.astype(int)
[docs]
def trough_between(i0, i1, pop_rate):
"""Find the minimum value (trough) between two indices in a population rate array.
Parameters:
i0 (int): Time bin index of the first burst.
i1 (int): Time bin index of the second burst.
pop_rate (np.ndarray): Smoothed population spiking data in spikes
per bin.
Returns:
trough_idx (int or None): Time bin index of minimum value (trough)
between peaks. None if the indices are adjacent.
"""
L, R = int(i0), int(i1)
if R - L <= 1:
return None
seg = pop_rate[L:R]
return L + int(np.argmin(seg))
def compute_cross_correlation_with_lag(ref_rate, comp_rate, max_lag=0):
"""Compute normalized cross-correlation with lag information.
Parameters:
ref_rate (array): Reference firing rate signal.
comp_rate (array): Comparison firing rate signal.
max_lag (int or None): Maximum lag in frames to search for
similarity. If None, lag is set to 0. Negative values
are treated as their absolute value (lag is symmetric).
Returns:
max_corr (float): Maximum correlation coefficient.
max_lag_idx (int): Lag (in frames) at which maximum correlation
occurs.
"""
if max_lag is None:
max_lag = 0
max_lag = abs(max_lag)
# Handle zero-norm vectors:
# - Both zero → undefined (NaN)
# - One zero, one not → uncorrelated (0.0)
ref_norm = np.sum(ref_rate**2)
comp_norm = np.sum(comp_rate**2)
if ref_norm == 0 and comp_norm == 0:
return np.nan, 0
if ref_norm == 0 or comp_norm == 0:
return 0.0, 0
norm_product = ref_norm * comp_norm
# Fast path for zero lag: direct dot-product normalisation.
# Normalises by sqrt(sum(ref^2) * sum(comp^2)) — the L2 norms.
if max_lag == 0:
max_corr = np.sum(ref_rate * comp_rate) / np.sqrt(norm_product)
return max_corr, 0
# General path: use scipy.signal.correlate and normalise by
# sqrt(autocorr_ref[0] * autocorr_comp[0]). For practical signal
# lengths this equals the L2-norm product used in the fast path;
# a tiny difference can arise for very short even-length signals
# because correlate(..., 'same') may offset the centre by half a
# sample.
auto_ref = signal.correlate(ref_rate, ref_rate, mode="same")[len(ref_rate) // 2]
auto_comp = signal.correlate(comp_rate, comp_rate, mode="same")[len(comp_rate) // 2]
denom = auto_ref * auto_comp
if denom <= 0:
return 0.0, 0
r = signal.correlate(ref_rate, comp_rate, mode="same") / np.sqrt(denom)
center = len(r) // 2
# Search within max_lag window
search_start = max(0, center - max_lag)
search_end = min(len(r), center + max_lag + 1)
search_window = r[search_start:search_end]
max_corr = np.max(search_window)
max_lag_idx = np.argmax(search_window) + search_start - center
return max_corr, max_lag_idx
def _cosine_sim(a, b):
"""Cosine similarity between two 1-D vectors. NaN if both zero-norm, 0.0 if one is."""
norm_a = np.linalg.norm(a)
norm_b = np.linalg.norm(b)
if norm_a == 0.0 and norm_b == 0.0:
return np.nan
if norm_a == 0.0 or norm_b == 0.0:
return 0.0
return float(np.dot(a, b) / (norm_a * norm_b))
def compute_cosine_similarity_with_lag(ref_rate, comp_rate, max_lag=0):
"""Compute cosine similarity with lag information.
Parameters:
ref_rate (array): Reference firing rate signal.
comp_rate (array): Comparison firing rate signal.
max_lag (int or None): Maximum lag in frames to search for
similarity. If None, lag is set to 0. Negative values
are treated as their absolute value (lag is symmetric).
Returns:
max_sim (float): Maximum cosine similarity coefficient.
max_lag_idx (int): Lag (in frames) at which maximum similarity
occurs.
"""
ref_rate = np.array(ref_rate).flatten()
comp_rate = np.array(comp_rate).flatten()
# Handle None case (convert to 0)
if max_lag is None:
max_lag = 0
max_lag = abs(max_lag)
if max_lag == 0:
# Only check zero lag
return _cosine_sim(ref_rate, comp_rate), 0
lag_range = range(-max_lag, max_lag + 1)
similarities = []
valid_lags = []
# Compute cosine similarity at each lag, and makes a case for negative, positive or no lag
for lag in lag_range:
if lag < 0:
# comp_rate leads ref_rate (shift comp_rate left, or ref_rate right)
ref_segment = ref_rate[-lag:]
comp_segment = comp_rate[:lag]
elif lag > 0:
# ref_rate leads comp_rate (shift ref_rate left, or comp_rate right)
ref_segment = ref_rate[:-lag]
comp_segment = comp_rate[lag:]
else:
# No lag
ref_segment = ref_rate
comp_segment = comp_rate
# Skip if segments are too short
if len(ref_segment) > 0 and len(comp_segment) > 0:
similarities.append(_cosine_sim(ref_segment, comp_segment))
valid_lags.append(lag)
# Find maximum similarity and corresponding lag
similarities = np.array(similarities)
valid_lags = np.array(valid_lags)
if np.all(np.isnan(similarities)):
return np.nan, 0
max_idx = np.nanargmax(similarities)
max_sim = similarities[max_idx]
max_lag_idx = valid_lags[max_idx]
return max_sim, max_lag_idx
def PCA_reduction(matrix_2d, n_components=2):
"""Compute PCA dimensionality reduction on axis 1 of a 2D matrix.
Parameters:
matrix_2d (array): 2D matrix of shape (samples, features) where
values must be int, float, or bool.
n_components (int): Number of principal components to retain
(default: 2).
Returns:
embedding (np.ndarray): 2D matrix of shape
``(samples, n_components)``.
explained_variance_ratio (np.ndarray): 1D array of shape
``(n_components,)`` with the fraction of total variance
explained by each component.
components (np.ndarray): 2D matrix of shape
``(n_components, features)`` with the principal axes
(loadings) -- each row is one PC expressed in the original
feature space.
"""
try:
from sklearn.decomposition import PCA
except ImportError:
raise ImportError(
"PCA_reduction requires the optional dependency 'scikit-learn'. "
"Install it with `pip install scikit-learn`."
)
max_components = min(matrix_2d.shape)
if n_components > max_components:
raise ValueError(
f"n_components={n_components} exceeds "
f"min(n_samples, n_features)={max_components}"
)
pca = PCA(n_components=n_components)
embedding = pca.fit_transform(matrix_2d)
return embedding, pca.explained_variance_ratio_, pca.components_
def _clamp_umap_n_neighbors(n_samples: int, n_neighbors: int) -> int:
"""Clamp n_neighbors so small datasets do not raise at UMAP fit time."""
if n_samples < 2:
return 1
max_nn = max(1, int(math.ceil(n_samples / 2)) - 1)
return min(max(int(n_neighbors), 2), max_nn)
def UMAP_reduction(
matrix_2d,
n_components: int = 2,
n_neighbors: int = 15,
min_dist: float = 0.1,
metric: str = "euclidean",
random_state: Optional[int] = None,
**umap_kwargs: Any,
):
"""Compute UMAP dimensionality reduction on a 2D matrix.
Parameters:
matrix_2d (array_like): Input data of shape
``(n_samples, n_features)``. Each row is a sample, each column
is a feature.
n_components (int): Dimension of the embedded space.
n_neighbors (int): Size of local neighborhood used for manifold
approximation.
min_dist (float): Controls how tightly UMAP packs points together in
the low-dimensional space.
metric (str): Distance metric used in the input space.
random_state (int or None): Random seed for reproducibility.
**umap_kwargs: Additional keyword arguments passed to
``umap.UMAP``.
Returns:
embedding (np.ndarray): Low-dimensional embedding of shape
``(n_samples, n_components)``.
trustworthiness_score (float): Trustworthiness of the embedding
(0 to 1). Measures how well local neighborhoods in the
high-dimensional space are preserved. Returns NaN if
scikit-learn is unavailable.
"""
if umap is None:
raise ImportError(
"UMAP_reduction requires the optional dependency 'umap-learn'. "
"Install it with `pip install umap-learn`."
)
matrix_2d = np.asarray(matrix_2d)
n_neighbors = _clamp_umap_n_neighbors(matrix_2d.shape[0], n_neighbors)
reducer = umap.UMAP(
n_components=n_components,
n_neighbors=n_neighbors,
min_dist=min_dist,
metric=metric,
random_state=random_state,
**umap_kwargs,
)
embedding = reducer.fit_transform(matrix_2d)
try:
from sklearn.manifold import trustworthiness
tw = float(trustworthiness(matrix_2d, embedding, n_neighbors=n_neighbors))
except ImportError:
tw = float("nan")
return embedding, tw
def UMAP_graph_communities(
matrix_2d,
n_components: int = 2,
resolution: float = 1.0,
n_neighbors: int = 15,
min_dist: float = 0.1,
metric: str = "euclidean",
random_state: Optional[int] = None,
**umap_kwargs: Any,
):
"""Run UMAP and Louvain community detection on the UMAP connectivity graph.
This helper keeps UMAP_reduction simple while providing an optional
graph-based clustering approach that builds on UMAP's internal graph.
Parameters:
matrix_2d (array_like): Input data of shape
``(n_samples, n_features)``. Each row is a sample, each column
is a feature.
n_components (int): Dimension of the embedded space.
resolution (float): Resolution parameter for the Louvain community
detection algorithm. Higher values produce more, smaller
communities. Lower values produce fewer, larger communities.
n_neighbors (int): Passed through to ``umap.UMAP``.
min_dist (float): Passed through to ``umap.UMAP``.
metric (str): Passed through to ``umap.UMAP``.
random_state (int or None): Passed through to ``umap.UMAP``.
**umap_kwargs: Additional keyword arguments passed to
``umap.UMAP``.
Returns:
embedding (np.ndarray): Low-dimensional UMAP embedding of shape
``(n_samples, n_components)``.
labels (np.ndarray): Integer community label for each sample, shape
``(n_samples,)``.
trustworthiness_score (float): Trustworthiness of the embedding
(0 to 1). Returns NaN if scikit-learn is not available.
"""
# First compute the UMAP embedding and fitted mapper using the same
# configuration as UMAP_reduction.
if umap is None:
raise ImportError(
"UMAP_graph_communities requires the optional dependency 'umap-learn'. "
"Install it with `pip install umap-learn`."
)
if nx is None:
raise ImportError(
"UMAP_graph_communities requires the optional dependency 'networkx'. "
"Install it with `pip install networkx`."
)
if community_louvain is None:
raise ImportError(
"UMAP_graph_communities requires the optional dependency "
"'python-louvain'. Install it with `pip install python-louvain`."
)
matrix_2d = np.asarray(matrix_2d)
n_neighbors = _clamp_umap_n_neighbors(matrix_2d.shape[0], n_neighbors)
mapper = umap.UMAP(
n_components=n_components,
n_neighbors=n_neighbors,
min_dist=min_dist,
metric=metric,
random_state=random_state,
**umap_kwargs,
).fit(matrix_2d)
# UMAP's internal connectivity graph -> NetworkX graph
# Use a compatibility shim so both old and new NetworkX versions work.
if hasattr(nx, "from_scipy_sparse_array"):
G = nx.from_scipy_sparse_array(mapper.graph_)
else:
G = nx.from_scipy_sparse_matrix(mapper.graph_)
# Louvain community detection on the graph
clustering = community_louvain.best_partition(G, resolution=resolution)
# Convert dict {node_idx: community_id} -> label array
# Use the fitted mapper's embedding to determine n_samples so that
# callers can pass in any array-like that UMAP accepts (not just ndarrays).
n_samples = mapper.embedding_.shape[0]
labels = np.zeros(n_samples, dtype=int)
for node, c_id in clustering.items():
labels[node] = c_id
try:
from sklearn.manifold import trustworthiness
tw = float(
trustworthiness(matrix_2d, mapper.embedding_, n_neighbors=n_neighbors)
)
except ImportError:
tw = float("nan")
return mapper.embedding_, labels, tw
[docs]
def ensure_h5py():
"""Raise ``ImportError`` if *h5py* is not installed."""
if h5py is None:
raise ImportError(
"h5py is required for this operation. " "Install it with: pip install h5py"
)
[docs]
def times_from_ms(
times_ms: np.ndarray, unit: TimeUnit, fs_Hz: Optional[float]
) -> Union[np.ndarray, float, int]:
"""Convert times from milliseconds to the requested unit."""
if unit == "ms":
return times_ms.astype(float)
if unit == "s":
return times_ms.astype(float) / 1e3
if unit == "samples":
if not fs_Hz or fs_Hz <= 0:
raise ValueError("fs_Hz must be provided and > 0 when unit='samples'")
# Use round-to-nearest to produce integer samples
return np.rint(times_ms.astype(float) * (fs_Hz / 1e3)).astype(np.int64)
raise ValueError(f"Unknown time unit '{unit}' (expected 's','ms','samples')")
[docs]
def to_ms(values: np.ndarray, unit: str, fs_Hz: Optional[float]) -> np.ndarray:
"""Convert a vector of times to milliseconds."""
if unit == "ms":
return values.astype(float)
if unit == "s":
return values.astype(float) * 1e3
if unit == "samples":
if not fs_Hz or fs_Hz <= 0:
raise ValueError("fs_Hz must be provided and > 0 when unit='samples'")
return values.astype(float) / fs_Hz * 1e3
raise ValueError(f"Unknown time unit '{unit}' (expected 's','ms','samples')")
[docs]
def check_neuron_attributes(
neuron_attributes: List[dict], n_neurons: Optional[int] = None
) -> List[dict]:
"""Check a list of dictionaries for use as neuron_attributes to verify that keys and values are consistent.
Parameters:
neuron_attributes (list of dict): List of dictionaries containing
neuron attributes.
n_neurons (int or None): Expected number of neurons. If provided,
validates the list length.
Returns:
result (list of dict): A list of dictionaries where all dictionaries
have valid keys and values.
Notes:
If some dictionaries are missing keys that others have, a ValueError
is raised indicating which neuron entries have inconsistent keys.
"""
if not isinstance(neuron_attributes, list):
raise ValueError("neuron_attributes must be a list")
if n_neurons is not None and len(neuron_attributes) != n_neurons:
raise ValueError(
f"neuron_attributes has {len(neuron_attributes)} items, expected {n_neurons}"
)
for i, attr in enumerate(neuron_attributes):
if not isinstance(attr, dict):
raise ValueError(f"neuron_attributes[{i}] must be a dict")
if not neuron_attributes:
return []
all_keys = set().union(*(attr.keys() for attr in neuron_attributes))
if not all_keys:
return [d.copy() for d in neuron_attributes]
missing = {
i: all_keys - attr.keys()
for i, attr in enumerate(neuron_attributes)
if attr.keys() != all_keys
}
if missing:
parts = [f"Neuron {i} missing: {keys}" for i, keys in sorted(missing.items())]
raise ValueError(f"Inconsistent neuron_attributes keys. {'; '.join(parts)}.")
return [{key: attr.get(key) for key in all_keys} for attr in neuron_attributes]
[docs]
def get_channels_for_unit(
unit_idx: int,
channels: Optional[Union[int, List[int]]],
neuron_to_channel: dict,
n_channels_total: int,
) -> List[int]:
"""Determine which channels to extract for a given unit.
Parameters:
unit_idx (int): Index of the unit.
channels (int, list of int, or None): Channel specification. None
uses neuron_to_channel mapping or all channels; int for single
channel; list for multiple; empty list for mapped channel.
neuron_to_channel (dict): Mapping from unit indices to channel
indices.
n_channels_total (int): Total number of channels in the raw data.
Returns:
result (list of int): Channel indices to extract.
Raises:
ValueError: If channels argument is invalid type.
"""
if channels is None:
if unit_idx in neuron_to_channel:
return [neuron_to_channel[unit_idx]]
return list(range(n_channels_total))
elif isinstance(channels, int):
return [channels]
elif isinstance(channels, list):
if len(channels) == 0:
if unit_idx in neuron_to_channel:
return [neuron_to_channel[unit_idx]]
return list(range(n_channels_total))
return channels
raise ValueError(f"Invalid channels argument: {channels}")
[docs]
def get_valid_spike_times(
spike_times_ms: np.ndarray,
fs_kHz: float,
ms_before: float,
ms_after: float,
n_time_samples: int,
) -> np.ndarray:
"""Filter spike times to only those within valid bounds of the raw data.
Parameters:
spike_times_ms (np.ndarray): Array of spike times in milliseconds.
fs_kHz (float): Sampling rate in kHz.
ms_before (float): Milliseconds before each spike time.
ms_after (float): Milliseconds after each spike time.
n_time_samples (int): Total number of time samples in the raw data.
Returns:
valid (np.ndarray): Array of valid spike times in milliseconds.
"""
before_samples = round(ms_before * fs_kHz)
after_samples = round(ms_after * fs_kHz)
valid_spike_times = []
for spike_time_ms in spike_times_ms:
spike_sample = round(spike_time_ms * fs_kHz)
start = spike_sample - before_samples
end = spike_sample + after_samples
if start >= 0 and end <= n_time_samples:
valid_spike_times.append(spike_time_ms)
return np.array(valid_spike_times)
[docs]
def consecutive_durations(signal, threshold, mode="above", min_dur=1):
"""
Compute the lengths of consecutive runs in a 1-D signal that satisfy a threshold condition.
Scans *signal* for contiguous stretches of bins that are above (>=) or
below (<) *threshold*, returns an array of their durations, and optionally
filters out runs shorter than *min_dur*.
Parameters:
signal (array_like): 1-D numeric array (e.g. continuity probability
time series from a GPLVM).
threshold (float): Threshold value for the condition.
mode (str): ``"above"`` keeps runs where ``signal >= threshold``;
``"below"`` keeps runs where ``signal < threshold``.
min_dur (int): Minimum run length to keep. Runs shorter than this
are discarded.
Returns:
durations (np.ndarray): 1-D integer array of run lengths that satisfy
the condition and are at least *min_dur* bins long. May be empty.
"""
signal = np.asarray(signal)
if signal.ndim != 1:
raise ValueError(f"signal must be 1-D, got shape {signal.shape}")
if mode == "above":
condition = signal >= threshold
elif mode == "below":
condition = signal < threshold
else:
raise ValueError("mode must be 'above' or 'below'")
# Compute lengths of consecutive True runs
durations = np.array(
[sum(1 for _ in group) for key, group in _groupby(condition) if key],
dtype=int,
)
if durations.size > 0:
durations = durations[durations >= min_dur]
return durations
[docs]
def gplvm_state_entropy(posterior_latent_marg):
"""
Compute Shannon entropy of the latent state distribution at each time bin.
Parameters:
posterior_latent_marg (np.ndarray): Marginal posterior over latent
states with shape ``(T, K)`` where *T* is the number of time bins
and *K* is the number of latent states. Typically obtained from
``SpikeData.fit_gplvm()["decode_res"]["posterior_latent_marg"]``.
Returns:
entropy (np.ndarray): 1-D array of shape ``(T,)`` with the Shannon
entropy (in nats) for each time bin.
"""
from scipy.stats import entropy as _entropy
posterior_latent_marg = np.asarray(posterior_latent_marg)
if posterior_latent_marg.ndim != 2:
raise ValueError(
f"posterior_latent_marg must be 2-D (T, K), got shape "
f"{posterior_latent_marg.shape}"
)
return _entropy(posterior_latent_marg, axis=1)
[docs]
def gplvm_continuity_prob(decode_res):
"""
Extract the continuity (non-jump) probability time series from a GPLVM decode result.
The continuity probability at each time bin is the marginal posterior
probability that the dynamics remained continuous (i.e. did not jump)
between the previous and current time bin.
Parameters:
decode_res (dict): Decoded latent state dictionary as returned by
``SpikeData.fit_gplvm()["decode_res"]``. Must contain the key
``"posterior_dynamics_marg"`` with shape ``(T, D)`` where the
first column (index 0) holds the continuity probability.
Returns:
continuity_prob (np.ndarray): 1-D array of shape ``(T,)`` with the
continuity probability at each time bin.
"""
if not isinstance(decode_res, dict):
raise TypeError("decode_res must be a dict from SpikeData.fit_gplvm()")
if "posterior_dynamics_marg" not in decode_res:
raise KeyError(
"decode_res must contain 'posterior_dynamics_marg'. "
"Pass the 'decode_res' dict from SpikeData.fit_gplvm()."
)
dynamics = np.asarray(decode_res["posterior_dynamics_marg"])
if dynamics.ndim != 2 or dynamics.shape[1] < 1:
raise ValueError(
f"posterior_dynamics_marg must be 2-D with at least 1 column, "
f"got shape {dynamics.shape}"
)
return dynamics[:, 0]
[docs]
def gplvm_average_state_probability(posterior_latent_marg):
"""
Compute the average probability of each latent state across all time bins.
Parameters:
posterior_latent_marg (np.ndarray): Marginal posterior over latent
states with shape ``(T, K)`` where *T* is the number of time bins
and *K* is the number of latent states. Typically obtained from
``SpikeData.fit_gplvm()["decode_res"]["posterior_latent_marg"]``.
Returns:
avg_prob (np.ndarray): 1-D array of shape ``(K,)`` with the mean
probability of each latent state, averaged over all time bins.
"""
posterior_latent_marg = np.asarray(posterior_latent_marg)
if posterior_latent_marg.ndim != 2:
raise ValueError(
f"posterior_latent_marg must be 2-D (T, K), got shape "
f"{posterior_latent_marg.shape}"
)
return np.mean(posterior_latent_marg, axis=0)
def _get_attr(obj, key, default):
"""Get an attribute from a dict-like or object-like neuron attribute entry."""
if isinstance(obj, dict):
return obj.get(key, default)
return getattr(obj, key, default)
def _validate_time_start_to_end(
times_start_to_end, warn_negative_start=False, recording_range=None
):
"""Validate that a list of (start, end) tuples has the same duration and is in proper format.
Parameters:
times_start_to_end (list): Each entry must be a tuple
``(start, end)``.
warn_negative_start (bool): If True, emit a warning for windows with
negative start times (default False). Useful when times are
expected to be absolute recording positions.
recording_range (tuple or None): If provided, a
``(rec_start, rec_end)`` tuple defining the valid time range.
Any window that extends outside this range raises
``ValueError``. If None (default), no range check is performed.
Returns:
valid_time_tuples (list): Sorted list of valid ``(start, end)``
tuples. Negative-start windows are preserved.
"""
if not isinstance(times_start_to_end, list):
raise TypeError("times must be a list of tuples")
time_diff_check = []
valid_time_tuples = []
times_start_to_end = sorted(times_start_to_end)
for i, time_window in enumerate(times_start_to_end):
if not isinstance(time_window, tuple):
raise TypeError(f"Element {i} of times is not a tuple: {time_window}")
if len(time_window) != 2:
raise TypeError(
f"Element {i} of times must be a tuple of length 2 (start, end): "
f"{time_window}"
)
if not (
isinstance(time_window[0], (int, float, np.number))
and isinstance(time_window[1], (int, float, np.number))
):
raise TypeError(
f"Start and end times in element {i} must be numbers: {time_window}"
)
if time_window[0] > time_window[1]:
raise ValueError(
f"Start time must not exceed end time in element {i}: {time_window}"
)
if time_window[0] == time_window[1]:
warnings.warn(
f"Zero-duration time window in element {i}: {time_window}. "
"Treating as an empty slice.",
UserWarning,
)
if warn_negative_start and time_window[0] < 0:
warnings.warn(
f"Time window {i} has negative start ({time_window[0]}). "
"If these are absolute recording times, negative values are "
"unexpected. For event-centered data constructed via "
"time_peaks + time_bounds, this is normal.",
UserWarning,
)
if recording_range is not None:
rec_start, rec_end = recording_range
if time_window[0] < rec_start or time_window[1] > rec_end:
raise ValueError(
f"Time window {i} ({time_window[0]}, {time_window[1]}) "
f"extends outside the recording range "
f"[{rec_start}, {rec_end}]."
)
time_diff_check.append(time_window[1] - time_window[0])
valid_time_tuples.append(time_window)
if len(time_diff_check) > 1:
diffs = np.array(time_diff_check)
if not np.allclose(diffs, diffs[0], atol=1e-6, rtol=0):
raise ValueError("All time windows must have the same length")
return valid_time_tuples
def _rank_order_correlation_from_timing(
timing_matrix,
min_overlap=3,
min_overlap_frac=None,
n_shuffles=100,
seed=1,
n_jobs=-1,
):
"""
Compute Spearman rank-order correlation of unit timing between all slice pairs.
Shared implementation used by both SpikeSliceStack.rank_order_correlation
and RateSliceStack.rank_order_correlation.
Parameters:
timing_matrix (np.ndarray): Array of shape (U, S) with timing values
per unit per slice. NaN entries mark inactive units.
min_overlap (int): Minimum units active in both slices (default: 3).
min_overlap_frac (float or None): Minimum fraction of total units
active in both slices. Effective threshold is
max(min_overlap, ceil(min_overlap_frac * U)).
n_shuffles (int): Shuffle iterations for z-scoring (default: 100).
0 = raw correlations. Values 1-4 are rejected.
seed (int or None): Random seed for shuffle reproducibility.
n_jobs (int): Number of threads for parallel computation. -1 uses all
cores (default), 1 disables parallelism, None is serial.
Returns:
corr_matrix (PairwiseCompMatrix): (S, S) Spearman correlation or z-score matrix.
av_corr (float): Average over valid lower-triangle pairs.
overlap_matrix (PairwiseCompMatrix): (S, S) fraction of units active in both slices.
"""
from scipy.stats import spearmanr
# Import here to avoid circular import at module level
from .pairwise import PairwiseCompMatrix
if 0 < n_shuffles < 5:
raise ValueError(
f"n_shuffles must be 0 (no shuffling) or >= 5, got {n_shuffles}"
)
timing_matrix = np.asarray(timing_matrix)
if timing_matrix.ndim != 2:
raise ValueError(
f"timing_matrix must be 2-D (U, S), got shape {timing_matrix.shape}"
)
num_units = timing_matrix.shape[0]
effective_min = min_overlap
if min_overlap_frac is not None:
frac_count = int(np.ceil(min_overlap_frac * num_units))
effective_min = max(effective_min, frac_count)
num_slices = timing_matrix.shape[1]
corr = np.full((num_slices, num_slices), np.nan)
overlap = np.zeros((num_slices, num_slices), dtype=int)
if n_shuffles == 0:
np.fill_diagonal(corr, 1.0)
for i in range(num_slices):
overlap[i, i] = int(np.sum(~np.isnan(timing_matrix[:, i])))
# Pre-compute validity masks and extract data for each pair
pairs = [(i, j) for i in range(num_slices) for j in range(i + 1, num_slices)]
# Each pair needs its own independent RNG for reproducibility
ss = np.random.SeedSequence(seed)
pair_seeds = ss.spawn(len(pairs))
def _compute_pair(args):
(i, j), child_seed = args
valid = ~np.isnan(timing_matrix[:, i]) & ~np.isnan(timing_matrix[:, j])
n_valid = int(np.sum(valid))
if n_valid < effective_min:
return i, j, n_valid, np.nan
a = timing_matrix[valid, i]
b = timing_matrix[valid, j]
rho, _ = spearmanr(a, b)
if n_shuffles == 0:
return i, j, n_valid, rho
rng = np.random.default_rng(child_seed)
null_rhos = np.empty(n_shuffles)
for k in range(n_shuffles):
null_rhos[k], _ = spearmanr(a, rng.permutation(b))
null_mean = np.mean(null_rhos)
null_std = np.std(null_rhos)
z = (rho - null_mean) / null_std if null_std > 0 else np.nan
return i, j, n_valid, z
work_items = list(zip(pairs, pair_seeds))
n_workers = _resolve_n_jobs(n_jobs)
if n_workers > 1 and len(pairs) > 1:
with ThreadPoolExecutor(max_workers=n_workers) as pool:
results = pool.map(_compute_pair, work_items)
else:
results = map(_compute_pair, work_items)
for i, j, n_valid, value in results:
overlap[i, j] = n_valid
overlap[j, i] = n_valid
corr[i, j] = value
corr[j, i] = value
lower_tri = np.tril_indices(num_slices, k=-1)
av_corr = float(np.nanmean(corr[lower_tri]))
overlap_frac = (
overlap.astype(float) / num_units if num_units > 0 else overlap.astype(float)
)
return (
PairwiseCompMatrix(matrix=corr),
av_corr,
PairwiseCompMatrix(matrix=overlap_frac),
)
# ---------------------------------------------------------------------------
# Slice comparison utilities
# ---------------------------------------------------------------------------
[docs]
def shuffle_z_score(observed, shuffle_distribution):
"""
Z-score an observed value against a shuffle null distribution.
Parameters:
observed (scalar or np.ndarray): The metric computed on the real data.
shuffle_distribution (np.ndarray): Shape ``(N, ...)`` array of the
same metric computed on N shuffled datasets (e.g. from
``SpikeSliceStack.apply`` on a shuffle stack built by
``SpikeData.spike_shuffle_stack``).
Returns:
z (np.ndarray): Z-score ``(observed - mean) / std`` computed along
axis 0. Same shape as *observed*.
Notes:
- Intended for determining whether an observed metric is significantly
different from what degree-preserving shuffled data produces.
- Elements where the shuffle standard deviation is zero will be NaN.
"""
shuffle_distribution = np.asarray(shuffle_distribution)
mean = np.nanmean(shuffle_distribution, axis=0)
std = np.nanstd(shuffle_distribution, axis=0)
safe_std = np.where(std == 0, 1.0, std)
z = (np.asarray(observed) - mean) / safe_std
z = np.where(std == 0, np.nan, z)
return z
[docs]
def shuffle_percentile(observed, shuffle_distribution):
"""
Compute the percentile rank of an observed value within a shuffle distribution.
Parameters:
observed (scalar or np.ndarray): The metric computed on the real data.
shuffle_distribution (np.ndarray): Shape ``(N, ...)`` array of the
same metric computed on N shuffled datasets.
Returns:
pct (np.ndarray): Fraction of shuffle values ≤ observed, computed
along axis 0. Values in [0, 1]. Same shape as *observed*.
Notes:
- Non-parametric alternative to ``shuffle_z_score``; gives the rank
of the observed value within the null distribution without assuming
normality.
"""
shuffle_distribution = np.asarray(shuffle_distribution)
observed = np.asarray(observed)
return np.mean(shuffle_distribution <= observed, axis=0)
[docs]
def slice_trend(values, times=None):
"""
Fit a linear trend to a metric computed across ordered slices.
Parameters:
values (np.ndarray): Shape ``(S,)`` array of metric values, one per
slice, in temporal order.
times (np.ndarray | None): Shape ``(S,)`` array of slice midpoints
in milliseconds. If None, integer indices ``0 .. S-1`` are used.
Returns:
slope (float): Linear regression slope. Units are metric-change per
millisecond (if *times* provided) or per slice index.
p_value (float): Two-sided p-value for the null hypothesis that the
slope is zero.
Notes:
- Intended for detecting systematic drift of a metric over the course
of a recording. Apply to the output of ``SpikeSliceStack.apply`` on
a frames stack built by ``SpikeData.frames``. A significant
positive or negative slope indicates non-stationarity.
- Uses ``scipy.stats.linregress``.
"""
from scipy.stats import linregress
values = np.asarray(values)
if values.ndim != 1:
raise ValueError(
f"values must be 1-D, got shape {values.shape}. "
"For higher-dimensional metrics, reduce to a scalar per slice "
"before calling slice_trend."
)
if times is None:
times = np.arange(len(values), dtype=float)
else:
times = np.asarray(times, dtype=float)
mask = ~np.isnan(values) & ~np.isnan(times)
n_valid = int(np.sum(mask))
if n_valid < 2:
raise ValueError(
"slice_trend requires at least 2 non-NaN (value, time) pairs; "
f"got {n_valid} after omitting NaNs."
)
result = linregress(times[mask], values[mask])
return result.slope, result.pvalue
[docs]
def slice_stability(values):
"""
Compute the coefficient of variation of a metric across slices.
Parameters:
values (np.ndarray): Shape ``(S,)`` or ``(S, ...)`` array of metric
values from ``SpikeSliceStack.apply``.
Returns:
cv (np.ndarray or float): Coefficient of variation ``std / |mean|``
computed along axis 0. Scalar when input is ``(S,)``.
Notes:
- Intended for summarising how much a metric varies across slices
(frames, trials, or shuffles). Low CV indicates a stable metric;
high CV indicates instability or sensitivity to the slicing.
- Elements where the mean is zero will be NaN.
"""
values = np.asarray(values, dtype=float)
mean = np.nanmean(values, axis=0)
std = np.nanstd(values, axis=0)
abs_mean = np.abs(mean)
safe_mean = np.where(abs_mean == 0, 1.0, abs_mean)
cv = std / safe_mean
cv = np.where(abs_mean == 0, np.nan, cv)
return float(cv) if cv.ndim == 0 else cv
# ---------------------------------------------------------------------------
# Sorter comparison helpers
# ---------------------------------------------------------------------------
def _count_matching_spikes(times1, times2, delta):
"""Count the number of matching spikes between two sorted spike trains.
Two spikes are considered matching if they occur within *delta* of each
other and belong to different trains. Uses a greedy left-to-right scan:
both trains are traversed simultaneously and the first valid pair
encountered is consumed, advancing both pointers.
This algorithm is adapted from SpikeInterface's ``count_matching_events``
(Buccino et al., eLife 2020; https://doi.org/10.7554/eLife.61834).
It runs in O(n1 + n2) time and is deterministic given sorted inputs. The
greedy strategy can yield sub-optimal counts when spikes cluster within
*delta* of each other. For example, with ``times1 = [10.0, 10.3]``,
``times2 = [10.2]``, and ``delta = 0.3``, the algorithm matches
``(10.0, 10.2)`` and leaves ``10.3`` unmatched — even though
``(10.3, 10.2)`` is a tighter match. A globally optimal assignment (e.g.
via the Hungarian algorithm) would be O(n^3) and is not used here because
the Jaccard agreement metric is insensitive to such edge cases when trains
are well-separated relative to *delta*. This matches the convention used
by SpikeInterface and SpikeForest.
Parameters:
times1 (np.ndarray): Sorted spike times for train 1.
times2 (np.ndarray): Sorted spike times for train 2.
delta (float): Maximum allowed temporal distance for a match.
Returns:
n_matches (int): Number of matched spike pairs.
"""
times1 = np.asarray(times1)
times2 = np.asarray(times2)
if len(times1) == 0 or len(times2) == 0:
return 0
i = 0
j = 0
n_matches = 0
n1 = len(times1)
n2 = len(times2)
while i < n1 and j < n2:
dt = times1[i] - times2[j]
if abs(dt) <= delta:
n_matches += 1
i += 1
j += 1
elif dt < 0:
i += 1
else:
j += 1
return n_matches
def _compute_agreement_score(train1, train2, delta):
"""Compute spike-train agreement between two spike trains.
Parameters:
train1 (np.ndarray): Sorted spike times for train 1.
train2 (np.ndarray): Sorted spike times for train 2.
delta (float): Maximum allowed temporal distance for a match.
Returns:
agreement (float): Jaccard-style agreement score
``n_matches / (n1 + n2 - n_matches)``.
frac_1 (float): Fraction of train1 spikes that were matched.
frac_2 (float): Fraction of train2 spikes that were matched.
"""
n1 = len(train1)
n2 = len(train2)
if n1 == 0 and n2 == 0:
return 0.0, 0.0, 0.0
n_matches = _count_matching_spikes(train1, train2, delta)
denom = n1 + n2 - n_matches
agreement = n_matches / denom if denom > 0 else 0.0
frac_1 = n_matches / n1 if n1 > 0 else 0.0
frac_2 = n_matches / n2 if n2 > 0 else 0.0
return agreement, frac_1, frac_2
def _compute_footprint(neuron_attrs, f_rel_to_trough, n_channels):
"""Build a spatial waveform footprint array for one unit.
The footprint is a 2-D array of shape ``(n_channels, n_samples)`` where
``n_samples = f_rel_to_trough[0] + f_rel_to_trough[1] + 1``. The
template waveform is placed at the unit's main channel row, and
neighbouring-channel templates are placed at their respective rows, all
aligned to the trough of the main template.
Parameters:
neuron_attrs (dict): Neuron attribute dictionary containing:
``template`` (1-D ndarray), ``neighbor_templates`` (2-D ndarray),
``channel`` (int), ``neighbor_channels`` (1-D ndarray).
f_rel_to_trough (tuple of int): ``(pre, post)`` number of samples
before and after the trough to include.
n_channels (int): Total number of channels on the probe.
Returns:
fp (np.ndarray): Footprint array of shape
``(n_channels, f_rel_to_trough[0] + f_rel_to_trough[1] + 1)``.
"""
n_samples = f_rel_to_trough[0] + f_rel_to_trough[1] + 1
fp = np.zeros((n_channels, n_samples))
template = np.asarray(neuron_attrs["template"])
nb_templates = np.asarray(neuron_attrs["neighbor_templates"])
channel = int(neuron_attrs["channel"])
nb_channels = np.asarray(neuron_attrs["neighbor_channels"])
t_i = int(np.argmin(template))
sel_start = max(0, t_i - f_rel_to_trough[0])
sel_end = min(len(template) - 1, t_i + f_rel_to_trough[1])
pre_seg = template[sel_start:t_i]
post_seg = template[t_i : sel_end + 1]
paste_start = f_rel_to_trough[0] - len(pre_seg)
paste_end = f_rel_to_trough[0] + len(post_seg)
fp[channel, paste_start : f_rel_to_trough[0]] = pre_seg
fp[channel, f_rel_to_trough[0] : paste_end] = post_seg
# nb_channels[0] is expected to be the primary channel (same as `channel`).
# Its template is already placed above via the main `template` array.
# Neighbor templates start at index 1. Validate the convention.
if len(nb_channels) > 0 and int(nb_channels[0]) != channel:
raise ValueError(
f"neighbor_channels[0] ({int(nb_channels[0])}) does not match the "
f"primary channel ({channel}). The first entry in neighbor_channels "
"must be the unit's own channel."
)
for nb_i in range(1, len(nb_channels)):
pre_nb = nb_templates[nb_i, sel_start:t_i]
post_nb = nb_templates[nb_i, t_i : sel_end + 1]
ch = int(nb_channels[nb_i])
if ch < n_channels:
fp[ch, paste_start : f_rel_to_trough[0]] = pre_nb
fp[ch, f_rel_to_trough[0] : paste_end] = post_nb
return fp
def _compute_footprint_similarity(fp1, fp2, max_lag=5):
"""Compute the best cosine similarity between two footprints over lag shifts.
The temporal lag is applied independently to each channel row (shifting
samples within a channel), then the resulting arrays are flattened and
compared via cosine similarity. Integer lags from ``-max_lag`` to
``+max_lag`` are tested and the maximum similarity is returned.
Parameters:
fp1 (np.ndarray): Footprint array (n_channels, n_samples).
fp2 (np.ndarray): Footprint array (n_channels, n_samples), same
shape as *fp1*.
max_lag (int): Maximum lag in samples to search (default 5).
Returns:
best_sim (float): Highest cosine similarity across all tested lags.
"""
if fp1.shape != fp2.shape:
raise ValueError(
f"Footprints must have the same shape, " f"got {fp1.shape} and {fp2.shape}"
)
n_samples = fp1.shape[1]
best = -np.inf
for lag in range(-max_lag, max_lag + 1):
if lag == 0:
vec1 = fp1.ravel()
vec2 = fp2.ravel()
elif lag > 0:
# Shift fp2 right by `lag` samples (compare overlapping region)
vec1 = fp1[:, lag:].ravel()
vec2 = fp2[:, : n_samples - lag].ravel()
else:
# Shift fp2 left by `|lag|` samples
vec1 = fp1[:, : n_samples + lag].ravel()
vec2 = fp2[:, -lag:].ravel()
sim = _cosine_sim(vec1, vec2)
if not np.isnan(sim) and sim > best:
best = sim
return float(best) if best > -np.inf else np.nan