Source code for spikelab.spikedata.rateslicestack

import numpy as np

__all__ = ["RateSliceStack"]
from .ratedata import RateData
from .spikedata import SpikeData
import warnings
from .pairwise import PairwiseCompMatrix, PairwiseCompMatrixStack


from concurrent.futures import ThreadPoolExecutor

from .utils import (
    compute_cross_correlation_with_lag,
    compute_cosine_similarity_with_lag,
    _validate_time_start_to_end,
    _get_attr,
    _resolve_n_jobs,
)


[docs] class RateSliceStack: """A 3D firing rate matrix of shape (U, T, S) with correlation and similarity capabilities. U is units (neurons), T is time bins, and S is slices (bursts, events, etc). Construct from either a data_obj (SpikeData or RateData) with time specifications, or directly from a pre-built event_matrix. The instance variables are the same regardless of input option. Parameters: data_obj (SpikeData or RateData): A data object to slice. Provide either this or event_matrix, not both. times_start_to_end (list or None): Each entry is a tuple (start, end) representing the start and end times of a desired slice. Each tuple must have the same duration. time_peaks (list or None): List of times as int or float where there is a burst peak or stimulation event. Must be paired with time_bounds. Alternative to times_start_to_end. time_bounds (tuple or None): Single tuple (left_bound, right_bound). For example, (250, 500) means 250 ms before peak and 500 ms after peak. Must be paired with time_peaks. sigma_ms (float): Smoothing factor for computing ISI if you input a SpikeData object. Otherwise not used. event_matrix (np.ndarray or None): A 3D array of shape (U, T, S). Provide either this or data_obj, not both. step_size (float or None): Time resolution in milliseconds between consecutive time bins. If None, defaults to 1.0. neuron_attributes (list or None): List of attribute objects, one per unit, containing arbitrary metadata about each neuron. Attributes: event_stack (np.ndarray): 3D array of shape (U, T, S) where U is the number of units, T is the number of time bins, and S is the number of slices. times (list): List of (start, end) time bounds for each slice, sorted chronologically. Length equals S. Example: [(100, 200), (500, 600), (1000, 1100)]. step_size (float): Time resolution in milliseconds between consecutive time bins. Inferred from input data. For SpikeData input, defaults to 1.0 ms. Example: 1.0 means time bins are at [100, 101, 102, ...] ms. neuron_attributes (list or None): List of attribute objects, one per unit, containing arbitrary metadata about each neuron. """
[docs] def __init__( self, data_obj=None, # Option 1 times_start_to_end=None, time_peaks=None, time_bounds=None, sigma_ms=10, event_matrix=None, # Option 2 step_size=None, neuron_attributes=None, ): if (data_obj is None) and (event_matrix is None): raise ValueError( "Must input either data_obj(option 1) or event_matrix(option 2)" ) if (data_obj is not None) and (event_matrix is not None): warnings.warn( "User input both data_obj and event_matrix. Ignoring data_obj and using event_matrix instead.", UserWarning, ) data_obj = None if sigma_ms is not None and sigma_ms < 0: raise ValueError("sigma_ms must be non-negative") # Option 1: Using data_obj if data_obj is not None: if not isinstance(data_obj, (SpikeData, RateData)): raise TypeError( "data_obj must either be a SpikeData object or RateData object" ) # This is to check that one of the time options is selected if times_start_to_end is None: if time_peaks is None or time_bounds is None: raise ValueError( "Must provide either times_start_to_end or both times_peaks and time_bounds" ) # If we're using peaks+bounds, validate them if not isinstance(time_bounds, tuple) or len(time_bounds) != 2: raise TypeError( "time_bounds must be a tuple of (before, after) durations" ) # Convert peaks and bounds to start_to_end format before, after = time_bounds time_peaks = sorted(time_peaks) times_start_to_end = [(t - before, t + after) for t in time_peaks] # Now that everything is times_start_to_end format, checking if inputs are correct types # Determine recording range for validation. if isinstance(data_obj, SpikeData): rec_range = ( data_obj.start_time, data_obj.start_time + data_obj.length, ) elif len(data_obj.times) > 1: step = data_obj.times[1] - data_obj.times[0] rec_range = (data_obj.times[0], data_obj.times[-1] + step) elif len(data_obj.times) == 1: rec_range = (data_obj.times[0], data_obj.times[0] + 1) else: rec_range = None times_start_to_end = _validate_time_start_to_end( times_start_to_end, recording_range=rec_range ) # Actual constructor if isinstance(data_obj, SpikeData): resolution = step_size if step_size is not None else 1.0 all_times = np.arange( data_obj.start_time, data_obj.start_time + data_obj.length, resolution, ) data_obj = data_obj.resampled_isi(all_times, sigma_ms) if len(data_obj.times) > 1: self.step_size = data_obj.times[1] - data_obj.times[0] else: self.step_size = 1.0 self.times = times_start_to_end event_stack = [] if isinstance(data_obj, RateData): # I use subtime here to extract a burst event and its time value based subtime for time in times_start_to_end: start = time[0] end = time[1] rate_obj_slice = data_obj.subtime(start, end) slice_matrix = rate_obj_slice.inst_Frate_data event_stack.append(slice_matrix) # Converts to a 3d array event_stack = np.stack(event_stack, axis=2) # This makes event stack be U x T x S self.event_stack = event_stack # Option 2: Using event matrx if event_matrix is not None: if not isinstance(event_matrix, np.ndarray): raise TypeError("event_matrix must be a numpy array") if event_matrix.ndim != 3: raise ValueError( f"event_matrix must be 3D (U x T x S), got {event_matrix.ndim}D array" ) if step_size is None: self.step_size = 1.0 else: self.step_size = step_size if times_start_to_end is None: slice_duration = event_matrix.shape[1] * self.step_size times_start_to_end = [] for i in range(event_matrix.shape[2]): start = i * slice_duration end = (i + 1) * slice_duration tup = (start, end) times_start_to_end.append(tup) else: times_start_to_end = _validate_time_start_to_end(times_start_to_end) # Make sure there is a (start,end) tuple for each slice if len(times_start_to_end) != event_matrix.shape[2]: raise ValueError( "times_start_to_end must have the same length as the last dimension of event_matrix" ) self.event_stack = event_matrix self.times = times_start_to_end if self.event_stack.shape[1] == 0: raise ValueError( "event_stack has zero time bins (T=0). " "A RateSliceStack requires at least one time bin." ) if neuron_attributes is None and data_obj is not None: neuron_attributes = getattr(data_obj, "neuron_attributes", None) self.neuron_attributes = None if neuron_attributes is not None: self.neuron_attributes = neuron_attributes.copy() if len(neuron_attributes) != self.event_stack.shape[0]: raise ValueError( f"neuron_attributes has {len(neuron_attributes)} items " f"but event_stack has {self.event_stack.shape[0]} units" )
[docs] def order_units_across_slices( self, agg_func, MIN_RATE_THRESHOLD=0.1, MIN_FRAC_ACTIVE=0.0, frac_active=None, timing_matrix=None, ): """Reorder units from earliest to latest peak firing rate across slices. Parameters: agg_func (str): Either ``"median"`` or ``"mean"``. Used for calculating the median/mean time when each unit has peak firing rate. MIN_RATE_THRESHOLD (float): Minimum peak firing rate for a slice to be included in the ordering calculation. Slices where a unit's max rate is below this threshold are excluded from that unit's typical peak time calculation. Ignored when timing_matrix is provided. MIN_FRAC_ACTIVE (float): Minimum fraction of slices a unit must be active in to be placed in the highly-active group. Default 0.0 means all units are in the first group, so the second array in each output tuple will be empty. frac_active (np.ndarray or None): Optional pre-computed fraction-active array of shape (U,) to use for the group split instead of the rate-based calculation. Compatible sources: SpikeSliceStack.compute_frac_active and SpikeData.get_frac_active (frac_per_unit output). timing_matrix (np.ndarray or None): Optional pre-computed (U, S) timing matrix from get_unit_timing_per_slice. When provided, MIN_RATE_THRESHOLD is ignored and this matrix is used directly. Returns: reordered_slice_matrices (tuple of arrays): Tuple of two 3D arrays from event_stack with the U dimension reordered temporally. The first array is the highly-active group and the second is the lower-activity group. unit_ids_in_order (tuple of arrays): Two arrays of original unit IDs in temporal order (highly-active, low-activity). For example, [3, 1, 0, 2] means unit 3 fires first. Use this to map back to original unit IDs. unit_std_indices (tuple of arrays): Two arrays of standard deviation of peak firing rate times (highly-active, low-activity). Lower values indicate more consistent timing across slices. unit_peak_times (tuple of arrays): Two arrays of median/mean peak firing time bin (highly-active, low-activity). unit_frac_active (tuple of arrays): Two arrays of the fraction of slices each unit was active in (highly-active, low-activity). Notes: - Call get_unit_timing_per_slice first to pre-compute the timing matrix if you want to reuse it across multiple calls (e.g. rank_order_correlation and this method). """ # burst_matrices is U x T x S slice_matrices = self.event_stack num_units = slice_matrices.shape[0] num_slices = slice_matrices.shape[2] if timing_matrix is not None: unit_max_indices_matrix = np.asarray(timing_matrix, dtype=float) if unit_max_indices_matrix.shape != (num_units, num_slices): raise ValueError( f"timing_matrix must have shape ({num_units}, {num_slices}), " f"got {unit_max_indices_matrix.shape}" ) # For frac_active fallback, derive mask from non-NaN entries mask = ~np.isnan(unit_max_indices_matrix) else: unit_max_indices_matrix = self.get_unit_timing_per_slice( MIN_RATE_THRESHOLD=MIN_RATE_THRESHOLD ) mask = ~np.isnan(unit_max_indices_matrix) unit_std_indices = np.nanstd(unit_max_indices_matrix, axis=1) # This gives you a list of size N. Now you have median peak time for each neuron if agg_func == "median": unit_peak_times = np.nanmedian(unit_max_indices_matrix, axis=1) elif agg_func == "mean": unit_peak_times = np.nanmean(unit_max_indices_matrix, axis=1) else: raise ValueError( f"{agg_func} is not a valid input option. Must be either median or mean" ) # Compute or validate frac_active only when splitting is requested num_units = slice_matrices.shape[0] skip_split = not MIN_FRAC_ACTIVE if skip_split: unit_frac_active = np.ones(num_units) highly_active_units = np.arange(num_units) low_active_units = np.array([], dtype=int) else: if frac_active is not None: frac_active = np.asarray(frac_active, dtype=float) if frac_active.shape != (num_units,): raise ValueError( f"frac_active must have shape ({num_units},), " f"got {frac_active.shape}" ) unit_frac_active = frac_active else: unit_frac_active = np.sum(mask, axis=1) / mask.shape[1] highly_active_units = np.where(unit_frac_active >= MIN_FRAC_ACTIVE)[0] low_active_units = np.where(unit_frac_active < MIN_FRAC_ACTIVE)[0] highly_active_order = highly_active_units[ np.argsort(unit_peak_times[highly_active_units]) ] low_active_order = low_active_units[ np.argsort(unit_peak_times[low_active_units]) ] # Cast to int for output only after sorting is done. # NaN values (units with no active slices) become -1. unit_peak_times_int = np.full(unit_peak_times.shape, -1, dtype=int) valid = ~np.isnan(unit_peak_times) unit_peak_times_int[valid] = np.round(unit_peak_times[valid]).astype(int) reordered_slice_matrices = ( slice_matrices[highly_active_order, :, :], slice_matrices[low_active_order, :, :], ) unit_ids_in_order = (highly_active_order, low_active_order) unit_std_indices = ( unit_std_indices[highly_active_order], unit_std_indices[low_active_order], ) unit_peak_times = ( unit_peak_times_int[highly_active_order], unit_peak_times_int[low_active_order], ) unit_frac_active = ( unit_frac_active[highly_active_order], unit_frac_active[low_active_order], ) return ( reordered_slice_matrices, unit_ids_in_order, unit_std_indices, unit_peak_times, unit_frac_active, )
[docs] def get_slice_to_slice_unit_corr_from_stack( self, compare_func=compute_cross_correlation_with_lag, MIN_RATE_THRESHOLD=0.1, MIN_FRAC=0.3, max_lag=10, frac_active=None, n_jobs=-1, ): """Compute slice-to-slice similarity along the unit axis of event_stack (U, T, S). Output is a PairwiseCompMatrixStack of shape (S, S, U). Parameters: compare_func (callable): Comparison function from utils. Specify cross-correlation or cosine similarity. The default is cross correlation. See utils.py for details. MIN_RATE_THRESHOLD (float): Minimum mean firing rate to consider a slice valid for that neuron. MIN_FRAC (float): Maximum fraction of slices that can be skipped before a unit is deemed invalid (default 0.3 = 30%). max_lag (int): Maximum lag in frames to search for similarity. If None, lag is set to 0. frac_active (np.ndarray or None): Optional pre-computed fraction-active array of shape (U,) to override the internal rate-based validity check for computing averages. When provided, a unit's average is set to NaN if frac_active[u] < (1 - MIN_FRAC). MIN_RATE_THRESHOLD still controls which individual slice pairs are computed. Compatible sources: SpikeSliceStack.compute_frac_active and SpikeData.get_frac_active (frac_per_unit output). n_jobs (int): Number of threads for parallel computation. -1 uses all cores (default), 1 disables parallelism, None is serial. Notes: When ``max_lag`` is 0 or None, the inner S x S loop is replaced by a vectorized matrix multiplication, which is significantly faster for large S. For non-zero ``max_lag``, the S x S comparisons are computed in a serial loop per unit (parallelised across units). This can be slow for large S (e.g. S > 100). Returns: all_slice_corr_scores (PairwiseCompMatrixStack): Pairwise correlation scores between all slice pairs for each unit. Shape is (S, S, U) in the stack attribute. av_slice_corr_scores (np.ndarray): Average correlation per neuron across all valid slice pairs. Shape is (U,). """ # Get dimensions event_stack = self.event_stack num_units = event_stack.shape[0] # N num_time_bins = event_stack.shape[1] # T num_slices = event_stack.shape[2] # B # Validate frac_active override if provided if frac_active is not None: frac_active = np.asarray(frac_active, dtype=float) if frac_active.shape != (num_units,): raise ValueError( f"frac_active must have shape ({num_units},), " f"got {frac_active.shape}" ) # Early return for single slice — pairwise comparison undefined (BUG-005) if num_slices < 2: warnings.warn( "Cannot compute slice-to-slice unit correlation with fewer than " "2 slices. Returning NaN.", RuntimeWarning, ) av_slice_corr_scores = np.full(num_units, np.nan) all_slice_corr_scores = np.full((num_slices, num_slices, num_units), np.nan) return ( PairwiseCompMatrixStack(stack=all_slice_corr_scores), av_slice_corr_scores, ) # Initialize result matrices (compute in U x S x S, then transpose) av_slice_corr_scores = np.full(num_units, np.nan) all_slice_corr_scores = np.full((num_units, num_slices, num_slices), np.nan) lower_tri_indices = np.tril_indices(num_slices, k=-1) effective_lag = 0 if max_lag is None else max_lag if effective_lag == 0: # --- Vectorized fast path (no lag search) ------------------------- # For each unit, compute the full S x S normalised dot-product # matrix in one matrix multiply instead of an O(S^2) Python loop. for unit in range(num_units): rates = event_stack[unit, :, :] # (T, S) slice_means = np.mean(rates, axis=0) # (S,) valid = slice_means >= MIN_RATE_THRESHOLD n_invalid = int(np.sum(~valid)) if np.sum(valid) < 2: # Not enough valid slices for pairwise comparison av_slice_corr_scores[unit] = np.nan continue # Compute norms and normalised correlation for valid slices norms = np.linalg.norm(rates, axis=0) # (S,) # Avoid division by zero for zero-norm slices safe_norms = np.where(norms > 0, norms, 1.0) normed = rates / safe_norms[np.newaxis, :] # (T, S) corr_matrix = normed.T @ normed # (S, S) # Build unit_corr: NaN for invalid slices, corr for valid pairs unit_corr = np.full((num_slices, num_slices), np.nan) # Handle zero-norm semantics: both zero → NaN, one zero → 0.0 valid_idx = np.where(valid)[0] ix = np.ix_(valid_idx, valid_idx) sub = corr_matrix[ix] # Zero-norm handling within valid slices zero_norm = norms[valid_idx] == 0 if np.any(zero_norm): both_zero = np.outer(zero_norm, zero_norm) one_zero = np.outer(zero_norm, ~zero_norm) | np.outer( ~zero_norm, zero_norm ) sub[both_zero] = np.nan sub[one_zero] = 0.0 unit_corr[ix] = sub all_slice_corr_scores[unit] = unit_corr # Compute average if frac_active is not None: unit_valid = frac_active[unit] >= (1 - MIN_FRAC) else: unit_valid = n_invalid / num_slices <= MIN_FRAC av_slice_corr_scores[unit] = ( np.nanmean(unit_corr[lower_tri_indices[0], lower_tri_indices[1]]) if unit_valid else np.nan ) else: # --- Loop fallback (non-zero lag) --------------------------------- def _process_unit(unit): unit_corr = np.full((num_slices, num_slices), np.nan) counter = 0 for ref_b in range(num_slices): ref_rate = event_stack[unit, :, ref_b] if np.mean(ref_rate) < MIN_RATE_THRESHOLD: counter += 1 continue for comp_b in range(ref_b, num_slices): comp_rate = event_stack[unit, :, comp_b] if np.mean(comp_rate) < MIN_RATE_THRESHOLD: continue max_corr, _ = compare_func(ref_rate, comp_rate, max_lag) unit_corr[comp_b, ref_b] = max_corr unit_corr[ref_b, comp_b] = max_corr if frac_active is not None: unit_valid = frac_active[unit] >= (1 - MIN_FRAC) else: unit_valid = counter / num_slices <= MIN_FRAC av = ( np.nanmean(unit_corr[lower_tri_indices[0], lower_tri_indices[1]]) if unit_valid else np.nan ) return unit, unit_corr, av n_workers = _resolve_n_jobs(n_jobs) if n_workers > 1 and num_units > 1: with ThreadPoolExecutor(max_workers=n_workers) as pool: results = pool.map(_process_unit, range(num_units)) else: results = map(_process_unit, range(num_units)) for unit, unit_corr, av in results: all_slice_corr_scores[unit] = unit_corr av_slice_corr_scores[unit] = av # Transpose from (U, S, S) to (S, S, U) for n×n×S convention all_slice_corr_scores = np.moveaxis(all_slice_corr_scores, 0, 2) # all_burst_corr_scores is now SxSxU and av_burst_corr_scores is U since its the mean correlation across all bursts. return ( PairwiseCompMatrixStack(stack=all_slice_corr_scores), av_slice_corr_scores, )
[docs] def get_slice_to_slice_time_corr_from_stack( self, compare_func=compute_cosine_similarity_with_lag, max_lag=0, n_jobs=-1 ): """Compute slice-to-slice similarity along the time axis of event_stack (U, T, S). Output is a PairwiseCompMatrixStack of shape (S, S, T). Parameters: compare_func (callable): Comparison function from utils. Specify cross-correlation or cosine similarity. The default is cosine similarity. See utils.py for details. max_lag (int): Maximum lag in frames to search for similarity. If None, lag is set to 0. n_jobs (int): Number of threads for parallel computation. -1 uses all cores (default), 1 disables parallelism, None is serial. Returns: all_slice_corr_scores (PairwiseCompMatrixStack): Pairwise correlation scores between all slice pairs for each time bin. Shape is (S, S, T) in the stack attribute. av_slice_corr_scores (np.ndarray): Average correlation per time bin across all valid slice pairs. Shape is (T,). """ # Get dimensions event_stack = self.event_stack num_time_bins = event_stack.shape[1] # T num_slices = event_stack.shape[2] # S # Early return for single slice — pairwise comparison undefined if num_slices < 2: warnings.warn( "Cannot compute slice-to-slice time correlation with fewer than " "2 slices. Returning NaN.", RuntimeWarning, ) av_slice_corr_scores = np.full(num_time_bins, np.nan) all_slice_corr_scores = np.full( (num_slices, num_slices, num_time_bins), np.nan ) return ( PairwiseCompMatrixStack(stack=all_slice_corr_scores), av_slice_corr_scores, ) # Initialize result matrices (compute in T x S x S, then transpose) av_slice_corr_scores = np.full(num_time_bins, np.nan) all_slice_corr_scores = np.full((num_time_bins, num_slices, num_slices), np.nan) lower_tri_indices = np.tril_indices(num_slices, k=-1) def _process_time(t): time_corr = np.full((num_slices, num_slices), np.nan) for ref_b in range(num_slices): ref_rate = event_stack[:, t, ref_b] for comp_b in range(ref_b, num_slices): comp_rate = event_stack[:, t, comp_b] max_corr, _ = compare_func(ref_rate, comp_rate, max_lag) time_corr[comp_b, ref_b] = max_corr time_corr[ref_b, comp_b] = max_corr av = np.nanmean(time_corr[lower_tri_indices[0], lower_tri_indices[1]]) return t, time_corr, av n_workers = _resolve_n_jobs(n_jobs) if n_workers > 1 and num_time_bins > 1: with ThreadPoolExecutor(max_workers=n_workers) as pool: results = pool.map(_process_time, range(num_time_bins)) else: results = map(_process_time, range(num_time_bins)) for t, time_corr, av in results: all_slice_corr_scores[t] = time_corr av_slice_corr_scores[t] = av # Transpose from (T, S, S) to (S, S, T) for n×n×S convention all_slice_corr_scores = np.moveaxis(all_slice_corr_scores, 0, 2) # all_slice_corr_scores is now SxSxT and av_burst_corr_scores is T return ( PairwiseCompMatrixStack(stack=all_slice_corr_scores), av_slice_corr_scores, )
[docs] def convert_to_list_of_RateData(self): """Create a list of RateData objects from the 3D event_stack. Returns: output (list): List of RateData objects. Length equals S. """ output = [] # U x T x S for s_idx in range(self.event_stack.shape[2]): matrix = self.event_stack[:, :, s_idx] start, end = self.times[s_idx] time = start + np.arange(matrix.shape[1]) * self.step_size if time[-1] > end: # Extremely rare edge case with floating point calculation. Should never happen but just in case time = np.clip(time, start, end - np.finfo(float).eps) # time = np.arange(start, end, self.step_size) rate_obj = RateData(matrix, time, neuron_attributes=self.neuron_attributes) output.append(rate_obj) return output
[docs] def unit_to_unit_correlation( self, compare_func=compute_cross_correlation_with_lag, max_lag=10, n_jobs=-1 ): """Compute unit-to-unit similarity along the slice axis of event_stack (U, T, S). Output is a PairwiseCompMatrixStack of shape (U, U, S). Parameters: compare_func (callable): Comparison function from utils. Specify cross-correlation or cosine similarity. The default is cross correlation. See utils.py for details. max_lag (int): Maximum lag in frames to search for similarity. If None, lag is set to 0. n_jobs (int): Number of threads for parallel computation. -1 uses all cores (default), 1 disables parallelism, None is serial. Returns: max_corr_array (PairwiseCompMatrixStack): Pairwise correlation scores between all unit pairs for each slice. Shape is (U, U, S) in the stack attribute. max_corr_lag_array (PairwiseCompMatrixStack): Lag where correlation between each pair is at its maximum. Shape is (U, U, S) in the stack attribute. av_max_corr (np.ndarray): Average correlation per slice across all valid unit pairs. Shape is (S,). av_max_corr_lag (np.ndarray): Average lag where correlation between each pair is at its maximum. Shape is (S,). """ num_units = self.event_stack.shape[0] num_slices = self.event_stack.shape[2] # Early return for single unit — pairwise comparison undefined (BUG-005) if num_units < 2: warnings.warn( "Cannot compute unit-to-unit correlation with fewer than " "2 units. Returning NaN.", RuntimeWarning, ) nan_stack = np.full((num_units, num_units, num_slices), np.nan) nan_avgs = np.full(num_slices, np.nan) return ( PairwiseCompMatrixStack(stack=nan_stack, times=self.times), PairwiseCompMatrixStack(stack=nan_stack.copy(), times=self.times), nan_avgs, nan_avgs.copy(), ) max_corr_stack = [] max_corr_lag_stack = [] rate_data_stack = self.convert_to_list_of_RateData() for i in range(len(rate_data_stack)): rate_data = rate_data_stack[i] # This gives 2 UxU matrices max_corr_matrix, lag_corr_matrix = rate_data.get_pairwise_fr_corr( compare_func, max_lag, n_jobs=n_jobs ) max_corr_stack.append(max_corr_matrix.matrix) max_corr_lag_stack.append(lag_corr_matrix.matrix) # Make the list of correlation matrices into a 3d matrix (S x U x U) max_corr_array = np.stack(max_corr_stack, axis=0) max_corr_lag_array = np.stack(max_corr_lag_stack, axis=0) num_units = max_corr_array.shape[1] lower_tri_indices = np.tril_indices(num_units, k=-1) # Find the averages to get a single dimension array of averages av_max_corr = np.nanmean( max_corr_array[:, lower_tri_indices[0], lower_tri_indices[1]], axis=(1) ) # shape (S,) av_max_corr_lag = np.nanmean( max_corr_lag_array[:, lower_tri_indices[0], lower_tri_indices[1]], axis=(1) ) # shape (S,) # Transpose from (S, U, U) to (U, U, S) for n×n×S convention max_corr_array = np.moveaxis(max_corr_array, 0, 2) max_corr_lag_array = np.moveaxis(max_corr_lag_array, 0, 2) return ( PairwiseCompMatrixStack(stack=max_corr_array, times=self.times), PairwiseCompMatrixStack(stack=max_corr_lag_array, times=self.times), av_max_corr, av_max_corr_lag, )
def __repr__(self) -> str: U, T, S = self.event_stack.shape return f"RateSliceStack(U={U}, T={T}, S={S})" def __len__(self) -> int: return self.event_stack.shape[2] def __iter__(self): return iter(self.convert_to_list_of_RateData())
[docs] def subset(self, units, by=None): """Extract a subset of units/neurons from the rate slice stack. Parameters: units (list or array): Unit indices to extract. If by is None, this should always be a list of ints. If by is not None, the list can contain ints or strings. by (str or None): Neuron attribute key to match against. Only use this if you initialized the object with neuron_attributes. Set to the key that contains neuron_id values. None selects by index (default). Returns: result (RateSliceStack): New RateSliceStack object containing only the specified units. """ N = self.event_stack.shape[0] if isinstance(units, int): units = [units] if isinstance(units, str): units = [units] units = set(units) if by is None: for u in units: if not isinstance(u, (int, np.integer)): raise TypeError(f"Unit index must be an integer, got {type(u)}") if u < 0 or u >= N: raise ValueError(f"Unit index {u} out of range for {N} units.") if by is not None: # VALUE-BASED: Look up by neuron_attribute if self.neuron_attributes is None: raise ValueError("can't use `by` without `neuron_attributes`") _missing = object() units = { i for i in range(N) if _get_attr(self.neuron_attributes[i], by, _missing) in units } units = sorted(units) neuron_attributes = None if self.neuron_attributes is not None: neuron_attributes = [self.neuron_attributes[i] for i in units] new_stack = self.event_stack[units, :, :] return RateSliceStack( event_matrix=new_stack, times_start_to_end=self.times, step_size=self.step_size, neuron_attributes=neuron_attributes, )
[docs] def subtime_by_index(self, start_idx, end_idx): """Extract a subset of time bins from every slice using index values. Trims along the time axis (T dimension) while preserving all slices (S dimension). Parameters: start_idx (int): Starting time bin index (inclusive). Supports negative indexing. end_idx (int): Ending time bin index (exclusive). Supports negative indexing. Returns: result (RateSliceStack): New RateSliceStack where each slice contains only the specified time bins. Shape changes from (U, T, S) to (U, T_trimmed, S). Notes: - Original timestamps are preserved (not shifted to zero). To get shifted-to-zero timestamps, create a new RateSliceStack. - All slices, neuron_attributes, and step_size are carried over from the original. """ T = self.event_stack.shape[1] if start_idx < 0: start_idx += T if end_idx < 0: end_idx += T if start_idx < 0 or start_idx >= T: raise ValueError(f"start_idx {start_idx} out of range for T={T}") if end_idx <= start_idx or end_idx > T: raise ValueError(f"end_idx {end_idx} invalid") new_stack = self.event_stack[:, start_idx:end_idx, :].copy() new_times = [] for t in self.times: new_start = t[0] + start_idx * self.step_size new_end = t[0] + end_idx * self.step_size new_times.append((new_start, new_end)) return RateSliceStack( event_matrix=new_stack, times_start_to_end=new_times, step_size=self.step_size, neuron_attributes=self.neuron_attributes, )
[docs] def subslice(self, slices): """Extract a subset of slices from the event stack using index values. Trims along the slice axis (S dimension) while preserving all time bins (T dimension). Parameters: slices (int or list): Slice index or list of slice indices to extract. Returns: result (RateSliceStack): New RateSliceStack containing only the specified slices. Shape changes from (U, T, S) to (U, T, S_trimmed). Notes: - All units, neuron_attributes, and step_size are carried over from the original. """ length = self.event_stack.shape[2] if isinstance(slices, int): slices = [slices] for s in slices: if s >= length or s < -length: raise ValueError( f"One or more slice indices out of range for S={length}" ) slices = sorted(slices) new_times = [] for s in slices: new_times.append(self.times[s]) new_stack = self.event_stack[:, :, slices] return RateSliceStack( event_matrix=new_stack, times_start_to_end=new_times, step_size=self.step_size, neuron_attributes=self.neuron_attributes, )
[docs] def get_unit_timing_per_slice(self, MIN_RATE_THRESHOLD=0.1): """Compute the peak firing rate time bin for each unit in each slice. Returns a ``(U, S)`` matrix where entry ``[u, s]`` is the time bin index of the peak firing rate for unit u in slice s. Units whose peak rate falls below MIN_RATE_THRESHOLD in a slice are marked NaN. Parameters: MIN_RATE_THRESHOLD (float): Minimum peak firing rate for a unit to count as active in a slice (default: 0.1). Returns: timing_matrix (np.ndarray): Array of shape ``(U, S)`` with peak time bin indices (integers cast to float for NaN support). These can be used to index directly into the ``(U, T, S)`` event stack. NaN where the unit is inactive. Notes: - Values are bin indices, not milliseconds. This differs from ``SpikeSliceStack.get_unit_timing_per_slice`` which returns milliseconds. Both representations preserve rank order, so ``rank_order_correlation`` produces identical results either way. - The returned matrix can be passed to ``rank_order_correlation`` to compute Spearman rank correlations between slice pairs. """ slice_matrices = self.event_stack unit_max_indices = np.argmax(slice_matrices, axis=1).astype(float) unit_max_rates = np.max(slice_matrices, axis=1) unit_max_indices[unit_max_rates < MIN_RATE_THRESHOLD] = np.nan return unit_max_indices
[docs] def rank_order_correlation( self, timing_matrix=None, MIN_RATE_THRESHOLD=0.1, min_overlap=3, min_overlap_frac=None, n_shuffles=100, seed=1, n_jobs=-1, ): """Compute Spearman rank-order correlation of unit timing between all slice pairs. For each pair of slices, only units active in both slices (non-NaN in both columns of the timing matrix) are included. If the overlap falls below the required minimum, the pair is set to NaN. When ``n_shuffles > 0``, the rank orders are shuffled n_shuffles times for each pair to build a null distribution, and the raw correlation is z-score normalised against it. Parameters: timing_matrix (np.ndarray or None): Array of shape ``(U, S)`` with timing values per unit per slice. NaN entries mark inactive units. Typically produced by ``get_unit_timing_per_slice``. When None, computed automatically using MIN_RATE_THRESHOLD. MIN_RATE_THRESHOLD (float): Minimum peak firing rate threshold (default: 0.1). Only used when timing_matrix is None. min_overlap (int): Minimum number of units that must be active in both slices (default: 3). min_overlap_frac (float or None): Minimum fraction of total units that must be active in both slices (default: None). When provided, the effective threshold is ``max(min_overlap, ceil(min_overlap_frac * U))``. n_shuffles (int): Number of shuffle iterations for z-scoring (default: 100). Set to 0 to return raw Spearman correlations. Values between 1 and 4 are rejected (minimum 5 required for a meaningful null distribution). seed (int or None): Random seed for reproducibility of the shuffle (default: 1). n_jobs (int): Number of threads for parallel computation. -1 uses all cores (default), 1 disables parallelism, None is serial. Returns: corr_matrix (PairwiseCompMatrix): Spearman correlation matrix of shape ``(S, S)``. When ``n_shuffles > 0``, values are z-scores. When ``n_shuffles == 0``, values are raw Spearman correlations. av_corr (float): Average correlation (or z-score) across all valid lower-triangle pairs. overlap_matrix (PairwiseCompMatrix): Matrix of shape ``(S, S)`` with fraction of units active in both slices. """ from .utils import _rank_order_correlation_from_timing if timing_matrix is None: timing_matrix = self.get_unit_timing_per_slice( MIN_RATE_THRESHOLD=MIN_RATE_THRESHOLD ) return _rank_order_correlation_from_timing( timing_matrix, min_overlap=min_overlap, min_overlap_frac=min_overlap_frac, n_shuffles=n_shuffles, seed=seed, n_jobs=n_jobs, )