Source code for spikelab.spike_sorting.pipeline

"""Sorter-agnostic spike sorting pipeline orchestration.

This module contains the functions that run after a sorter backend
has produced its output: SpikeData conversion, curation, compilation,
and epoch splitting.  These functions are independent of which sorter
was used — they operate on SpikeData and the ``SortingPipelineConfig``.

The backend-specific steps (loading, sorting, waveform extraction) are
handled by the ``SorterBackend`` subclass passed to
``process_recording``.
"""

import json
import os
import pickle
import sys
import time
from typing import Any, Dict, List, Optional, Tuple, Union
import shutil
import warnings
from pathlib import Path

import numpy as np

from .config import SortingPipelineConfig

from .sorting_utils import (
    Stopwatch,
    Tee,
    print_stage,
    create_folder,
    delete_folder,
    get_paths,
)

# Display names for the source_format metadata field.
_SORTER_DISPLAY_NAMES = {
    "kilosort2": "Kilosort2",
    "kilosort4": "Kilosort4",
    "rt_sort": "RT-Sort",
}

# ---------------------------------------------------------------------------
# SpikeData conversion
# ---------------------------------------------------------------------------


def _get_noise_levels(
    recording: Any,
    return_scaled: bool = True,
    num_chunks: int = 20,
    chunk_size: int = 10000,
    seed: int = 0,
) -> np.ndarray:
    """Estimate per-channel noise using MAD on random recording chunks.

    Parameters:
        recording: SpikeInterface BaseRecording.
        return_scaled (bool): Use scaled traces.
        num_chunks (int): Number of random chunks to sample.
        chunk_size (int): Samples per chunk.
        seed (int): Random seed.

    Returns:
        noise_levels (np.ndarray): Per-channel noise, shape ``(channels,)``.
    """
    length = recording.get_num_samples()
    rng = np.random.RandomState(seed=seed)
    starts = rng.randint(0, length - chunk_size, size=num_chunks)
    chunks = []
    for s in starts:
        chunks.append(
            recording.get_traces(
                start_frame=s,
                end_frame=s + chunk_size,
                return_scaled=return_scaled,
            )
        )
    data = np.concatenate(chunks, axis=0)
    med = np.median(data, axis=0, keepdims=True)
    return np.median(np.abs(data - med), axis=0) / 0.6745


def build_spikedata(
    w_e: Any,
    rec_path: Any,
    config: Any,
    rec_chunks: Optional[list] = None,
    rec_chunk_names: Optional[list] = None,
) -> Any:
    """Convert a waveform extractor to a SpikeData with rich neuron attributes.

    This is the bridge between any sorter backend's waveform extractor
    and the sorter-agnostic downstream pipeline (curation, compilation).

    Parameters:
        w_e: Waveform extractor object (custom or SpikeInterface).
            Must provide: ``sorting``, ``recording``,
            ``sampling_frequency``, ``chans_max_all``, ``use_pos_peak``,
            ``peak_ind``, ``get_computed_template(unit_id, mode)``,
            ``ms_to_samples(ms)``, ``root_folder``.
        rec_path (str or Path): Original recording file path.
        config (SortingPipelineConfig): Pipeline configuration.
        rec_chunks (list of (int, int) or None): Frame boundaries for
            concatenated recording epochs.
        rec_chunk_names (list of str or None): File names for each epoch.

    Returns:
        sd (SpikeData): Enriched SpikeData with per-unit attributes.
    """
    from spikelab.spikedata import SpikeData

    wf_cfg = config.waveform
    sorting = w_e.sorting
    fs_Hz = float(w_e.sampling_frequency)
    rec_locations = w_e.recording.get_channel_locations()
    channel_ids = w_e.recording.get_channel_ids()

    try:
        electrode_ids = w_e.recording.get_property("electrode")
    except Exception:
        electrode_ids = None
    if electrode_ids is None:
        electrode_ids = channel_ids

    noise_levels = _get_noise_levels(w_e.recording, getattr(w_e, "return_scaled", True))

    use_pos_peak = w_e.use_pos_peak

    nbefore_compiled = w_e.ms_to_samples(wf_cfg.compiled_ms_before)
    nafter_compiled = w_e.ms_to_samples(wf_cfg.compiled_ms_after) + 1

    has_epochs = rec_chunks is not None and len(rec_chunks) > 1

    trains = []
    neuron_attributes = []
    for uid in sorting.unit_ids:
        spike_samples = sorting.get_unit_spike_train(uid)
        spike_times_ms = np.sort(spike_samples.astype(float) / fs_Hz * 1000.0)
        trains.append(spike_times_ms)

        chan_max = int(w_e.chans_max_all[uid])
        x, y = rec_locations[chan_max]

        template_mean = w_e.get_computed_template(unit_id=uid, mode="average")
        template_std = w_e.get_computed_template(unit_id=uid, mode="std")
        peak_ind_full = w_e.peak_ind

        # When scale_compiled_waveforms is False, convert µV templates
        # back to raw ADC counts for users who want raw values.
        if not wf_cfg.scale_compiled_waveforms and getattr(w_e, "return_scaled", False):
            gain = w_e.recording.get_channel_gains()
            offset = w_e.recording.get_channel_offsets()
            template_mean = ((template_mean - offset) / gain).astype(
                w_e.recording.get_dtype()
            )
            template_std = ((template_std - offset) / gain).astype(
                w_e.recording.get_dtype()
            )

        template_windowed = template_mean[
            peak_ind_full - nbefore_compiled : peak_ind_full + nafter_compiled, :
        ]

        template_abs = np.abs(template_windowed)
        peak_inds = np.argmax(template_abs, axis=0)
        amplitudes = template_abs[peak_inds, range(peak_inds.size)]
        amplitude_max = float(amplitudes[chan_max])

        noise = float(noise_levels[chan_max]) if chan_max < len(noise_levels) else 1.0
        snr = float(amplitude_max / noise) if noise > 0 else 0.0

        peak_ind_buffer = peak_ind_full - nbefore_compiled
        if wf_cfg.std_at_peak:
            stds = template_std[peak_ind_buffer + peak_inds, range(peak_inds.size)]
        else:
            nb = w_e.ms_to_samples(wf_cfg.std_over_window_ms_before)
            na = w_e.ms_to_samples(wf_cfg.std_over_window_ms_after) + 1
            stds = np.mean(
                template_std[
                    peak_ind_buffer + peak_inds - nb : peak_ind_buffer + peak_inds + na,
                    range(peak_inds.size),
                ],
                axis=0,
            )
        with np.errstate(divide="ignore", invalid="ignore"):
            std_norms_all = np.where(amplitudes > 0, stds / amplitudes, np.inf)
        std_norm = float(std_norms_all[chan_max])

        spike_train_samples = spike_samples.copy()

        attrs = {
            "unit_id": int(uid),
            "channel": chan_max,
            "channel_id": channel_ids[chan_max],
            "x": float(x),
            "y": float(y),
            "electrode": electrode_ids[chan_max],
            "template": template_mean[:, chan_max].copy(),
            "template_full": template_mean.copy(),
            "template_windowed": template_windowed.copy(),
            "template_peak_ind": int(peak_ind_full),
            "amplitude": amplitude_max,
            "amplitudes": amplitudes.copy(),
            "peak_inds": peak_inds.copy(),
            "std_norms_all": std_norms_all.copy(),
            "has_pos_peak": bool(use_pos_peak[uid]),
            "snr": snr,
            "std_norm": std_norm,
            "spike_train_samples": spike_train_samples,
        }

        # Per-epoch templates
        if has_epochs:
            wfs, sampled_indices = w_e.get_waveforms(uid, with_index=True)
            all_spike_samples = sorting.get_unit_spike_train(uid)
            epoch_templates = []
            for start_frame, end_frame in rec_chunks:
                epoch_mask = np.array(
                    [
                        start_frame <= all_spike_samples[idx] < end_frame
                        for idx in sampled_indices
                    ]
                )
                if np.any(epoch_mask):
                    epoch_wfs = wfs[epoch_mask]
                    epoch_avg = np.mean(epoch_wfs, axis=0)
                    epoch_templates.append(epoch_avg[:, chan_max].copy())
                else:
                    epoch_templates.append(np.zeros_like(template_mean[:, chan_max]))
            attrs["epoch_templates"] = epoch_templates

        wf_file = w_e.root_folder / "waveforms" / f"waveforms_{uid}.npy"
        if wf_file.exists():
            attrs["_waveforms_path"] = str(wf_file)
            attrs["_waveforms_window"] = (
                int(peak_ind_full - nbefore_compiled),
                int(peak_ind_full + nafter_compiled),
            )

        neuron_attributes.append(attrs)

    metadata = {
        "source_file": str(rec_path),
        "source_format": _SORTER_DISPLAY_NAMES.get(
            config.sorter.sorter_name, config.sorter.sorter_name
        ),
        "fs_Hz": fs_Hz,
        "channel_locations": rec_locations.copy(),
        "n_samples": int(w_e.recording.get_num_samples()),
    }
    if has_epochs:
        metadata["rec_chunks_frames"] = list(rec_chunks)
        metadata["rec_chunks_ms"] = [
            (s / fs_Hz * 1000.0, e / fs_Hz * 1000.0) for s, e in rec_chunks
        ]
        metadata["rec_chunk_names"] = list(rec_chunk_names) if rec_chunk_names else None

    return SpikeData(trains, metadata=metadata, neuron_attributes=neuron_attributes)


