Source code for spikelab.data_loaders.data_exporters

"""Data exporters that mirror data_loaders, writing SpikeData to common formats.

Provided exporters:

- HDF5 generic with one of four styles: ``raster`` (units x time matrix
  with a specified bin size in ms), ``ragged`` (flat ``spike_times`` plus
  ``spike_times_index``), ``group`` (one HDF5 group per unit), or ``paired``
  (parallel ``idces`` and ``times`` arrays).
- NWB Units table (``spike_times`` / ``spike_times_index``) via h5py.
- KiloSort/Phy (``spike_times.npy`` + ``spike_clusters.npy``).

All exporters accept SpikeData times in milliseconds and convert to the
target time units as needed.
"""

from __future__ import annotations

from typing import Iterable, Literal, Optional, Sequence, Tuple, Union, TYPE_CHECKING

import os
import warnings

import numpy as np

import pickle

try:
    import h5py
except ImportError:  # pragma: no cover
    h5py = None  # type: ignore

if TYPE_CHECKING:  # avoid runtime circular import
    from ..spikedata import SpikeData  # noqa: F401

from ..spikedata.utils import TimeUnit, ensure_h5py, times_from_ms


[docs] def export_spikedata_to_hdf5( sd: "SpikeData", filepath: str, *, style: Literal["raster", "ragged", "group", "paired"] = "ragged", # raster raster_dataset: str = "raster", raster_bin_size_ms: Optional[float] = None, # ragged spike_times_dataset: str = "spike_times", spike_times_index_dataset: str = "spike_times_index", spike_times_unit: TimeUnit = "s", fs_Hz: Optional[float] = None, # group-per-unit group_per_unit: str = "units", group_time_unit: TimeUnit = "s", # paired arrays idces_dataset: str = "idces", times_dataset: str = "times", times_unit: TimeUnit = "ms", # optional raw arrays (written if present and destinations provided) raw_dataset: Optional[str] = None, raw_time_dataset: Optional[str] = None, raw_time_unit: TimeUnit = "ms", ) -> None: """Export a SpikeData to a generic HDF5 file using a chosen style. Parameters: sd (SpikeData): The SpikeData object to export. filepath (str): Path where the HDF5 file will be created (overwrites existing). style (Literal["raster", "ragged", "group", "paired"]): Export format style; see the module docstring for what each style produces. raster_dataset (str): HDF5 dataset name for the raster matrix. raster_bin_size_ms (float | None): Bin size in milliseconds for rasterization. Required for raster style. spike_times_dataset (str): Dataset name for concatenated spike times. spike_times_index_dataset (str): Dataset name for cumulative spike count indices. spike_times_unit (TimeUnit): Time unit for spike times ('ms', 's', 'samples'). fs_Hz (float | None): Sampling frequency in Hz. Required when any unit is 'samples'. group_per_unit (str): HDF5 group name containing per-unit datasets. group_time_unit (TimeUnit): Time unit for individual unit datasets. idces_dataset (str): Dataset name for unit indices array. times_dataset (str): Dataset name for spike times array. times_unit (TimeUnit): Time unit for spike times. raw_dataset (str | None): Dataset name for raw analog data (if present in sd). raw_time_dataset (str | None): Dataset name for raw data time vector. raw_time_unit (TimeUnit): Time unit for raw data timestamps. Raises: ImportError: If h5py is not available. ValueError: For invalid styles, missing required parameters, or missing fs_Hz when needed. Notes: - Spike times are automatically converted from milliseconds to the requested unit. - The function creates or overwrites the target HDF5 file. - Raw data is only written if both raw_dataset and raw_time_dataset are provided and the SpikeData contains raw_data and raw_time attributes. - For raster style, the bin size is stored as an attribute for provenance. - Parameters mirror the corresponding loader function to ease round-tripping. - The generic HDF5 format does not persist ``neuron_attributes`` or ``metadata``; use ``AnalysisWorkspace.save`` (workspace HDF5) or ``export_to_pickle`` for full-fidelity round-trips. """ ensure_h5py() style = style.lower() # normalize valid_styles = {"raster", "ragged", "group", "paired"} if style not in valid_styles: raise ValueError( f"Unknown style '{style}' (choose one of {sorted(valid_styles)})" ) # Create or overwrite the HDF5 file with h5py.File(filepath, "w") as f: # type: ignore # Store start_time for event-centered data round-trips f.attrs["start_time"] = float(sd.start_time) # Optionally write raw arrays if destinations are provided and data exist if ( raw_dataset and raw_time_dataset and getattr(sd, "raw_data", None) is not None and sd.raw_data.size > 0 ): f.create_dataset(raw_dataset, data=np.asarray(sd.raw_data)) # Export raw_time converted to the requested unit raw_time = np.asarray(sd.raw_time) if raw_time_unit == "ms": raw_time_out = raw_time elif raw_time_unit == "s": raw_time_out = raw_time / 1e3 elif raw_time_unit == "samples": if not fs_Hz or fs_Hz <= 0: raise ValueError( "fs_Hz must be provided for raw_time_unit='samples'" ) raw_time_out = np.rint(raw_time * (fs_Hz / 1e3)).astype(int) else: raise ValueError("raw_time_unit must be one of 's','ms','samples'") f.create_dataset(raw_time_dataset, data=raw_time_out) if style == "raster": if raster_bin_size_ms is None or raster_bin_size_ms <= 0: raise ValueError( "raster_bin_size_ms must be provided and > 0 for raster style" ) raster = sd.raster(raster_bin_size_ms) f.create_dataset(raster_dataset, data=np.asarray(raster)) # Store bin size as an attribute for provenance (readers can ignore) f[raster_dataset].attrs["bin_size_ms"] = float(raster_bin_size_ms) return # file-level attr (start_time) already written above if style == "ragged": # Flatten all trains and write cumulative end indices counts = [len(t) for t in sd.train] flat_ms = np.concatenate(sd.train) if sum(counts) else np.array([], float) flat = times_from_ms(flat_ms, spike_times_unit, fs_Hz) index = np.cumsum(counts, dtype=int) f.create_dataset(spike_times_dataset, data=flat) f.create_dataset(spike_times_index_dataset, data=index) return if style == "group": grp = f.create_group(group_per_unit) for i, tms in enumerate(sd.train): grp.create_dataset( str(i), data=times_from_ms(np.asarray(tms), group_time_unit, fs_Hz) ) return # paired idces: list[int] = [] times_ms: list[float] = [] for unit_index, tms in enumerate(sd.train): if len(tms) == 0: continue idces.extend([unit_index] * len(tms)) times_ms.extend(tms.tolist()) idces_arr = np.array(idces, dtype=int) times_arr = times_from_ms(np.array(times_ms, dtype=float), times_unit, fs_Hz) f.create_dataset(idces_dataset, data=idces_arr) f.create_dataset(times_dataset, data=times_arr)
[docs] def export_spikedata_to_nwb( sd: "SpikeData", filepath: str, *, spike_times_dataset: str = "spike_times", spike_times_index_dataset: str = "spike_times_index", group: str = "units", ) -> None: """Export SpikeData to a minimal NWB-like file using h5py. Parameters: sd (SpikeData): The SpikeData object to export. filepath (str): Path where the NWB file will be created (overwrites existing). spike_times_dataset (str): Name of the dataset containing concatenated spike times. Default is "spike_times" per NWB convention. spike_times_index_dataset (str): Name of the dataset containing cumulative indices. Default is "spike_times_index" per NWB convention. group (str): Name of the HDF5 group to contain the datasets. Default is "units" per NWB convention. Raises: ImportError: If h5py is not available. Notes: - Spike times are automatically converted from milliseconds to seconds. - The output file structure follows NWB conventions but is minimal (does not include full NWB metadata or schema validation). - Empty units (no spikes) are handled correctly in the index array. - This is compatible with the load_spikedata_from_nwb function when prefer_pynwb=False. """ ensure_h5py() if sd.start_time != 0: warnings.warn( f"Exporting event-centered SpikeData (start_time={sd.start_time}) " "to NWB. The NWB format does not store start_time, so spike times " "are written as-is. On reload, start_time will default to 0.", UserWarning, ) counts = [len(t) for t in sd.train] flat_ms = np.concatenate(sd.train) if sum(counts) else np.array([], float) flat_s = times_from_ms(flat_ms, "s", fs_Hz=None) index = np.cumsum(counts, dtype=int) with h5py.File(filepath, "w") as f: # type: ignore g = f.create_group(group) g.create_dataset(spike_times_dataset, data=flat_s) g.create_dataset(spike_times_index_dataset, data=index) g.create_dataset("id", data=np.arange(sd.N, dtype=int)) electrodes = sd.electrodes if electrodes is not None: g.create_dataset("electrodes", data=electrodes) g.create_dataset("electrodes_index", data=np.arange(1, sd.N + 1, dtype=int)) unit_locations = sd.unit_locations if unit_locations is not None: elec_grp = f.create_group("general/extracellular_ephys/electrodes") locations = unit_locations # Build electrodes table IDs to be consistent with units/electrodes. if electrodes is not None: elec_ids = np.asarray(sd.electrodes, dtype=int) # Unique electrode IDs and representative indices into unit_locations unique_ids, first_indices = np.unique(elec_ids, return_index=True) # Sort by electrode ID for a stable, ordered table sort_idx = np.argsort(unique_ids) unique_ids = unique_ids[sort_idx] first_indices = first_indices[sort_idx] elec_grp.create_dataset("id", data=unique_ids) elec_locations = locations[first_indices] else: # Fallback: no explicit electrode IDs; use 0..N-1 as before elec_grp.create_dataset("id", data=np.arange(sd.N, dtype=int)) elec_locations = locations # Only dimensions present in the data are written. On # reload, locations will have fewer columns than the original # if y or z were omitted here — this is inherent to the NWB # format and cannot be avoided without padding with zeros. elec_grp.create_dataset("x", data=elec_locations[:, 0]) if elec_locations.shape[1] > 1: elec_grp.create_dataset("y", data=elec_locations[:, 1]) if elec_locations.shape[1] > 2: elec_grp.create_dataset("z", data=elec_locations[:, 2])
[docs] def export_spikedata_to_kilosort( sd: "SpikeData", folder: str, *, fs_Hz: float, spike_times_file: str = "spike_times.npy", spike_clusters_file: str = "spike_clusters.npy", time_unit: TimeUnit = "samples", cluster_ids: Optional[Sequence[int]] = None, ) -> Tuple[str, str]: """Export SpikeData to a KiloSort/Phy-like folder. Parameters: sd (SpikeData): The SpikeData object to export. folder (str): Directory path where the .npy files will be created. Created if it doesn't exist. fs_Hz (float): Sampling frequency in Hz. Required for time unit conversion, especially when time_unit='samples'. spike_times_file (str): Filename for the spike times array. Default is "spike_times.npy". spike_clusters_file (str): Filename for the spike clusters array. Default is "spike_clusters.npy". time_unit (TimeUnit): Time unit for output spike times. 'samples': integer sample indices (default, KiloSort standard). 'ms': milliseconds (float). 's': seconds (float). cluster_ids (Sequence[int] | None): Custom cluster IDs for each unit. If None, uses sequential integers 0, 1, 2, ... Length must match sd.N. Returns: paths (tuple[str, str]): Paths to the created spike_times.npy and spike_clusters.npy files. Notes: - The output arrays have the same length (one entry per spike across all units). - Spike times are sorted by unit order, not chronologically. - Empty units (no spikes) don't contribute entries to the output arrays. - The 'samples' time unit produces integer arrays suitable for KiloSort/Phy. - Cluster IDs can be arbitrary integers and don't need to be sequential. """ if not fs_Hz or fs_Hz <= 0: raise ValueError("A positive fs_Hz is required for KiloSort export") if sd.start_time != 0: warnings.warn( f"Exporting event-centered SpikeData (start_time={sd.start_time}) " "to KiloSort. The format does not store start_time, so spike times " "are written as-is. On reload, start_time will default to 0.", UserWarning, ) os.makedirs(folder, exist_ok=True) # Build flat arrays idces: list[int] = [] times_ms: list[float] = [] for unit_index, tms in enumerate(sd.train): if len(tms) == 0: continue idces.extend([unit_index] * len(tms)) times_ms.extend(tms.tolist()) # Map units -> cluster ids if cluster_ids is None: cluster_ids = list(range(sd.N)) if len(cluster_ids) != sd.N: raise ValueError("cluster_ids length must match sd.N") clusters = np.array([int(cluster_ids[i]) for i in idces], dtype=int) # Convert times if time_unit == "samples": times_out = times_from_ms(np.array(times_ms, dtype=float), "samples", fs_Hz) elif time_unit == "ms": times_out = np.array(times_ms, dtype=float) elif time_unit == "s": times_out = np.array(times_ms, dtype=float) / 1e3 else: raise ValueError("time_unit must be one of 'samples','ms','s'") # KiloSort expects numpy arrays saved to .npy spike_times_path = os.path.join(folder, spike_times_file) spike_clusters_path = os.path.join(folder, spike_clusters_file) np.save(spike_times_path, times_out) np.save(spike_clusters_path, clusters) if sd.electrodes is not None: np.save(os.path.join(folder, "channel_map.npy"), sd.electrodes) return spike_times_path, spike_clusters_path
[docs] def export_to_pickle( obj, filepath: str, *, protocol: Optional[int] = None, s3_upload: bool = False, aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, aws_session_token: Optional[str] = None, region_name: Optional[str] = None, ) -> str: """Export a spikelab data object to a pickle file. Supported types: ``SpikeData``, ``RateData``, ``PairwiseCompMatrix``, ``PairwiseCompMatrixStack``, ``RateSliceStack``, ``SpikeSliceStack``. Parameters: obj: The spikelab data object to export. filepath (str): Path where the pickle file will be created (overwrites existing). If s3_upload=True, this should be an S3 URL (s3://bucket/key). protocol (int | None): Pickle protocol version. If None, uses the highest protocol available. Lower protocols (e.g., 2, 3) may be needed for compatibility with older Python versions. s3_upload (bool): If True, upload to S3 URL specified in filepath. aws_access_key_id (str | None): AWS access key ID for S3 uploads. aws_secret_access_key (str | None): AWS secret access key for S3 uploads. aws_session_token (str | None): AWS session token for temporary credentials. region_name (str | None): AWS region name for S3 access. Returns: path (str): Path to the created pickle file (local path or S3 URL). """ import tempfile from ..spikedata.spikedata import SpikeData from ..spikedata.ratedata import RateData from ..spikedata.pairwise import PairwiseCompMatrix, PairwiseCompMatrixStack from ..spikedata.rateslicestack import RateSliceStack from ..spikedata.spikeslicestack import SpikeSliceStack from .s3_utils import is_s3_url, upload_to_s3 as _upload_to_s3 _SUPPORTED = ( SpikeData, RateData, PairwiseCompMatrix, PairwiseCompMatrixStack, RateSliceStack, SpikeSliceStack, ) if not isinstance(obj, _SUPPORTED): supported_names = ", ".join(t.__name__ for t in _SUPPORTED) raise TypeError( f"Expected a spikelab data object ({supported_names}), " f"got {type(obj).__name__}" ) sd = obj # preserve variable name for minimal diff below if s3_upload: if not is_s3_url(filepath): raise ValueError( f"filepath must be an S3 URL when s3_upload=True (got '{filepath}')" ) with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as tmp: temp_path = tmp.name try: with open(temp_path, "wb") as f: pickle.dump(sd, f, protocol=protocol) _upload_to_s3( temp_path, filepath, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, region_name=region_name, ) return filepath finally: try: os.remove(temp_path) except OSError: pass # Best-effort cleanup: ignore failures when removing temporary file. else: dirpath = os.path.dirname(filepath) if dirpath: os.makedirs(dirpath, exist_ok=True) with open(filepath, "wb") as f: pickle.dump(sd, f, protocol=protocol) return filepath
__all__ = [ "export_spikedata_to_hdf5", "export_spikedata_to_nwb", "export_spikedata_to_kilosort", "export_to_pickle", ]