Source code for spikelab.data_loaders.data_loaders

"""
Lightweight loaders that convert common neurophysiology formats into
`spikedata.SpikeData` objects.

Supported inputs (best-effort, optional deps):
    - HDF5 (generic): spike times, (indices,times), or raster matrices
    - NWB: reads Units table spike_times (via pynwb if available, else h5py)
    - KiloSort/Phy outputs: spike_times.npy + spike_clusters.npy (+ optional TSV)
    - SpikeInterface: from a SortingExtractor
    - IBL (International Brain Laboratory): via ONE API + brainwidemap

Times are converted to milliseconds to match `SpikeData` conventions.
These helpers avoid hard dependencies: optional libraries are imported lazily.
"""

from __future__ import annotations

from typing import List, Mapping, Optional, Sequence, Union

import os
import warnings

import numpy as np

import pickle

try:
    import h5py
except ImportError:  # pragma: no cover
    h5py = None  # type: ignore

try:
    import pandas as pd  # noqa: F401  # used in type annotations only
except ImportError:  # pragma: no cover
    pd = None  # type: ignore[assignment]

from ..spikedata import SpikeData

__all__ = [
    "load_spikedata_from_hdf5",
    "load_spikedata_from_hdf5_raw_thresholded",
    "load_spikedata_from_nwb",
    "load_spikedata_from_kilosort",
    "load_spikedata_from_spikeinterface",
    "load_spikedata_from_spikeinterface_recording",
    "load_spikedata_from_pickle",
    "load_spikedata_from_ibl",
    "query_ibl_probes",
    "load_spikedata_from_spikelab_sorted_npz",
]

from ..spikedata.utils import ensure_h5py, to_ms


def _trains_from_flat_index(
    flat_times: np.ndarray,
    end_indices: np.ndarray,
    *,
    unit: str,
    fs_Hz: Optional[float],
) -> List[np.ndarray]:
    """Split a flat time array into per-unit trains using end indices and convert to ms."""
    if len(end_indices) > 0:
        if not np.all(np.diff(end_indices) >= 0):
            raise ValueError("spike_times_index must be monotonically non-decreasing")
        if end_indices[-1] > len(flat_times):
            raise ValueError(
                f"spike_times_index final value ({end_indices[-1]}) exceeds "
                f"flat_times length ({len(flat_times)})"
            )
    trains: List[np.ndarray] = []
    start = 0
    for stop in end_indices:
        segment = flat_times[start:stop]
        trains.append(to_ms(segment, unit, fs_Hz))
        start = stop
    return trains


def _read_raw_arrays(
    f,
    raw_dataset: Optional[str],
    raw_time_dataset: Optional[str],
    raw_time_unit: str,
    fs_Hz: Optional[float],
) -> tuple[Optional[np.ndarray], Optional[Union[np.ndarray, float]]]:
    """Read optional raw arrays and convert the time vector to milliseconds."""
    raw_data = None
    raw_time: Optional[Union[np.ndarray, float]] = None
    if raw_dataset is not None:
        raw_data = np.asarray(f[raw_dataset])
        if raw_time_dataset is not None:
            raw_time_vals = np.asarray(f[raw_time_dataset])
            if raw_time_unit == "s":
                raw_time = raw_time_vals * 1e3
            elif raw_time_unit == "ms":
                raw_time = raw_time_vals
            elif raw_time_unit == "samples":
                if not fs_Hz:
                    raise ValueError(
                        "fs_Hz must be provided for raw_time_unit='samples'"
                    )
                raw_time = raw_time_vals / float(fs_Hz) * 1e3
            else:
                raise ValueError("raw_time_unit must be one of 's','ms','samples'")
    return raw_data, raw_time


def _maybe_with_raw(
    sd: SpikeData,
    raw_data: Optional[np.ndarray],
    raw_time: Optional[Union[np.ndarray, float]],
) -> SpikeData:
    """Return SpikeData with raw fields attached if provided, else original."""
    if raw_data is not None and raw_time is not None:
        return _build_spikedata(
            sd.train,
            length_ms=sd.length,
            start_time=sd.start_time,
            metadata=sd.metadata,
            raw_data=raw_data,
            raw_time=raw_time,
            neuron_attributes=sd.neuron_attributes,
        )
    if (raw_data is None) != (raw_time is None):
        present = "raw_data" if raw_data is not None else "raw_time"
        missing = "raw_time" if raw_data is not None else "raw_data"
        warnings.warn(
            f"{present} was provided but {missing} is None — "
            f"raw data will not be attached to the SpikeData.",
            UserWarning,
        )
    return sd


def _build_spikedata(
    trains_ms: List[np.ndarray],
    *,
    length_ms: Optional[float] = None,
    start_time: float = 0.0,
    metadata: Optional[Mapping[str, object]] = None,
    raw_data: Optional[np.ndarray] = None,
    raw_time: Optional[Union[np.ndarray, float]] = None,
    neuron_attributes: Optional[List[dict]] = None,
) -> SpikeData:
    """Internal helper to construct a SpikeData with sensible defaults. Infers `length_ms` from the last spike if not provided."""
    if length_ms is None:
        last = [t[-1] for t in trains_ms if len(t) > 0]
        length_ms = float(max(last)) - start_time if last else 0.0
    return SpikeData(
        trains_ms,
        length=length_ms,
        start_time=start_time,
        metadata=dict(metadata) if metadata else {},
        raw_data=raw_data,
        raw_time=raw_time,
        neuron_attributes=neuron_attributes,
    )