# ---------------------------------------------------------------------------
# Curation wrapper
# ---------------------------------------------------------------------------


def curate_spikedata(
    sd: Any, curation_folder: Any, config: Any, recurate: bool = False
) -> Tuple[Any, dict]:
    """Curate a SpikeData with disk caching.

    Reads curation thresholds from *config* and applies them via
    ``sd.curate()``.  Results are cached to *curation_folder*.

    Parameters:
        sd (SpikeData): Uncurated SpikeData.
        curation_folder (str or Path): Cache directory.
        config (SortingPipelineConfig): Pipeline configuration.
        recurate (bool): Re-run curation even when cached.

    Returns:
        sd_curated (SpikeData): Curated SpikeData.
        history (dict): Serializable curation history.
    """
    from spikelab.spikedata.curation import build_curation_history

    cur = config.curation
    curate_kwargs = {}

    if cur.curate_first:
        if cur.fr_min is not None:
            curate_kwargs["min_rate_hz"] = cur.fr_min
        if cur.isi_viol_max is not None:
            curate_kwargs["isi_max"] = cur.isi_viol_max
            curate_kwargs["isi_threshold_ms"] = 1.5
            curate_kwargs["isi_method"] = cur.isi_violation_method
        if cur.snr_min is not None:
            curate_kwargs["min_snr"] = cur.snr_min
        if cur.spikes_min_first is not None:
            curate_kwargs["min_spikes"] = cur.spikes_min_first
    if cur.curate_second:
        if cur.spikes_min_second is not None:
            curate_kwargs["min_spikes"] = cur.spikes_min_second
        if cur.std_norm_max is not None:
            curate_kwargs["max_std_norm"] = cur.std_norm_max

    curation_folder = Path(curation_folder)
    unit_ids_path = curation_folder / "unit_ids.npy"
    history_path = curation_folder / "curation_history.json"

    # Check cache
    if not recurate and unit_ids_path.exists() and history_path.exists():
        cached_ids = set(int(x) for x in np.load(str(unit_ids_path)))
        passing = [
            i
            for i in range(sd.N)
            if sd.neuron_attributes is not None
            and int(sd.neuron_attributes[i].get("unit_id", i)) in cached_ids
        ]
        sd_curated = sd.subset(passing)
        with open(history_path, "r") as f:
            history = json.load(f)
        return sd_curated, history

    # Run curation
    sd_curated, results = sd.curate(**curate_kwargs)
    history = build_curation_history(sd, sd_curated, results, parameters=curate_kwargs)

    # Save to disk
    curation_folder.mkdir(parents=True, exist_ok=True)
    np.save(str(unit_ids_path), np.array(history["curated_final"]))
    with open(history_path, "w") as f:
        json.dump(history, f, indent=2, default=str)

    return sd_curated, history


