Source code for spikelab.spikedata.ratedata

import warnings

import numpy as np

__all__ = ["RateData"]

from .pairwise import PairwiseCompMatrix
from concurrent.futures import ThreadPoolExecutor

from .utils import (
    compute_cross_correlation_with_lag,
    PCA_reduction,
    UMAP_reduction,
    UMAP_graph_communities,
    _get_attr,
    _resolve_n_jobs,
)


[docs] class RateData: """A 2D instantaneous firing rate matrix with unit-to-unit correlation capabilities. Parameters: inst_Frate_data (array): 2D array of shape (U, T). Each value is the instantaneous firing rate. U is the number of units/neurons and T is the number of time bins. times (list): List of time values that each column index in inst_Frate_data represents. For example, times = [5, 10, 15] so inst_Frate_data column 0 is 5 ms, column 1 is 10 ms, and column 2 is 15 ms. neuron_attributes (list or None): List of dicts, one per unit, containing arbitrary metadata about each neuron. None if not provided. Attributes: inst_Frate_data (array): 2D array of shape (U, T). Each value is the instantaneous firing rate. U is the number of units/neurons and T is the number of time bins. times (list): List of time values that each column index in inst_Frate_data represents. For example, times = [5, 10, 15] so inst_Frate_data column 0 is 5 ms, column 1 is 10 ms, and column 2 is 15 ms. neuron_attributes (list or None): List of dicts, one per unit, containing arbitrary metadata about each neuron. None if not provided. N (int): Number of units in inst_Frate_data. Notes: - ``times`` may contain negative values when the RateData represents an event-aligned window (e.g., times from -200 to +500 ms around a stimulus). - ``subtime`` always treats ``start``/``end`` as literal time values. Use ``subtime_by_index`` for index-based slicing with negative indexing. """
[docs] def __init__(self, inst_Frate_data, times, neuron_attributes=None, rate_unit=None): """Initialize a RateData object. Parameters: inst_Frate_data (numpy.ndarray): Firing rate data, shape (N, T). times (numpy.ndarray or list): Time points, length T. neuron_attributes (list or None): Per-unit attribute dicts. rate_unit (str or None): Unit of the rate values. Typically ``"Hz"`` (spikes/s) for ``resampled_isi`` or ``"kHz"`` (spikes/ms) for ``sliding_rate``. When *None*, the unit is unspecified. """ if inst_Frate_data.ndim != 2: raise ValueError( f"rates must be a 2D array, got shape {inst_Frate_data.shape}" ) if len(times) != inst_Frate_data.shape[1]: raise ValueError( "Number of columns in inst_Frate_data must be the same as length of times" ) if not isinstance(times, np.ndarray): times = np.array(times) self.inst_Frate_data = np.array(inst_Frate_data, dtype=float) self.times = times self.rate_unit = rate_unit self.N = inst_Frate_data.shape[0] self.neuron_attributes = None if neuron_attributes is not None: self.neuron_attributes = list(neuron_attributes) if len(neuron_attributes) != self.N: raise ValueError( f"neuron_attributes has {len(neuron_attributes)} items " f"but inst_Frate_data has {self.N} rows" )
def __repr__(self) -> str: t0 = float(self.times[0]) if len(self.times) > 0 else 0.0 t1 = float(self.times[-1]) if len(self.times) > 0 else 0.0 return f"RateData(shape={self.inst_Frate_data.shape}, time_range=[{t0:.1f}, {t1:.1f}])"
[docs] def subset(self, units, by=None): """Extract a subset of units/neurons from the rate data. Parameters: units (list or array): Unit indices to extract. If by is None, this should always be a list of ints. If by is not None, the list can contain ints or strings. by (str or None): Neuron attribute key to match against. Only use this if you initialized the object with neuron_attributes. Set to the key that contains neuron_id values. None selects by index (default). Returns: result (RateData): New RateData object containing only the specified units. """ if isinstance(units, int): units = [units] # For case where user inputs a single string for units when using by option if isinstance(units, str): units = [units] units = set(units) if by is not None: # VALUE-BASED: Look up by neuron_attribute if self.neuron_attributes is None: raise ValueError("can't use `by` without `neuron_attributes`") _missing = object() units = { i for i in range(self.N) if _get_attr(self.neuron_attributes[i], by, _missing) in units } units = sorted(units) output = self.inst_Frate_data[units, :] neuron_attributes = None if self.neuron_attributes is not None: neuron_attributes = [self.neuron_attributes[i] for i in units] return RateData( inst_Frate_data=output, times=self.times, neuron_attributes=neuron_attributes, rate_unit=self.rate_unit, )
[docs] def subtime(self, start, end): """Extract a subset of time points from the rate data using time values. Original time values are preserved in the output. Parameters: start (int or float): Starting time value (inclusive). end (int or float): Ending time value (exclusive). Returns: result (RateData): New RateData object containing only the specified time range. Notes: - Start and end are always treated as literal time values (not offsets from the end). To slice by array index with negative indexing support, use ``subtime_by_index(start_idx, end_idx)``. """ # Handle start if start is None or start is Ellipsis: start = self.times[0] if len(self.times) > 0 else 0 # Handle end — use a value just past the last time point so the # mask (times < end) includes the final bin. if end is None or end is Ellipsis: if len(self.times) > 1: end = self.times[-1] + (self.times[1] - self.times[0]) elif len(self.times) == 1: end = self.times[-1] + 1 else: end = 0 # Validate if start >= end: raise ValueError(f"start ({start}) must be less than end ({end})") mask = (self.times >= start) & (self.times < end) # Check if start and end were in range if not np.any(mask): raise ValueError( f"No time points found in range [{start}, {end}). " f"The available range is [{self.times[0]}, {self.times[-1]}]" ) output = self.inst_Frate_data[:, mask] new_times = self.times[mask] return RateData( inst_Frate_data=output, times=new_times, neuron_attributes=self.neuron_attributes, rate_unit=self.rate_unit, )
[docs] def subtime_by_index(self, start_idx, end_idx): """Extract a subset of time points from the rate data using time index values. Original time values are preserved in the output. Parameters: start_idx (int): Starting time index (inclusive). end_idx (int): Ending time index (exclusive). Returns: result (RateData): New RateData object containing only the specified time range. Notes: - Supports negative indexing (e.g., -5 selects 5 from the end). - To slice by time values instead of array indices, use ``subtime(start, end)``. """ if start_idx < 0: start_idx += len(self.times) if end_idx < 0: end_idx += len(self.times) if start_idx < 0 or start_idx >= len(self.times): raise ValueError(f"start_idx {start_idx} out of range") if end_idx <= start_idx or end_idx > len(self.times): raise ValueError(f"end_idx {end_idx} invalid") output = self.inst_Frate_data[:, start_idx:end_idx] new_times = self.times[start_idx:end_idx] return RateData( inst_Frate_data=output, times=new_times, neuron_attributes=self.neuron_attributes, rate_unit=self.rate_unit, )
[docs] def frames(self, length, overlap=0): """Split the rate data into a RateSliceStack of fixed-length windows. Parameters: length (float): Length of each window in milliseconds. overlap (float): Overlap between consecutive windows in milliseconds. Default 0. Returns: stack (RateSliceStack): Stack of rate data windows, one per frame. Notes: - Windows that would extend past the end of the recording are excluded. - overlap must be strictly less than length. """ from .rateslicestack import RateSliceStack step = length - overlap if step <= 0: raise ValueError("overlap must be less than length") t0 = float(self.times[0]) t_end = float(self.times[-1]) step_size = float(self.times[1] - self.times[0]) if len(self.times) > 1 else 1.0 upper = t_end - length + step_size + 1e-9 times = [ (float(start), float(start) + length) for start in np.arange(t0, upper, step) ] if not times: raise ValueError( f"Recording length ({t_end - t0 + step_size:.1f} ms) is shorter " f"than frame length ({length} ms)" ) return RateSliceStack(self, times_start_to_end=times)
[docs] def get_pairwise_fr_corr( self, compare_func=compute_cross_correlation_with_lag, max_lag=10, n_jobs=-1 ): """Compute unit-to-unit similarity from the firing rate matrix (U, T). Parameters: compare_func (callable): Comparison function from utils. Specify cross-correlation or cosine similarity. The default is cross correlation. See utils.py for details. max_lag (int): Max number of lag steps around 0 to consider for finding the max correlation. If None, lag is set to 0. n_jobs (int): Number of threads for parallel computation. -1 uses all cores (default), 1 disables parallelism, None is serial. Returns: corr_matrix (PairwiseCompMatrix): Maximum correlation coefficients between all unit/neuron pairs. matrix[i, j] is the max correlation between unit i and unit j. Values range from -1 to 1. Diagonal is always 1 (self-correlation). lag_matrix (PairwiseCompMatrix): Time lags (in time bins) at which maximum correlation occurs. lag_matrix[i, j] is the lag where correlation between i and j is maximal. Positive lag means unit j leads unit i (j fires earlier). Negative lag means unit i leads unit j (i fires earlier). Diagonal is always 0. """ rate_matrix = self.inst_Frate_data num_units = self.inst_Frate_data.shape[0] # N corr_matrix_this_event = np.full((num_units, num_units), np.nan) lag_matrix_this_event = np.full((num_units, num_units), np.nan) pairs = [(n1, n2) for n1 in range(num_units) for n2 in range(n1, num_units)] def _compute_pair(pair): n1, n2 = pair return pair, compare_func( rate_matrix[n1, :], rate_matrix[n2, :], max_lag=max_lag ) n_workers = _resolve_n_jobs(n_jobs) if n_workers > 1 and len(pairs) > 1: with ThreadPoolExecutor(max_workers=n_workers) as pool: results = pool.map(_compute_pair, pairs) else: results = map(_compute_pair, pairs) for (n1, n2), (max_corr, max_lag_idx) in results: corr_matrix_this_event[n1, n2] = max_corr lag_matrix_this_event[n1, n2] = max_lag_idx corr_matrix_this_event[n2, n1] = max_corr lag_matrix_this_event[n2, n1] = -max_lag_idx # Output is UxU, wrapped in PairwiseCompMatrix for API consistency meta = {"compare_func": compare_func.__name__, "max_lag": max_lag} return ( PairwiseCompMatrix(matrix=corr_matrix_this_event, metadata=meta), PairwiseCompMatrix(matrix=lag_matrix_this_event, metadata=meta), )
[docs] def get_manifold( self, method: str = "PCA", n_components: int = 2, **kwargs, ): """Project the firing-rate data into a low-dimensional manifold using PCA or UMAP. Parameters: method (str): Which dimensionality reduction method to use. Either ``"PCA"`` (default) or ``"UMAP"``. n_components (int): Number of output dimensions to return (default 2). **kwargs: Additional options for UMAP. If method is ``"UMAP"``, you can specify use_graph_communities (bool), return_labels (bool), and other UMAP-specific keyword arguments such as n_neighbors, min_dist, metric, or resolution. Returns: result (tuple): Depends on method and options: If method is ``"PCA"``: ``(embedding, explained_variance_ratio, components)`` where embedding has shape (T, n_components), explained_variance_ratio has shape (n_components,), and components has shape (n_components, U). If method is ``"UMAP"``: ``(embedding, trustworthiness)`` where embedding has shape (T, n_components) and trustworthiness is a float from 0 to 1. If method is ``"UMAP"`` with use_graph_communities=True and return_labels=True: ``(embedding, labels, trustworthiness)``. Notes: - To visualise the resulting embedding, use :func:`~spikelab.spikedata.plot_utils.plot_manifold`. It accepts the embedding array directly and supports background masks, continuous colour values, and discrete group colouring. """ # Shape is (U, T); treat each time bin as a sample. data_T = self.inst_Frate_data.T # (T, U) method_upper = method.upper() if method_upper == "PCA": if kwargs: warnings.warn( f"Additional keyword arguments {list(kwargs.keys())} are ignored for method='{method}'.", UserWarning, ) return PCA_reduction( data_T, n_components=n_components ) # (embedding, var_ratio, components) if method_upper == "UMAP": # Optional graph-based UMAP + Louvain communities. use_graph_communities = kwargs.pop("use_graph_communities", False) return_labels = kwargs.pop("return_labels", False) if return_labels and not use_graph_communities: warnings.warn( "return_labels=True has no effect without use_graph_communities=True; " "labels will not be returned.", UserWarning, stacklevel=2, ) if use_graph_communities: embedding, labels, tw = UMAP_graph_communities( data_T, n_components=n_components, **kwargs, ) if return_labels: return embedding, labels, tw return embedding, tw # Default: plain UMAP embedding + trustworthiness. return UMAP_reduction( data_T, n_components=n_components, **kwargs, ) # (embedding, trustworthiness) raise ValueError( f"Unknown manifold method '{method}' (expected 'PCA' or 'UMAP')." )