# ----------------------------
# HDF5
# ----------------------------


[docs] def load_spikedata_from_hdf5( filepath: str, *, raster_dataset: Optional[str] = None, raster_bin_size_ms: Optional[float] = None, spike_times_dataset: Optional[str] = None, spike_times_index_dataset: Optional[str] = None, spike_times_unit: str = "s", fs_Hz: Optional[float] = None, group_per_unit: Optional[str] = None, group_time_unit: str = "s", idces_dataset: Optional[str] = None, times_dataset: Optional[str] = None, times_unit: str = "s", raw_dataset: Optional[str] = None, raw_time_dataset: Optional[str] = None, raw_time_unit: str = "s", length_ms: Optional[float] = None, metadata: Optional[Mapping[str, object]] = None, ) -> SpikeData: """Load spike trains from a generic HDF5 file using one of four supported input styles. Exactly one input style must be specified. The four styles are: raster matrix, ragged arrays, group-per-unit, and paired arrays. Parameters: filepath (str): Path to the HDF5 file. raster_dataset (str | None): Dataset path for a 2D raster/counts matrix (units x time). Activates raster style. raster_bin_size_ms (float | None): Bin width in milliseconds. Required for raster style. spike_times_dataset (str | None): Dataset path for flat concatenated spike times. Activates ragged style (requires spike_times_index_dataset). spike_times_index_dataset (str | None): Dataset path for cumulative end-of-unit indices into the flat spike times array. spike_times_unit (str): Time unit for ragged spike times ('s', 'ms', or 'samples'). fs_Hz (float | None): Sampling frequency in Hz. Required when any time unit is 'samples'. group_per_unit (str | None): HDF5 group path containing one dataset per unit. Activates group-per-unit style. group_time_unit (str): Time unit for group-per-unit datasets ('s', 'ms', or 'samples'). idces_dataset (str | None): Dataset path for unit index array. Activates paired-arrays style (requires times_dataset). times_dataset (str | None): Dataset path for spike times array (paired with idces_dataset). times_unit (str): Time unit for paired spike times ('s', 'ms', or 'samples'). raw_dataset (str | None): Dataset path for optional raw analog data. raw_time_dataset (str | None): Dataset path for the raw data time vector. raw_time_unit (str): Time unit for the raw time vector ('s', 'ms', or 'samples'). length_ms (float | None): Recording duration in milliseconds. If not provided, inferred from the latest spike time. metadata (Mapping | None): Additional metadata to attach to the resulting SpikeData. Returns: sd (SpikeData): The loaded spike train data. Raises: ValueError: If not exactly one input style is specified, or if required arguments are missing. """ ensure_h5py() # Validate exactly one style is provided provided = [ raster_dataset is not None, spike_times_dataset is not None and spike_times_index_dataset is not None, group_per_unit is not None, idces_dataset is not None and times_dataset is not None, ] if sum(provided) != 1: raise ValueError("Specify exactly one HDF5 input style") # Accumulate metadata and preserve file path provenance meta = dict(metadata or {}) meta.setdefault("source_file", os.path.abspath(filepath)) with h5py.File(filepath, "r") as f: # type: ignore # Read start_time if stored (backward compatible default 0.0) file_start_time = float(f.attrs.get("start_time", 0.0)) # Optionally read raw arrays and a time vector raw_data, raw_time = _read_raw_arrays( f, raw_dataset, raw_time_dataset, raw_time_unit, fs_Hz, ) if raster_dataset is not None: # Style (1): counts/raster matrix -> SpikeData via from_raster if raster_bin_size_ms is None: raise ValueError("raster_bin_size_ms is required for raster_dataset") raster = np.asarray(f[raster_dataset]) if raster.ndim != 2: raise ValueError("raster_dataset must be 2D (units, time)") total_time = raster.shape[1] * raster_bin_size_ms if total_time > 0: # subtract the smallest representable spacing so the length is # slightly less than the exact bin-aligned value and avoids # triggering the extra empty bin in `SpikeData.raster`. length_ms = max(total_time - np.spacing(total_time), 0.0) else: length_ms = 0.0 sd = SpikeData.from_raster( raster, raster_bin_size_ms, length=length_ms, start_time=file_start_time ) sd.metadata.update(meta) return _maybe_with_raw(sd, raw_data, raw_time) if spike_times_dataset is not None and spike_times_index_dataset is not None: # Style (2): flat ragged spike_times + spike_times_index flat = np.asarray(f[spike_times_dataset]) index = np.asarray(f[spike_times_index_dataset]) trains = _trains_from_flat_index( flat, index, unit=spike_times_unit, fs_Hz=fs_Hz ) return _build_spikedata( trains, length_ms=length_ms, start_time=file_start_time, metadata=meta, raw_data=raw_data, raw_time=raw_time, ) if group_per_unit is not None: # Style (3): each child dataset is a unit's spike times grp = f[group_per_unit] keys = sorted(list(grp.keys())) trains = [to_ms(np.asarray(grp[k]), group_time_unit, fs_Hz) for k in keys] return _build_spikedata( trains, length_ms=length_ms, start_time=file_start_time, metadata=meta, raw_data=raw_data, raw_time=raw_time, ) # Style (4): paired indices and times arrays idces = np.asarray(f[idces_dataset]) # type: ignore times = to_ms(np.asarray(f[times_dataset]), times_unit, fs_Hz) # type: ignore N = int(idces.max()) + 1 if idces.size else 0 sd = SpikeData.from_idces_times( idces, times, N=N, length=length_ms, start_time=file_start_time ) sd.metadata.update(meta) return _maybe_with_raw(sd, raw_data, raw_time)
[docs] def load_spikedata_from_hdf5_raw_thresholded( filepath: str, dataset: str, *, fs_Hz: float, threshold_sigma: float = 5.0, filter: Union[dict, bool] = True, hysteresis: bool = True, direction: str = "both", ) -> SpikeData: """Threshold-and-detect spikes from an HDF5 dataset of raw traces. Parameters: filepath (str): Path to HDF5 file. dataset (str): HDF5 dataset path containing raw traces shaped (channels, time). fs_Hz (float): Sampling frequency in Hz. threshold_sigma (float): Threshold in units of per-channel standard deviation. filter (dict | bool): If True, apply default Butterworth bandpass; if dict, pass to filter; if False, no filtering. hysteresis (bool): Use rising-edge detection if True. direction (str): 'both' | 'up' | 'down'. Returns: sd (SpikeData): The detected spike train data. """ ensure_h5py() with h5py.File(filepath, "r") as f: # type: ignore data = np.asarray(f[dataset]) return SpikeData.from_thresholding( data, fs_Hz=fs_Hz, threshold_sigma=threshold_sigma, filter=filter, hysteresis=hysteresis, direction=direction, # type: ignore[arg-type] )
# ---------------------------- # NWB (units table) # ----------------------------
[docs] def load_spikedata_from_nwb( filepath: str, *, prefer_pynwb: bool = True, length_ms: Optional[float] = None, ) -> SpikeData: """Load spike trains from an NWB file's Units table. Parameters: filepath (str): Path to the NWB file. prefer_pynwb (bool): If True, try pynwb first; if False, try h5py. length_ms (float | None): Recording duration in milliseconds. Returns: sd (SpikeData): The loaded spike train data. """ trains: List[np.ndarray] = [] neuron_attributes: List[dict] = [] meta = {"source_file": os.path.abspath(filepath), "format": "NWB"} if prefer_pynwb: try: from pynwb import NWBHDF5IO # type: ignore with NWBHDF5IO(filepath, "r") as io: nwb = io.read() if getattr(nwb, "units", None) is None: raise ValueError("NWB file has no Units table") df = nwb.units.to_dataframe() electrode_positions: Optional[dict] = None if getattr(nwb, "electrodes", None) is not None: elec_df = nwb.electrodes.to_dataframe() electrode_positions = {} for elec_row in elec_df.itertuples(): pos = [] for coord in ("x", "y", "z"): if coord in elec_df.columns: val = getattr(elec_row, coord, None) if val is not None and not np.isnan(val): pos.append(float(val)) if pos: electrode_positions[elec_row.Index] = pos for row in df.itertuples(): stimes = np.asarray(row.spike_times, dtype=float) trains.append(stimes * 1e3) attr = {"unit_id": row.Index} electrode_id = None for col in ("electrodes", "electrode_group", "channel", "ch"): if col in df.columns: val = getattr(row, col, None) if val is not None: if ( hasattr(val, "__len__") and not isinstance(val, str) and len(val) > 0 ): channel_val = val[0] else: channel_val = val try: attr["electrode"] = int(channel_val) electrode_id = int(channel_val) except (TypeError, ValueError): attr["electrode"] = channel_val electrode_id = channel_val break if electrode_positions and electrode_id in electrode_positions: attr["location"] = electrode_positions[electrode_id] neuron_attributes.append(attr) return _build_spikedata( trains, length_ms=length_ms, metadata=meta, neuron_attributes=neuron_attributes, ) except ImportError: # pragma: no cover pass # pynwb not installed — fall back to h5py except ( TypeError, ValueError, KeyError, AttributeError, ) as e: # pragma: no cover warnings.warn( f"pynwb failed to load NWB file ({type(e).__name__}: {e}); " f"falling back to h5py. If this is unexpected, check the file " f"format or report a bug.", stacklevel=2, ) ensure_h5py() with h5py.File(filepath, "r") as f: # type: ignore if "units" not in f: raise ValueError("NWB file missing '/units' group") unit_grp = f["units"] st_key = "spike_times" idx_key = "spike_times_index" if st_key not in unit_grp or idx_key not in unit_grp: candidates = [k for k in unit_grp.keys() if k.endswith("spike_times")] idx_candidates = [ k for k in unit_grp.keys() if k.endswith("spike_times_index") ] if not candidates or not idx_candidates: raise ValueError("Could not find spike_times datasets in NWB file") st_key = candidates[0] idx_key = idx_candidates[0] flat = np.asarray(unit_grp[st_key]) index = np.asarray(unit_grp[idx_key]) trains.extend( _trains_from_flat_index(flat.astype(float), index, unit="s", fs_Hz=None) ) unit_ids = ( np.asarray(unit_grp["id"]) if "id" in unit_grp else range(len(trains)) ) electrode_indices = None if "electrodes" in unit_grp and "electrodes_index" in unit_grp: elec_flat = np.asarray(unit_grp["electrodes"]) elec_idx = np.asarray(unit_grp["electrodes_index"]) if len(elec_idx) > 0 and elec_idx[-1] > len(elec_flat): warnings.warn( "NWB electrodes_index exceeds electrodes array length; " "electrode data may be truncated.", UserWarning, ) electrode_indices = [] start = 0 for stop in elec_idx: electrode_indices.append(elec_flat[start:stop]) start = stop electrode_positions: Optional[dict] = None elec_table_path = "general/extracellular_ephys/electrodes" if elec_table_path in f: elec_grp = f[elec_table_path] electrode_positions = {} x_arr = np.asarray(elec_grp["x"]) if "x" in elec_grp else None y_arr = np.asarray(elec_grp["y"]) if "y" in elec_grp else None z_arr = np.asarray(elec_grp["z"]) if "z" in elec_grp else None elec_ids = ( np.asarray(elec_grp["id"]) if "id" in elec_grp else np.arange(len(x_arr) if x_arr is not None else 0) ) for idx, eid in enumerate(elec_ids): pos = [] if x_arr is not None and idx < len(x_arr): pos.append(float(x_arr[idx])) if y_arr is not None and idx < len(y_arr): pos.append(float(y_arr[idx])) if z_arr is not None and idx < len(z_arr): pos.append(float(z_arr[idx])) if pos: electrode_positions[int(eid)] = pos for i, uid in enumerate(unit_ids): attr = {"unit_id": int(uid)} electrode_id = None if ( electrode_indices and i < len(electrode_indices) and len(electrode_indices[i]) > 0 ): electrode_id = int(electrode_indices[i][0]) attr["electrode"] = electrode_id if electrode_positions and electrode_id in electrode_positions: attr["location"] = electrode_positions[electrode_id] neuron_attributes.append(attr) return _build_spikedata( trains, length_ms=length_ms, metadata=meta, neuron_attributes=neuron_attributes )
# ---------------------------- # SpikeInterface # ----------------------------
[docs] def load_spikedata_from_spikeinterface( sorting, *, sampling_frequency: Optional[float] = None, unit_ids: Optional[Sequence[Union[int, str]]] = None, segment_index: int = 0, ) -> SpikeData: """Convert a SpikeInterface SortingExtractor-like object to SpikeData. Parameters: sorting (object): Exposes get_unit_ids(), get_sampling_frequency(), get_unit_spike_train(...). sampling_frequency (float | None): Optional override for sampling frequency (Hz). unit_ids (Sequence | None): Optional subset of unit IDs to include. segment_index (int): Segment index for multi-segment sortings. Returns: sd (SpikeData): The converted spike train data. """ try: get_unit_ids = sorting.get_unit_ids # type: ignore[attr-defined] get_sf = sorting.get_sampling_frequency # type: ignore[attr-defined] get_train = sorting.get_unit_spike_train # type: ignore[attr-defined] except AttributeError as e: raise TypeError( "`sorting` must be a SpikeInterface SortingExtractor-like object" ) from e fs = sampling_frequency or float(get_sf()) if not fs or fs <= 0: raise ValueError("A positive sampling_frequency (Hz) is required") ids = list(unit_ids) if unit_ids is not None else list(get_unit_ids()) trains: List[np.ndarray] = [] neuron_attributes: List[dict] = [] channel_prop = None location_prop = None if hasattr(sorting, "get_property"): for prop_name in ("channel", "ch", "peak_channel", "electrode"): try: channel_prop = sorting.get_property(prop_name) except (AttributeError, KeyError): continue if channel_prop is not None: break for prop_name in ("location", "unit_location", "position"): try: location_prop = sorting.get_property(prop_name) except (AttributeError, KeyError): continue if location_prop is not None: break for i, uid in enumerate(ids): st = np.asarray(get_train(unit_id=uid, segment_index=segment_index)) trains.append(to_ms(st.astype(float), "samples", fs)) attr = {"unit_id": uid} if channel_prop is not None and i < len(channel_prop): attr["electrode"] = int(channel_prop[i]) if location_prop is not None and i < len(location_prop): loc = location_prop[i] if loc is not None: attr["location"] = list(loc) if hasattr(loc, "__iter__") else [loc] neuron_attributes.append(attr) meta = {"source_format": "SpikeInterface", "unit_ids": ids, "fs_Hz": fs} return _build_spikedata(trains, metadata=meta, neuron_attributes=neuron_attributes)
# ---------------------------- # KiloSort / Phy # ----------------------------
[docs] def load_spikedata_from_kilosort( folder: str, *, fs_Hz: float, spike_times_file: str = "spike_times.npy", spike_clusters_file: str = "spike_clusters.npy", cluster_info_tsv: Optional[str] = None, time_unit: str = "samples", include_noise: bool = False, length_ms: Optional[float] = None, channel_map_file: str = "channel_map.npy", channel_positions_file: str = "channel_positions.npy", ) -> SpikeData: """Load KiloSort/Phy outputs into SpikeData. Parameters: folder (str): Path to the KiloSort/Phy output directory. fs_Hz (float): Sampling frequency in Hz. spike_times_file (str): Path to the spike_times.npy file. spike_clusters_file (str): Path to the spike_clusters.npy file. cluster_info_tsv (str | None): Path to the cluster info TSV file. time_unit (str): Unit of the spike times ('samples', 's', or 'ms'). include_noise (bool): If True, include noise clusters. length_ms (float | None): Recording duration in milliseconds. channel_map_file (str): Filename of the channel map file relative to folder. Expected format: 1D numpy array mapping cluster indices to channel numbers. channel_positions_file (str): Filename of the channel positions file relative to folder. Expected format: 2D numpy array of shape (channels, 3) containing channel positions. Returns: sd (SpikeData): The loaded spike train data. Notes: - This loader does not extract or include waveform data; only spike times and cluster assignments are loaded. - Reads spike_times.npy (samples) and spike_clusters.npy; groups times per cluster and converts to ms using fs_Hz. """ st_path = os.path.join(folder, spike_times_file) sc_path = os.path.join(folder, spike_clusters_file) spike_times = np.load(st_path) spike_clusters = np.load(sc_path) if spike_times.shape[0] != spike_clusters.shape[0]: raise ValueError("spike_times and spike_clusters length mismatch") channel_map: Optional[np.ndarray] = None cm_path = os.path.join(folder, channel_map_file) if os.path.exists(cm_path): try: channel_map = np.load(cm_path).flatten() except (IOError, ValueError) as e: warnings.warn(f"Failed loading channel_map: {e}") channel_positions: Optional[np.ndarray] = None cp_path = os.path.join(folder, channel_positions_file) if os.path.exists(cp_path): try: channel_positions = np.load(cp_path) except (IOError, ValueError) as e: warnings.warn(f"Failed loading channel_positions: {e}") keep_clusters: Optional[set] = None if cluster_info_tsv is not None: tsv_path = os.path.join(folder, cluster_info_tsv) if os.path.exists(tsv_path): try: import pandas as pd df = pd.read_csv(tsv_path, sep="\t") label_col = ( "group" if "group" in df.columns else ("KSLabel" if "KSLabel" in df.columns else None) ) id_col = ( "cluster_id" if "cluster_id" in df.columns else ("id" if "id" in df.columns else None) ) if id_col is None or label_col is None: warnings.warn( "Could not find id/label columns in cluster TSV; keeping all clusters" ) else: if include_noise: keep_clusters = set(df[id_col].astype(int).tolist()) else: mask = ( df[label_col] .astype(str) .str.lower() .isin(["good", "mua", "mua good"]) ) # permissive keep_clusters = set(df.loc[mask, id_col].astype(int).tolist()) except ImportError: warnings.warn( "pandas is required to parse cluster info TSV. " "Install with: pip install spikelab[io]. " "Keeping all clusters." ) except (IOError, ValueError, KeyError) as e: warnings.warn( f"Failed parsing cluster info TSV: {e}; keeping all clusters" ) trains: List[np.ndarray] = [] metadata_units: List[int] = [] neuron_attributes: List[dict] = [] unique_clusters = np.unique(spike_clusters) if channel_map is not None and len(unique_clusters) > 0: expected_sequential = np.arange(len(unique_clusters)) if not np.array_equal(unique_clusters, expected_sequential): warnings.warn( f"Cluster IDs are not sequential (0..{len(unique_clusters)-1}): " f"channel_map lookup uses cluster ID as array index, which " f"may assign incorrect electrode/location metadata after " f"Phy curation. Verify spatial analysis results.", UserWarning, ) unit_idx = 0 for clu in unique_clusters: if keep_clusters is not None and int(clu) not in keep_clusters: continue times = spike_times[spike_clusters == clu] times_ms = to_ms(times.astype(float), time_unit, fs_Hz) trains.append(np.sort(times_ms)) metadata_units.append(int(clu)) attr: dict = {"unit_id": int(clu)} channel_idx = None int_clu = int(clu) # channel_map is indexed by template/cluster ID — only correct # when cluster IDs are sequential integers starting from 0. # After Phy curation (merge/split), IDs become non-sequential # and this lookup silently maps to the wrong channel. if channel_map is not None and int_clu < len(channel_map): channel_idx = int(channel_map[int_clu]) attr["electrode"] = channel_idx if channel_positions is not None: if channel_idx is not None and channel_idx < len(channel_positions): attr["location"] = list(channel_positions[channel_idx]) elif unit_idx < len(channel_positions): # Fallback: use unit index when channel map lookup fails attr["location"] = list(channel_positions[unit_idx]) neuron_attributes.append(attr) unit_idx += 1 meta = { "source_folder": os.path.abspath(folder), "source_format": "KiloSort", "cluster_ids": metadata_units, "fs_Hz": fs_Hz, } return _build_spikedata( trains, length_ms=length_ms, metadata=meta, neuron_attributes=neuron_attributes )
# ---------------------------- # SpikeLab sorted .npz -> SpikeData # ----------------------------
[docs] def load_spikedata_from_spikelab_sorted_npz( filepath: str, *, length_ms: Optional[float] = None, ) -> SpikeData: """Load a SpikeLab compiled sorting result (``.npz``) into SpikeData. These ``.npz`` files are produced by :func:`sort_with_kilosort2`'s ``compile_results`` step and contain per-unit spike trains, electrode locations, waveform templates, and quality metrics. Parameters: filepath (str): Path to the ``.npz`` file. length_ms (float | None): Recording duration in milliseconds. Inferred from the latest spike time when *None*. Returns: sd (SpikeData): The loaded spike train data with neuron attributes (unit_id, location, electrode, template, amplitudes, etc.). """ data = np.load(filepath, allow_pickle=True) units = data["units"] fs_Hz = float(data["fs"]) locations = data.get("locations", None) trains: List[np.ndarray] = [] neuron_attributes: List[dict] = [] for unit in units: spike_samples = unit["spike_train"] spike_times_ms = np.sort(spike_samples.astype(float) / fs_Hz * 1000.0) trains.append(spike_times_ms) attr: dict = {"unit_id": int(unit["unit_id"])} if "x_max" in unit and "y_max" in unit: attr["location"] = [float(unit["x_max"]), float(unit["y_max"])] if "electrode" in unit: attr["electrode"] = int(unit["electrode"]) if "template" in unit: attr["template"] = np.asarray(unit["template"]) if "amplitudes" in unit: attr["amplitudes"] = np.asarray(unit["amplitudes"]) if "std_norms" in unit: attr["std_norms"] = np.asarray(unit["std_norms"]) if "peak_sign" in unit: attr["peak_sign"] = str(unit["peak_sign"]) if "max_channel_id" in unit: attr["max_channel_id"] = str(unit["max_channel_id"]) neuron_attributes.append(attr) meta = { "source_file": os.path.abspath(filepath), "source_format": "SpikeLab_npz", "fs_Hz": fs_Hz, } if locations is not None: meta["channel_locations"] = locations return _build_spikedata( trains, length_ms=length_ms, metadata=meta, neuron_attributes=neuron_attributes )
# ---------------------------- # SpikeInterface BaseRecording -> SpikeData via thresholding # ----------------------------
[docs] def load_spikedata_from_spikeinterface_recording( recording, *, segment_index: int = 0, threshold_sigma: float = 5.0, filter: Union[dict, bool] = False, hysteresis: bool = True, direction: str = "both", ) -> SpikeData: """Convert a SpikeInterface BaseRecording-like object into SpikeData. Parameters: recording (object): Exposes get_traces(segment_index=...), get_sampling_frequency(), get_num_channels(). segment_index (int): Segment index for multi-segment recordings. threshold_sigma (float): Threshold in units of per-channel standard deviation. filter (dict | bool): If True, apply default Butterworth bandpass; if dict, pass to filter; if False, no filtering. hysteresis (bool): Use rising-edge detection if True. direction (str): 'both' | 'up' | 'down'. Returns: sd (SpikeData): The converted spike train data. """ # Resolve sampling frequency if hasattr(recording, "get_sampling_frequency"): fs = float(recording.get_sampling_frequency()) else: fs = float(getattr(recording, "sampling_frequency")) if not fs or fs <= 0: raise ValueError("A positive sampling_frequency (Hz) is required on recording") # Retrieve traces (2D array) and coerce to numpy traces = recording.get_traces(segment_index=segment_index) data = np.asarray(traces) # Ensure orientation is (channels, time) via robust heuristic: # choose the smaller dimension as channels (typical: channels << time). if data.ndim != 2: raise ValueError("recording.get_traces() must return a 2D array") if data.shape[0] == data.shape[1]: warnings.warn( f"Ambiguous data orientation: shape is {data.shape} (square). " "Assuming (channels, time). Pass data with an explicit orientation " "if this is incorrect.", UserWarning, ) data_ct = data if data.shape[0] <= data.shape[1] else data.T # Delegate detection to SpikeData convenience constructor return SpikeData.from_thresholding( data_ct, fs_Hz=fs, threshold_sigma=threshold_sigma, filter=filter, hysteresis=hysteresis, direction=direction, # type: ignore[arg-type] )
# ---------------------------- # Pickle # ----------------------------
[docs] def load_spikedata_from_pickle( filepath: str, *, aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, aws_session_token: Optional[str] = None, region_name: Optional[str] = None, ) -> SpikeData: """Load a SpikeData object from a pickle file. Warning: Only load pickle files from trusted sources. Pickle deserialization can execute arbitrary code and should never be used with untrusted data. The file is deserialized before type checking — malicious payloads execute regardless of the subsequent isinstance check. Parameters: filepath (str): Path to the pickle file, or an S3 URL (s3://bucket/key). aws_access_key_id (str | None): AWS access key ID for S3 downloads. aws_secret_access_key (str | None): AWS secret access key for S3 downloads. aws_session_token (str | None): AWS session token for temporary credentials. region_name (str | None): AWS region name for S3 access. Returns: sd (SpikeData): The deserialized SpikeData object. """ from .s3_utils import ensure_local_file local_path, is_temp = ensure_local_file( filepath, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, region_name=region_name, ) try: with open(local_path, "rb") as f: obj = pickle.load(f) finally: if is_temp: try: os.remove(local_path) except OSError: pass if not isinstance(obj, SpikeData): raise ValueError( f"Pickle file does not contain a SpikeData object (found {type(obj).__name__})" ) return obj
# ---------------------------- # IBL (International Brain Laboratory) # ---------------------------- #: IBL public server URL used for ONE authentication. _IBL_BASE_URL = "https://openalyx.internationalbrainlab.org" #: Collections searched in order when loading spikes. The probe-specific #: collection is prepended at runtime based on the PID suffix. _IBL_FALLBACK_COLLECTIONS = [ "alf/probe00/pykilosort", "alf/probe01/pykilosort", "alf", ]
[docs] def load_spikedata_from_ibl( eid: str, pid: str, *, length_ms: Optional[float] = None, ) -> SpikeData: """Load spike trains for a single IBL probe into SpikeData. Authenticates against the public IBL server automatically. Only units labelled as good (``label == 1``) in the Brain-Wide Map unit table are included. Trial event times are stored in ``SpikeData.metadata`` as individual numpy arrays, all in milliseconds. Parameters: eid (str): IBL experiment ID (UUID string). pid (str): IBL probe ID (UUID string). length_ms (float | None): Recording duration in milliseconds. If not provided, the maximum spike time across all units is used. Returns: sd (SpikeData): Loaded spike train data. ``neuron_attributes`` contains ``{"region": <Beryl atlas region>}`` per unit. ``metadata`` contains ``eid``, ``pid``, ``n_trials``, ``trial_start_times``, ``trial_end_times``, ``stim_on_times``, ``stim_off_times``, ``go_cue_times``, ``response_times``, ``feedback_times``, ``first_movement_times``, ``choice``, ``feedback_type``, ``contrast_left``, ``contrast_right``, and ``probability_left``. All time arrays are in milliseconds. Notes: - Requires ``one-api`` and ``brainwidemap`` packages (optional dependencies). - Spike times are converted from seconds (IBL convention) to milliseconds. - Trial times are converted from seconds to milliseconds. - Probe collection is inferred from the PID suffix; falls back through ``alf/probe00/pykilosort``, ``alf/probe01/pykilosort``, and ``alf``. """ try: from one.api import ONE # type: ignore except ImportError as e: raise ImportError( "one-api is required for load_spikedata_from_ibl. " "Install with: pip install one-api" ) from e try: from brainwidemap import bwm_units # type: ignore except ImportError as e: raise ImportError( "brainwidemap is required for load_spikedata_from_ibl. " "Install with: pip install brainwidemap" ) from e # Authenticate against the public IBL server. ONE.setup(base_url=_IBL_BASE_URL, silent=True) one = ONE(password="international") # Retrieve good units for this probe from the Brain-Wide Map table. unit_df = bwm_units(one) good_units = unit_df[(unit_df["pid"] == pid) & (unit_df["label"] == 1)] # Build the ordered list of collections to try, with the probe-specific # one first when the PID suffix hints at the probe number. This is a # best-effort heuristic for ordering (PIDs are UUIDs, so the last two # hex chars can coincidentally match "00"/"01"). All candidates are # tried regardless, so correctness is not affected — only the order. collections = [] if pid.endswith("00") or pid.endswith("01"): collections.append(f"alf/probe{pid[-2:]}/pykilosort") collections.extend(_IBL_FALLBACK_COLLECTIONS) # Deduplicate while preserving order. seen: set = set() ordered_collections: List[str] = [] for c in collections: if c not in seen: seen.add(c) ordered_collections.append(c) # Load spikes from the first available collection. spikes = None for collection in ordered_collections: try: spikes = one.load_object(eid, "spikes", collection=collection) break except (ValueError, KeyError, FileNotFoundError): continue # Build per-unit spike trains (seconds → milliseconds). spike_trains: List[np.ndarray] = [] neuron_attributes: List[dict] = [] for _, unit in good_units.iterrows(): if spikes is None: spike_trains.append(np.array([], dtype=float)) else: mask = spikes["clusters"] == unit["cluster_id"] spike_trains.append(spikes["times"][mask] * 1_000.0) neuron_attributes.append({"region": unit["Beryl"]}) # Infer session length from the largest spike time if not provided. if length_ms is None: max_t = max((t.max() for t in spike_trains if len(t) > 0), default=0.0) length_ms = float(max_t) if max_t > 0 else 10_000.0 # Load trials and extract relevant fields as numpy arrays (seconds → ms). trials = one.load_object(eid, "trials") trials_df = trials.to_df() n_trials = len(trials_df) def _to_ms_array(col: str) -> np.ndarray: """Extract a trials column and convert seconds to milliseconds.""" return trials_df[col].to_numpy(dtype=float) * 1_000.0 def _to_array(col: str) -> np.ndarray: """Extract a trials column as a plain numpy array (no unit conversion).""" return trials_df[col].to_numpy(dtype=float) metadata: dict = { "eid": eid, "pid": pid, "n_trials": n_trials, "trial_start_times": _to_ms_array("intervals_0"), "trial_end_times": _to_ms_array("intervals_1"), "stim_on_times": _to_ms_array("stimOn_times"), "stim_off_times": _to_ms_array("stimOff_times"), "go_cue_times": _to_ms_array("goCue_times"), "response_times": _to_ms_array("response_times"), "feedback_times": _to_ms_array("feedback_times"), "first_movement_times": _to_ms_array("firstMovement_times"), "choice": _to_array("choice"), "feedback_type": _to_array("feedbackType"), "contrast_left": _to_array("contrastLeft"), "contrast_right": _to_array("contrastRight"), "probability_left": _to_array("probabilityLeft"), } return _build_spikedata( spike_trains, length_ms=length_ms, metadata=metadata, neuron_attributes=neuron_attributes, )
[docs] def query_ibl_probes( target_regions: Optional[List[str]] = None, *, min_units: int = 0, min_fraction_in_target: float = 0.0, ) -> "tuple[list[tuple[str, str]], pd.DataFrame]": """Search the IBL Brain-Wide Map database for probes matching given criteria. Authenticates against the public IBL server automatically. Filters probes by brain region and unit count. Returns matching (eid, pid) pairs alongside a per-probe statistics DataFrame. Parameters: target_regions (list[str] | None): Beryl atlas region names to filter by (e.g. ``["MOs", "MOp"]``). If None, no region filter is applied. min_units (int): Minimum number of good units required per probe. Default ``0`` (no minimum). min_fraction_in_target (float): Minimum fraction (0–1) of good units that must fall within ``target_regions``. Ignored when ``target_regions`` is ``None``. Default ``0.0``. Returns: probes (list[tuple[str, str]]): List of ``(eid, pid)`` pairs for probes that pass all filters, sorted by descending good unit count. stats (pd.DataFrame): One row per matching probe with columns: ``eid``, ``pid``, ``n_good_units``, and (when ``target_regions`` is not ``None``) ``n_in_target`` and ``fraction_in_target``. Notes: - Requires ``one-api`` and ``brainwidemap`` packages (optional dependencies). - ``bwm_units()`` fetches the full Brain-Wide Map unit table from the IBL server; this may take several seconds on first call. """ try: from one.api import ONE # type: ignore except ImportError as e: raise ImportError( "one-api is required for query_ibl_probes. " "Install with: pip install one-api" ) from e try: from brainwidemap import bwm_units # type: ignore except ImportError as e: raise ImportError( "brainwidemap is required for query_ibl_probes. " "Install with: pip install brainwidemap" ) from e try: import pandas as pd # type: ignore except ImportError as e: raise ImportError( "pandas is required for query_ibl_probes. " "Install with: pip install spikelab[io]" ) from e # Authenticate against the public IBL server. ONE.setup(base_url=_IBL_BASE_URL, silent=True) one = ONE(password="international") # Fetch all good units from the Brain-Wide Map table. unit_df = bwm_units(one) good_units = unit_df[unit_df["label"] == 1].copy() # Build per-probe aggregation. agg = good_units.groupby(["eid", "pid"], as_index=False).agg( n_good_units=("cluster_id", "count"), ) # Compute region-based columns when target_regions is provided. if target_regions is not None: in_target = good_units["Beryl"].isin(target_regions) region_counts = ( good_units[in_target] .groupby(["eid", "pid"], as_index=False) .agg(n_in_target=("cluster_id", "count")) ) agg = agg.merge(region_counts, on=["eid", "pid"], how="left") agg["n_in_target"] = agg["n_in_target"].fillna(0).astype(int) agg["fraction_in_target"] = np.where( agg["n_good_units"] > 0, agg["n_in_target"] / agg["n_good_units"], 0.0, ) # Apply unit-count filter. mask = agg["n_good_units"] >= min_units # Apply region fraction filter. if target_regions is not None: mask = mask & (agg["fraction_in_target"] >= min_fraction_in_target) stats = ( agg[mask].sort_values("n_good_units", ascending=False).reset_index(drop=True) ) probes = list(zip(stats["eid"].tolist(), stats["pid"].tolist())) return probes, stats