# ---------------------------------------------------------------------------
# Compiler
# ---------------------------------------------------------------------------


class Compiler:
    """Aggregates sorting results from one or more SpikeData objects for export.

    Reads unit metadata from ``neuron_attributes`` and writes combined
    ``.npz``, ``.mat``, and figure outputs.

    Parameters:
        config (SortingPipelineConfig): Pipeline configuration.
    """

    def __init__(self, config: Any) -> None:
        self.config = config
        fig = config.figures
        comp = config.compilation
        cur = config.curation

        self.create_figures = fig.create_figures
        self.create_std_scatter_plot = (
            cur.curate_second
            and cur.spikes_min_second is not None
            and cur.std_norm_max is not None
        )
        self.compile_to_mat = comp.compile_to_mat
        self.compile_to_npz = comp.compile_to_npz
        self.save_electrodes = comp.save_electrodes
        self.recs_cache = []

    def add_recording(
        self, rec_name: str, sd: Any, curation_history: Optional[dict] = None
    ) -> None:
        """Queue a recording for compilation.

        Parameters:
            rec_name (str): Short name for the recording.
            sd (SpikeData): Curated SpikeData.
            curation_history (dict or None): Curation history dict.
        """
        self.recs_cache.append((rec_name, sd, curation_history))

    def save_results(self, folder: Any) -> None:
        """Compile and save results from all queued recordings.

        Parameters:
            folder (Path or str): Output directory.
        """
        try:
            from scipy.io import savemat
        except ImportError:
            savemat = None

        create_folder(folder)
        folder = Path(folder)

        cfg = self.config
        comp = cfg.compilation
        fig = cfg.figures

        all_units = []
        rec_metadata = {}
        bar_rec_names = []
        bar_n_total = []
        bar_n_selected = []
        scatter_n_spikes = {}
        scatter_std_norms = {}
        fig_fs_Hz = None

        for rec_name, sd, curation_history in self.recs_cache:
            print(f"Adding recording: {rec_name}")

            fs_Hz = sd.metadata.get("fs_Hz", 30000.0)
            rec_metadata[rec_name] = {
                "fs": fs_Hz,
                "locations": sd.metadata.get("channel_locations"),
                "n_samples": sd.metadata.get("n_samples", 0),
            }
            if fig_fs_Hz is None:
                fig_fs_Hz = fs_Hz

            for i in range(sd.N):
                attrs = sd.neuron_attributes[i] if sd.neuron_attributes else {}
                all_units.append((attrs, True, rec_name))

            if self.create_figures:
                curated_ids = set()
                if sd.neuron_attributes is not None:
                    for attrs in sd.neuron_attributes:
                        curated_ids.add(int(attrs.get("unit_id", -1)))
                n_total = len(curated_ids)
                if curation_history is not None:
                    n_total = len(curation_history.get("initial", curated_ids))
                bar_rec_names.append(rec_name)
                bar_n_total.append(n_total)
                bar_n_selected.append(sd.N)

                if self.create_std_scatter_plot and curation_history is not None:
                    scatter_n_spikes[rec_name] = curation_history.get(
                        "metrics", {}
                    ).get("spike_count", {})
                    scatter_std_norms[rec_name] = curation_history.get(
                        "metrics", {}
                    ).get("std_norm", {})

        # Sort by polarity then amplitude
        neg_units = [u for u in all_units if not u[0].get("has_pos_peak", False)]
        pos_units = [u for u in all_units if u[0].get("has_pos_peak", False)]
        neg_units.sort(key=lambda x: float(x[0].get("amplitude", 0)), reverse=True)
        pos_units.sort(key=lambda x: float(x[0].get("amplitude", 0)), reverse=True)

        compile_dict = None
        if self.compile_to_mat or self.compile_to_npz:
            if len(rec_metadata) == 1:
                rec = list(rec_metadata.keys())[0]
                meta = rec_metadata[rec]
                compile_dict = {
                    "units": [],
                    "locations": meta["locations"],
                    "fs": meta["fs"],
                }

        if comp.compile_waveforms:
            create_folder(folder / "negative_peaks")
            create_folder(folder / "positive_peaks")

        fig_templates = []
        fig_peak_indices = []
        fig_is_curated = []
        fig_has_pos_peak = []

        sorted_index = 0
        for group_label, units_group in [
            ("negative", neg_units),
            ("positive", pos_units),
        ]:
            has_pos = group_label == "positive"
            print(
                f"\nIterating through {len(units_group)} units with "
                f"{group_label} peaks"
            )
            for attrs, is_curated, rec_name in units_group:
                if is_curated:
                    if compile_dict is not None:
                        spike_train_samples = attrs.get("spike_train_samples")
                        if comp.save_dl_data:
                            unit_dict = {
                                "unit_id": attrs.get("unit_id"),
                                "spike_train": spike_train_samples,
                                "x_max": attrs.get("x"),
                                "y_max": attrs.get("y"),
                                "template": attrs.get("template_windowed"),
                                "sorted_index": sorted_index,
                                "max_channel_si": attrs.get("channel"),
                                "max_channel_id": attrs.get("channel_id"),
                                "peak_sign": group_label,
                                "peak_ind": attrs.get("peak_inds"),
                                "amplitudes": attrs.get("amplitudes"),
                                "std_norms": attrs.get("std_norms_all"),
                            }
                        else:
                            unit_dict = {
                                "unit_id": attrs.get("unit_id"),
                                "spike_train": spike_train_samples,
                                "x_max": attrs.get("x"),
                                "y_max": attrs.get("y"),
                                "template": attrs.get("template_windowed"),
                            }
                        if self.save_electrodes:
                            unit_dict["electrode"] = attrs.get("electrode")
                        compile_dict["units"].append(unit_dict)

                    if comp.compile_waveforms:
                        wf_path = attrs.get("_waveforms_path")
                        wf_window = attrs.get("_waveforms_window")
                        if wf_path is not None:
                            waveforms = np.load(wf_path, mmap_mode="r")
                            if wf_window is not None:
                                waveforms = waveforms[:, wf_window[0] : wf_window[1], :]
                            wf_folder = (
                                folder / "positive_peaks"
                                if has_pos
                                else folder / "negative_peaks"
                            )
                            np.save(
                                wf_folder / f"waveforms_{sorted_index}.npy",
                                np.array(waveforms),
                            )

                    sorted_index += 1

                if self.create_figures:
                    fig_templates.append(attrs.get("template", np.array([])))
                    fig_peak_indices.append(attrs.get("template_peak_ind", 0))
                    fig_is_curated.append(is_curated)
                    fig_has_pos_peak.append(has_pos)

        if compile_dict is not None:
            if self.compile_to_mat and savemat is not None:
                savemat(folder / "sorted.mat", compile_dict)
                print("Compiled results to .mat")
            if self.compile_to_npz:
                np.savez(folder / "sorted.npz", **compile_dict)
                print("Compiled results to .npz")

        if self.create_figures:
            from .figures import plot_curation_bar, plot_std_scatter, plot_templates

            figures_path = folder / "figures"
            print("\nSaving figures")
            create_folder(figures_path)

            plot_curation_bar(
                bar_rec_names,
                bar_n_total,
                bar_n_selected,
                total_label=fig.bar_total_label,
                selected_label=fig.bar_selected_label,
                x_label=fig.bar_x_label,
                y_label=fig.bar_y_label,
                label_rotation=fig.bar_label_rotation,
                save_path=str(figures_path / "curation_bar_plot.png"),
            )
            print("Curation bar plot has been saved")

            if self.create_std_scatter_plot and scatter_n_spikes:
                plot_std_scatter(
                    scatter_n_spikes,
                    scatter_std_norms,
                    spikes_thresh=cfg.curation.spikes_min_second,
                    std_thresh=cfg.curation.std_norm_max,
                    colors=fig.scatter_recording_colors[:],
                    alpha=fig.scatter_recording_alpha,
                    x_label=fig.scatter_x_label,
                    y_label=fig.scatter_y_label,
                    x_max_buffer=fig.scatter_x_max_buffer,
                    y_max_buffer=fig.scatter_y_max_buffer,
                    save_path=str(figures_path / "std_scatter_plot.png"),
                )
                print("Std scatter plot has been saved")

            if fig_templates and fig_fs_Hz is not None:
                plot_templates(
                    fig_templates,
                    fig_peak_indices,
                    fig_fs_Hz,
                    fig_is_curated,
                    fig_has_pos_peak,
                    templates_per_column=fig.templates_per_column,
                    y_spacing=fig.templates_y_spacing,
                    y_lim_buffer=fig.templates_y_lim_buffer,
                    color_curated=fig.templates_color_curated,
                    color_failed=fig.templates_color_failed,
                    window_ms_before=fig.templates_window_ms_before,
                    window_ms_after=fig.templates_window_ms_after,
                    line_ms_before=fig.templates_line_ms_before,
                    line_ms_after=fig.templates_line_ms_after,
                    x_label=fig.templates_x_label,
                    save_path=str(figures_path / "all_templates_plot.png"),
                )
                print("All templates plot has been saved")


# ---------------------------------------------------------------------------
# Pipeline orchestration
# ---------------------------------------------------------------------------


def process_recording(
    backend,
    config,
    rec_name,
    rec_path,
    inter_path,
    results_path,
    rec_loaded=None,
    rec_chunks=None,
    rec_chunk_names=None,
    rng=None,
):
    """Run the full sorting pipeline on a single recording.

    Delegates loading, sorting, and waveform extraction to the
    *backend*, then handles SpikeData conversion, curation, and
    compilation using the *config*.

    Parameters:
        backend (SorterBackend): Sorter backend instance.
        config (SortingPipelineConfig): Pipeline configuration.
        rec_name (str): Short name for the recording.
        rec_path (str or Path): Path to the recording file.
        inter_path (str or Path): Root intermediate directory.
        results_path (str or Path): Root results directory.
        rec_loaded: Pre-loaded recording object, or None.
        rec_chunks (list of (int, int) or None): Epoch frame boundaries.
        rec_chunk_names (list of str or None): Epoch file names.
        rng (np.random.Generator or None): Random number generator for
            reproducible waveform sampling.  When ``None``, a new
            ``default_rng()`` is created.

    Returns:
        result (SpikeData or tuple or Exception): ``sd_curated`` on
            success, or ``(sd_raw, sd_curated)`` when
            ``config.compilation.save_raw_pkl`` is True.  Returns the
            caught exception if any stage failed.
    """
    exe = config.execution
    cur = config.curation
    comp = config.compilation

    create_folder(inter_path)
    with Tee(Path(inter_path) / exe.out_file, "a"):
        stopwatch = Stopwatch()

        (
            rec_path,
            inter_path,
            recording_dat_path,
            output_folder,
            waveforms_root_folder,
            curation_initial_folder,
            curation_first_folder,
            curation_second_folder,
            results_path,
        ) = get_paths(rec_path, inter_path, results_path, exe)

        # Load Recording
        try:
            recording_filtered = backend.load_recording(
                rec_path if rec_loaded is None else rec_loaded
            )
        except Exception as e:
            print(f"Could not open the recording file because of {e}")
            print("Moving on to next recording")
            return e

        # Spike sorting
        sorting = backend.sort(
            recording_filtered, rec_path, recording_dat_path, output_folder
        )
        if isinstance(sorting, BaseException):
            return sorting

        # Extract waveforms
        w_e_raw = backend.extract_waveforms(
            recording_filtered,
            sorting,
            waveforms_root_folder,
            curation_initial_folder,
            rec_path=rec_path,
            rng=rng,
        )

        # Convert to SpikeData
        sd = build_spikedata(
            w_e_raw,
            rec_path,
            config,
            rec_chunks=rec_chunks,
            rec_chunk_names=rec_chunk_names,
        )

        # Generate figures if create_figures is enabled.
        # Per-unit figures are generated before curation (while individual
        # waveforms are still on disk), then sorted into curated/failed
        # subdirs after curation completes.
        unit_figures_dir = Path(results_path) / "figures" / "units"
        _fig = {}
        figures_dir = Path(results_path) / "figures"
        _thresholds = {
            "fr_min": cur.fr_min,
            "isi_viol_max": cur.isi_viol_max,
            "snr_min": cur.snr_min,
            "spikes_min_second": cur.spikes_min_second,
            "std_norm_max": cur.std_norm_max,
        }

        if not config.figures.create_figures:
            print("Skipping figure generation (create_figures=False)")
        else:
            unit_figures_dir.mkdir(parents=True, exist_ok=True)
            figures_dir.mkdir(parents=True, exist_ok=True)

            _fmod = None
            try:
                from scripts import generate_sorting_figures as _fmod
            except ImportError:
                import importlib.util

                _script = (
                    Path(__file__).parents[2]
                    / "scripts"
                    / "generate_sorting_figures.py"
                )
                if _script.exists():
                    _spec = importlib.util.spec_from_file_location(
                        "generate_sorting_figures", _script
                    )
                    _fmod = importlib.util.module_from_spec(_spec)
                    _spec.loader.exec_module(_fmod)

            if _fmod is not None:
                for name in (
                    "generate_per_unit_figures",
                    "generate_quality_distributions",
                    "generate_builtin_figures",
                    "generate_raster_overview",
                ):
                    _fig[name] = getattr(_fmod, name, None)

            if (
                config.figures.create_unit_figures
                and _fig.get("generate_per_unit_figures") is not None
            ):
                print_stage("GENERATING PER-UNIT FIGURES")
                _fig["generate_per_unit_figures"](
                    sd,
                    unit_figures_dir,
                    amp_thresh_uv=15.0,
                    w_e_raw=w_e_raw,
                )
            elif not config.figures.create_unit_figures:
                print("Skipping per-unit figures (create_unit_figures=False)")

            if _fig.get("generate_quality_distributions") is not None:
                print_stage("GENERATING QUALITY DISTRIBUTIONS (ALL UNITS)")
                _fig["generate_quality_distributions"](
                    sd,
                    is_pre_curation=True,
                    thresholds=_thresholds,
                    out_dir=figures_dir,
                )

        # Curate
        has_epochs = bool(sd.metadata.get("rec_chunks_ms"))
        if cur.curation_epoch is not None and has_epochs:
            epoch_sds = sd.split_epochs()
            if cur.curation_epoch < 0 or cur.curation_epoch >= len(epoch_sds):
                raise ValueError(
                    f"curation_epoch={cur.curation_epoch} is out of range "
                    f"(recording has {len(epoch_sds)} epochs, 0-indexed)."
                )
            sd_for_curation = epoch_sds[cur.curation_epoch]
            print(
                f"Curating based on epoch {cur.curation_epoch} "
                f"({sd_for_curation.metadata.get('source_file', '')})"
            )
        else:
            sd_for_curation = sd

        sd_epoch_curated, curation_history = curate_spikedata(
            sd_for_curation,
            curation_folder=curation_first_folder,
            config=config,
            recurate=exe.recurate_first or exe.recurate_second,
        )

        # When curating on a single epoch, apply passing units to full SD
        if sd_for_curation is not sd:
            passing_ids = set()
            if sd_epoch_curated.neuron_attributes is not None:
                for attrs in sd_epoch_curated.neuron_attributes:
                    uid = attrs.get("unit_id")
                    if uid is not None:
                        passing_ids.add(int(uid))
            passing_indices = [
                i
                for i in range(sd.N)
                if sd.neuron_attributes is not None
                and int(sd.neuron_attributes[i].get("unit_id", -1)) in passing_ids
            ]
            sd_curated = sd.subset(passing_indices)
        else:
            sd_curated = sd_epoch_curated

        n_before = sd.N
        n_after = sd_curated.N
        print(
            f"Curation: {n_before} -> {n_after} units "
            f"({n_before - n_after} removed)"
        )

        # Sort per-unit figures into curated/failed subdirectories
        if unit_figures_dir.exists() and any(unit_figures_dir.glob("unit_*.png")):
            curated_ids = set()
            if sd_curated.neuron_attributes is not None:
                for attrs in sd_curated.neuron_attributes:
                    uid = attrs.get("unit_id")
                    if uid is not None:
                        curated_ids.add(int(uid))

            curated_dir = unit_figures_dir / "curated"
            failed_dir = unit_figures_dir / "failed"
            curated_dir.mkdir(exist_ok=True)
            failed_dir.mkdir(exist_ok=True)

            for png in unit_figures_dir.glob("unit_*.png"):
                try:
                    uid = int(png.stem.split("_")[1])
                except (IndexError, ValueError):
                    continue
                dest = curated_dir if uid in curated_ids else failed_dir
                shutil.move(str(png), str(dest / png.name))

            n_curated_figs = len(list(curated_dir.glob("*.png")))
            n_failed_figs = len(list(failed_dir.glob("*.png")))
            print(
                f"Per-unit figures sorted: {n_curated_figs} curated, "
                f"{n_failed_figs} failed"
            )

        # Generate remaining figures (need curated SpikeData)
        if _fig.get("generate_builtin_figures") is not None:
            print_stage("GENERATING QC FIGURES")
            _fig["generate_builtin_figures"](sd_curated, _thresholds, figures_dir)
        if _fig.get("generate_raster_overview") is not None:
            generate_raster_overview = _fig["generate_raster_overview"]
            generate_raster_overview(sd_curated, figures_dir)

        # Compile results
        compile_results(
            config,
            rec_name,
            rec_path,
            results_path,
            sd_curated,
            curation_history,
            rec_chunks,
        )

        print_stage("DONE WITH RECORDING")
        print(f"Recording: {rec_path}")
        stopwatch.log_time("Total")

        if comp.save_raw_pkl:
            return sd, sd_curated
        return sd_curated


def compile_results(
    config, rec_name, rec_path, results_path, sd, curation_history=None, rec_chunks=None
):
    """Compile and export sorting results for a single recording.

    Parameters:
        config (SortingPipelineConfig): Pipeline configuration.
        rec_name (str): Short name for the recording.
        rec_path (str or Path): Original recording file path.
        results_path (Path): Output directory.
        sd (SpikeData): Curated SpikeData.
        curation_history (dict or None): Curation history dict.
        rec_chunks (list or None): Epoch frame boundaries.
    """
    comp = config.compilation
    exe = config.execution

    compile_stopwatch = Stopwatch("COMPILING RESULTS")
    print(f"For recording: {rec_path}")
    if comp.compile_single_recording:
        if (
            not (Path(results_path) / "parameters.json").exists()
            or exe.recompile_single_recording
        ):
            print(f"Saving to path: {results_path}")
            if rec_chunks is not None and len(rec_chunks) > 1:
                epoch_sds = sd.split_epochs()
                for c, sd_chunk in enumerate(epoch_sds):
                    print(f"Compiling chunk {c}")
                    compiler = Compiler(config)
                    compiler.add_recording(rec_name, sd_chunk, curation_history)
                    compiler.save_results(Path(results_path) / f"chunk{c}")
            else:
                compiler = Compiler(config)
                compiler.add_recording(rec_name, sd, curation_history)
                compiler.save_results(results_path)
                compile_stopwatch.log_time("Done compiling results.")
        else:
            print(
                "Skipping compiling results because 'recompile_single_recording' "
                "is set to False and already compiled"
            )
    else:
        print(
            "Skipping compiling results because 'compile_single_recording' "
            "is set to False"
        )


# ---------------------------------------------------------------------------
# Generic entry points
# ---------------------------------------------------------------------------


from contextlib import contextmanager


@contextmanager
def _bounded_host_memory(frac: float = 0.8):
    """Cap the calling process's heap allocations at ``frac`` of system RAM.

    Best-effort guard against OOM during local sorting (especially RT-Sort,
    which can exhaust host RAM on long recordings or high-unit-count
    populations). Uses ``RLIMIT_DATA`` rather than ``RLIMIT_AS`` so that
    file-backed mmap regions used for recording I/O are not capped — only
    anonymous heap allocations (numpy / torch tensors) are bounded, which
    is where the OOM actually originates.

    Behaviour by platform:
        - Linux (kernel 4.7+): caps anonymous heap (brk + anonymous mmap).
          This is the strict OOM guard intended.
        - macOS / other POSIX: caps the brk segment only; large mmap
          allocations are not capped (semantics are weaker).
        - Windows: no-op with a printed notice (``resource`` module
          unavailable). Host RAM is unprotected — rely on Docker's
          ``mem_limit`` for containerised sorters, or monitor RAM
          manually for local runs.

    The original soft limit is restored on context exit so the cap does
    not leak into longer-lived sessions (e.g. notebooks).

    Parameters:
        frac (float): Fraction of total physical RAM to cap heap at.
            Defaults to ``0.8``.
    """
    try:
        import resource
    except ImportError:
        print(
            "[host memory cap] Windows detected — RLIMIT_DATA unavailable. "
            "Local sorting is not protected from host OOM. "
            "Use Docker, or monitor RAM manually."
        )
        yield
        return

    from .sorting_utils import get_system_ram_bytes

    ram_bytes = get_system_ram_bytes()
    if ram_bytes is None:
        print("[host memory cap] Could not detect system RAM; cap not enforced.")
        yield
        return

    new_soft = int(ram_bytes * frac)
    soft_orig, hard_orig = resource.getrlimit(resource.RLIMIT_DATA)
    if hard_orig != resource.RLIM_INFINITY and new_soft > hard_orig:
        new_soft = hard_orig

    try:
        resource.setrlimit(resource.RLIMIT_DATA, (new_soft, hard_orig))
    except (ValueError, OSError) as exc:
        print(f"[host memory cap] Failed to set RLIMIT_DATA: {exc}; cap not enforced.")
        yield
        return

    try:
        yield
    finally:
        try:
            resource.setrlimit(resource.RLIMIT_DATA, (soft_orig, hard_orig))
        except (ValueError, OSError):
            pass


def _print_pipeline_banner(
    sorter: str,
    rec_path: Any,
    config: "SortingPipelineConfig",
    log_path: Path,
) -> None:
    """Print an environment + system + input banner at the start of a sort.

    Captured by the surrounding ``Tee`` and persisted to the
    ``sorting_*.log`` file alongside the run's stdout.
    """
    import datetime as _dt
    import platform
    import socket
    import subprocess

    from .sorting_utils import get_system_ram_bytes, print_stage

    print_stage(f"SPIKE SORTING — {sorter.upper()}")
    print()
    print("-- Environment --")
    print(f"Started:        {_dt.datetime.now().isoformat(timespec='seconds')}")
    print(f"Host:           {socket.gethostname()}")
    print(f"Platform:       {platform.platform()}")
    print(f"Python:         {sys.version.split()[0]}")

    try:
        import spikeinterface as _si

        print(f"SpikeInterface: {_si.__version__}")
    except ImportError:
        pass

    try:
        import spikelab as _sl

        version = getattr(_sl, "__version__", "unknown")
        print(f"SpikeLab:       {version}")
    except ImportError:
        pass

    print()
    print("-- System Resources --")
    cpu_count = os.cpu_count()
    if cpu_count is not None:
        print(f"CPU cores:      {cpu_count}")

    ram_bytes = get_system_ram_bytes()
    if ram_bytes is not None:
        print(f"RAM total:      {ram_bytes / 1e9:.1f} GB")

    try:
        import resource

        soft, _hard = resource.getrlimit(resource.RLIMIT_DATA)
        if soft == resource.RLIM_INFINITY:
            print("Heap cap:       (unlimited)")
        else:
            print(f"Heap cap:       {soft / 1e9:.1f} GB (RLIMIT_DATA)")
    except ImportError:
        print("Heap cap:       (Windows — not enforced)")

    try:
        gpu_info = subprocess.check_output(
            [
                "nvidia-smi",
                "--query-gpu=name,driver_version,memory.total",
                "--format=csv,noheader",
            ],
            text=True,
            timeout=5,
        ).strip()
        print(f"GPU:            {gpu_info}")
    except (subprocess.SubprocessError, FileNotFoundError):
        print("GPU:            (nvidia-smi unavailable)")

    print()
    print("-- Run --")
    print(f"Sorter:         {sorter}")
    print(f"Use Docker:     {config.sorter.use_docker}")
    print(f"Recording:      {rec_path}")
    print(f"Log file:       {log_path}")
    print()


def _print_pipeline_summary(
    status: str,
    elapsed_s: float,
    error: Optional[BaseException] = None,
) -> None:
    """Print a closing summary banner with status, wall time, and resources."""
    import datetime as _dt
    import subprocess

    from .sorting_utils import get_system_ram_bytes, print_stage

    print()
    print_stage("SUMMARY")
    print()
    print(f"Status:         {status}")
    if error is not None:
        print(f"Error:          {type(error).__name__}: {error}")

    minutes, seconds = divmod(int(elapsed_s), 60)
    print(f"Wall time:      {minutes}m {seconds}s")

    ram_bytes = get_system_ram_bytes()
    if ram_bytes is not None:
        print(f"RAM total:      {ram_bytes / 1e9:.1f} GB")
    try:
        gpu_mem = subprocess.check_output(
            [
                "nvidia-smi",
                "--query-gpu=memory.used,memory.total",
                "--format=csv,noheader",
            ],
            text=True,
            timeout=5,
        ).strip()
        print(f"GPU memory:     {gpu_mem}")
    except (subprocess.SubprocessError, FileNotFoundError):
        pass

    print(f"Finished:       {_dt.datetime.now().isoformat(timespec='seconds')}")


[docs] def sort_recording( recording_files, config=None, sorter="kilosort2", intermediate_folders=None, results_folders=None, **kwargs, ): """Run spike sorting on one or more recordings using any registered backend. This is the primary entry point for the modular sorting pipeline. Parameters: recording_files (list): Paths to recording files or directories. Each entry is sorted independently. Directories have their contents concatenated before sorting and split back into per-file SpikeData afterward. config (SortingPipelineConfig or None): Pre-built configuration. When provided, ``**kwargs`` are applied as overrides via ``config.override()``. When None, a fresh config is built from ``sorter`` + ``**kwargs``. Preset configs are available in ``spikelab.spike_sorting.config`` (e.g. ``KILOSORT2``). sorter (str): Registered sorter backend name. Only used when ``config`` is None. Available: ``"kilosort2"``, ``"kilosort4"``. intermediate_folders (list or None): Intermediate result directories, one per recording. Auto-generated if None. results_folders (list or None): Output directories, one per recording. Auto-generated if None. **kwargs: Override individual config fields (e.g. ``snr_min=5.0``, ``use_docker=True``, ``fr_min=0.05``). See ``spikelab.spike_sorting.config`` for all available parameters, grouped by: ``RecordingConfig``, ``SorterConfig``, ``WaveformConfig``, ``CurationConfig``, ``CompilationConfig``, ``FigureConfig``, ``ExecutionConfig``. Returns: results (list[SpikeData]): One SpikeData per original recording file. For directory inputs, the concatenated recording is split back into per-file SpikeData objects. Notes: - Pickle files (``sorted_spikedata_curated.pkl`` and optionally ``sorted_spikedata.pkl``) are saved to each results folder. - ``hdf5_plugin_path`` (passed via config or kwargs) sets ``os.environ['HDF5_PLUGIN_PATH']`` before any recording is loaded. This is needed for Maxwell ``.h5`` files and applies to all backends. """ import datetime from .backends import get_backend_class from .config import SortingPipelineConfig if config is not None: if kwargs: config = config.override(**kwargs) sorter = config.sorter.sorter_name else: config = SortingPipelineConfig.from_kwargs(**kwargs) # Set HDF5 plugin path before any recording is loaded (affects all backends) if config.recording.hdf5_plugin_path is not None: import os os.environ["HDF5_PLUGIN_PATH"] = str(config.recording.hdf5_plugin_path) backend_cls = get_backend_class(sorter) backend = backend_cls(config) # Auto-generate folder paths def _rec_to_path(rec): try: from spikeinterface.core import BaseRecording as _BR except ImportError: _BR = None if _BR is not None and isinstance(rec, _BR): kw = rec._kwargs backing = kw.get("file_path") or (kw.get("file_paths") or [None])[0] if backing is None: raise ValueError( f"Cannot auto-generate intermediate_folders / " f"results_folders for a {type(rec).__name__} without a " f"backing file path. Pass `intermediate_folders` and " f"`results_folders` explicitly." ) return Path(backing) return Path(rec) if intermediate_folders is None: cur_dt = datetime.datetime.now().strftime("%y%m%d_%H%M%S_%f") intermediate_folders = [ _rec_to_path(rec).parent / f"inter_{sorter}_{cur_dt}" for rec in recording_files ] if results_folders is None: results_folders = [ _rec_to_path(rec).parent / f"sorted_{sorter}" for rec in recording_files ] # Validate if not (len(recording_files) == len(intermediate_folders) == len(results_folders)): raise ValueError( f"recording_files ({len(recording_files)}), " f"intermediate_folders ({len(intermediate_folders)}), and " f"results_folders ({len(results_folders)}) must all have " "the same length." ) # Figure settings try: import matplotlib as mpl if config.figures.create_figures: if config.figures.dpi is not None: mpl.rcParams["figure.dpi"] = config.figures.dpi if config.figures.font_size is not None: mpl.rcParams["font.size"] = config.figures.font_size except ImportError: pass rng = np.random.default_rng(config.execution.random_seed) # Main loop — wrap in a host heap cap so local sorts (especially RT-Sort) # cannot drag the workstation into swap. No-op on Windows; restored on exit. spikedata_results = [] with _bounded_host_memory(0.8): for rec_path, inter_path, res_path in zip( recording_files, intermediate_folders, results_folders ): try: from spikeinterface.core import BaseRecording except ImportError: BaseRecording = None rec_loaded = None if BaseRecording is not None and isinstance(rec_path, BaseRecording): rec_loaded = rec_path if "file_path" in rec_loaded._kwargs: rec_path = rec_loaded._kwargs["file_path"] else: rec_path = rec_loaded._kwargs["file_paths"][0] rec_name = str(rec_path).split("/")[-1].split("\\")[-1].split(".")[0] # Mirror stdout to a per-recording log file from start to finish. # The log captures the environment banner, every sorting stage, the # closing summary, and any exception traceback — making it the # canonical artefact for the post-sorting report. res_path_obj = Path(res_path) res_path_obj.mkdir(parents=True, exist_ok=True) log_ts = datetime.datetime.now().strftime("%y%m%d_%H%M%S") log_path = res_path_obj / f"sorting_{log_ts}.log" with Tee(log_path, file_mode="w"): _print_pipeline_banner(sorter, rec_path, config, log_path) t_start = time.time() result = process_recording( backend, config, rec_name, rec_path, inter_path, res_path, rec_loaded=rec_loaded, rec_chunks=config.recording.rec_chunks or None, rec_chunk_names=getattr(backend, "rec_chunk_names", None), rng=rng, ) if isinstance(result, BaseException): status = ( "OOM (MemoryError)" if isinstance(result, MemoryError) else "FAILED" ) _print_pipeline_summary(status, time.time() - t_start, error=result) continue if config.compilation.save_raw_pkl: sd_raw, sd_curated = result else: sd_curated = result # Save pickle import pickle as _pkl res_path = Path(res_path) if config.compilation.save_raw_pkl: raw_pkl = res_path / "sorted_spikedata.pkl" with open(raw_pkl, "wb") as f: _pkl.dump(sd_raw, f) print(f"Saved {sd_raw.N} raw units to {raw_pkl}") curated_pkl = res_path / "sorted_spikedata_curated.pkl" with open(curated_pkl, "wb") as f: _pkl.dump(sd_curated, f) print(f"Saved {sd_curated.N} curated units to {curated_pkl}") # Epoch splitting if sd_curated.metadata.get("rec_chunks_ms"): epoch_sds = sd_curated.split_epochs() spikedata_results.extend(epoch_sds) else: spikedata_results.append(sd_curated) if config.execution.delete_inter: import shutil as _shutil _shutil.rmtree(inter_path) _print_pipeline_summary("SUCCESS", time.time() - t_start) return spikedata_results
[docs] def sort_multistream(recording, stream_ids, config=None, sorter="kilosort2", **kwargs): """Sort a multi-stream recording across multiple stream IDs. Calls ``sort_recording`` once per stream ID, routing each stream to its own intermediate and results folders. Validates that the requested stream IDs exist in the recording file before sorting. Parameters: recording (str or Path): Path to a single multi-stream recording file (e.g. MaxTwo ``.raw.h5``) or a directory of such files. When a directory is given, all files are concatenated per stream. stream_ids (list of str): Stream identifiers to sort, e.g. ``["well000", "well001", "well002"]``. config (SortingPipelineConfig or None): Pre-built configuration. When provided, ``**kwargs`` are applied as overrides. sorter (str): Registered sorter backend name (default ``"kilosort2"``). Only used when ``config`` is None. **kwargs: Override individual config fields. The following must not be provided: - ``intermediate_folders`` and ``results_folders`` are auto-generated per stream. - ``stream_id`` is set automatically per iteration. Returns: results (dict): ``{stream_id: list[SpikeData]}``. Notes: - Stream ID validation uses SpikeInterface's extractor for the recording format. Currently supports Maxwell ``.h5`` files. For other formats, validation is skipped and invalid stream IDs will produce errors at loading time. - When *recording* is a directory of files, each file is concatenated per stream before sorting. Channel count and sampling frequency must match across files (raises ``ValueError``); mismatched channel IDs or locations produce warnings. """ import datetime if "stream_id" in kwargs: raise ValueError( "Do not pass 'stream_id' to sort_multistream — it is set " "automatically for each stream. Pass stream IDs via the " "'stream_ids' parameter instead." ) if kwargs.get("intermediate_folders") is not None: raise ValueError( "'intermediate_folders' cannot be specified for " "sort_multistream — folders are auto-generated per stream." ) if kwargs.get("results_folders") is not None: raise ValueError( "'results_folders' cannot be specified for " "sort_multistream — folders are auto-generated per stream." ) recording = Path(recording) # Validate stream IDs against the recording file h5_files = [] if recording.is_dir(): try: from natsort import natsorted except ImportError: natsorted = sorted h5_files = [ recording / name for name in natsorted( p.name for p in recording.iterdir() if p.name.endswith(".raw.h5") ) ] elif str(recording).endswith(".h5"): h5_files = [recording] if h5_files: try: from spikeinterface.extractors import MaxwellRecordingExtractor _, available_ids = MaxwellRecordingExtractor.get_streams(str(h5_files[0])) missing = [sid for sid in stream_ids if sid not in available_ids] if missing: raise ValueError( f"Stream ID(s) {missing} not found in " f"{h5_files[0].name}. Available streams: {available_ids}" ) except ImportError: pass # SI not available — skip validation results = {} for sid in stream_ids: print_stage(f"SORTING STREAM: {sid}") if recording.is_dir(): base = recording else: base = recording.parent cur_dt = datetime.datetime.now().strftime("%y%m%d_%H%M%S_%f") inter = [str(base / f"inter_{sorter}_{sid}_{cur_dt}")] res = [str(base / f"sorted_{sorter}_{sid}")] stream_results = sort_recording( recording_files=[str(recording)], config=config, sorter=sorter, intermediate_folders=inter, results_folders=res, stream_id=sid, **kwargs, ) results[sid] = stream_results return results