Source code for spikelab.spike_sorting.stim_sorting.artifact_removal

"""Stimulation artifact removal for offline electrophysiology recordings.

Removes electrical stimulation artifacts from multi-electrode array
(MEA) recordings while preserving neural spikes.  Two methods are
provided:

``"polynomial"`` (default)
    Per-event, per-channel low-order polynomial detrend.  A polynomial
    (default cubic) is fit to the non-saturated samples in the artifact
    tail — after the electrode desaturates — and subtracted.  Because
    the polynomial is far too smooth to capture spike waveforms
    (~0.5-1 ms), spikes riding on the artifact tail are preserved in
    the residual.  Saturated samples are blanked (set to zero).

``"blank"``
    Simply zeros out the entire artifact window.  Crude but useful as
    a quick sanity check or when the artifact is too variable for a
    good polynomial fit.

The polynomial detrend approach is related to SALPA (Suprathreshold
Artifact-Level Polynomial Approximation):

    Wagenaar, D. A. & Potter, S. M. (2002). Real-time multi-channel
    stimulus artifact suppression by local curve fitting. J Neurosci
    Methods, 120(2), 113-120.

SALPA fits a local polynomial in a causal (backward-looking) sliding
window and forward-extrapolates during the artifact, which is necessary
for real-time operation.  This module is designed for offline use, so it
instead looks ahead past saturation and fits the polynomial to the
actual post-saturation recovery curve, yielding a more accurate fit
without the extrapolation drift inherent in SALPA's forward prediction.

Sequential stimulation handling
    When multiple stim events occur in rapid succession (e.g. burst or
    paired-pulse protocols), the signal may re-saturate before reaching
    baseline after the previous stim.  This module dynamically detects
    whether the signal has returned to baseline-like levels after each
    desaturation.  If re-saturation occurs before baseline is reached,
    the blanking region is extended and the polynomial fit is deferred
    until after the final stim in the burst.
"""

import warnings

import numpy as np


def _auto_saturation_threshold(traces, quantile=0.999):
    """Estimate a saturation threshold from the trace amplitude distribution.

    Uses a high quantile of the absolute voltage distribution as the
    threshold.  Recordings with genuine saturation will have a hard
    clip at the ADC rail, so the quantile lands just below that clip.

    Parameters:
        traces (np.ndarray): ``(channels, samples)``.
        quantile (float): Quantile of ``|traces|`` to use.

    Returns:
        threshold (float): Absolute voltage threshold.
    """
    return float(np.quantile(np.abs(traces), quantile))


def _saturation_threshold_from_recording(
    recording, traces, frac=0.95, min_clip_samples=10
):
    """Derive a saturation threshold (µV) from gain metadata + observed extremes.

    Returns ``+inf`` (i.e. "do not blank anything") when the recording
    shows no evidence of ADC clipping.  Otherwise returns
    ``frac * round(max_uV / gain_uV_per_bit) * gain_uV_per_bit`` — the
    observed rail rounded to a whole number of raw ADC bits and pulled
    in by ``frac``.

    Saturation detection: a hard ADC clip produces many samples pinned
    at the rail (a flat top in the amplitude histogram).  A single
    large spike produces exactly one sample at the maximum.  We count
    samples within one raw bit of ``max(|traces|)``; if fewer than
    ``min_clip_samples``, we treat the recording as unsaturated and
    return ``+inf``.  This means "blank only completely saturated
    electrodes" semantics: high-amplitude artifacts that didn't reach
    the rail are left intact for the polynomial detrend.

    Why use the gain at all (vs. raw ``frac * max(|traces|)``):
      * Read from the recording's ``get_channel_gains()``, which on a
        SpikeInterface chain is propagated from the underlying integer
        extractor (e.g. ``MaxwellRecordingExtractor``).
      * Rounding the observed max to a whole number of raw bits anchors
        the threshold to a hardware-meaningful value, not a
        floating-point artefact of preprocessing arithmetic — two
        recordings of the same probe at the same gain settings produce
        the same threshold.
      * The "within one raw bit of max" tolerance for the clip-detection
        count is also gain-anchored, not arbitrary.

    Parameters:
        recording (BaseRecording or None): SpikeInterface recording
            exposing ``get_channel_gains()``.  When ``None`` or no
            gains available, falls back to a 1.0 µV/bit assumption.
        traces (np.ndarray): ``(channels, samples)`` already scaled to µV.
        frac (float): Fraction of the rail to use as threshold.  Default
            ``0.95`` — leaves ~5% margin so samples within the top bit
            of the rail still count as saturated.
        min_clip_samples (int): Minimum number of samples within one raw
            bit of ``max(|traces|)`` required to consider the recording
            saturated.  Below this, the recording is treated as
            unsaturated and the function returns ``+inf``.  Default
            ``10`` — high enough to ignore single-spike maxima and
            small numbers of outlier samples, low enough to catch even
            very sparse stimulation protocols.

    Returns:
        threshold (float, µV) — finite if saturation detected, ``+inf``
        if not.
    """
    abs_traces = np.abs(traces)
    observed_max_uV = float(np.max(abs_traces))

    # Resolve gain
    gain_uV_per_bit = 1.0
    if recording is not None:
        try:
            gains = recording.get_channel_gains()
        except (AttributeError, NotImplementedError):
            gains = None
        if gains is not None and len(gains) > 0:
            gain_uV_per_bit = max(1e-9, float(np.max(np.abs(gains))))

    # Saturation detection: how many samples sit at or just below the
    # observed max?  A hard clip pins many samples there; a single big
    # spike is just one sample.
    n_at_rail = int(np.sum(abs_traces >= observed_max_uV - gain_uV_per_bit))
    if n_at_rail < min_clip_samples:
        return float("inf")

    observed_rail_bits = round(observed_max_uV / gain_uV_per_bit)
    return frac * observed_rail_bits * gain_uV_per_bit


def _auto_baseline_threshold(traces, stim_times_ms, fs_Hz, k=5.0):
    """Estimate a baseline envelope threshold from pre-stim signal.

    Computes the median absolute deviation (MAD) of the signal in the
    2 ms window before the first stim event (or the first 2 ms of the
    recording if there's no pre-stim data), then returns
    ``median + k * MAD`` as the threshold for "signal has returned to
    baseline-like levels."

    Parameters:
        traces (np.ndarray): ``(channels, samples)``.
        stim_times_ms (np.ndarray): Corrected stim times in ms.
        fs_Hz (float): Sampling frequency in Hz.
        k (float): Multiplier on MAD.  Default 5.0.

    Returns:
        threshold (float): Baseline envelope threshold (absolute).
    """
    baseline_ms = 2.0
    baseline_samples = max(1, int(np.round(baseline_ms * fs_Hz / 1000.0)))

    if len(stim_times_ms) > 0:
        first_stim_sample = int(np.round(np.min(stim_times_ms) * fs_Hz / 1000.0))
        end = max(1, first_stim_sample)
        start = max(0, end - baseline_samples)
    else:
        start = 0
        end = min(baseline_samples, traces.shape[1])

    segment = traces[:, start:end]
    if segment.size == 0:
        return float(np.median(np.abs(traces)) * k)

    med = np.median(np.abs(segment))
    mad = np.median(np.abs(np.abs(segment) - med))
    return float(med + k * mad)


def _find_saturation_end(channel_trace, start, saturation_threshold, n_samples):
    """Find the first sample after *start* where the signal desaturates.

    Parameters:
        channel_trace (np.ndarray): 1-D voltage trace for one channel.
        start (int): Sample index to start searching from.
        saturation_threshold (float): Absolute voltage threshold.
        n_samples (int): Total number of samples in the trace.

    Returns:
        end (int): First sample index where
            ``|voltage| < saturation_threshold``, or ``n_samples`` if
            the signal never desaturates.
    """
    idx = start
    while idx < n_samples and np.abs(channel_trace[idx]) >= saturation_threshold:
        idx += 1
    return idx


def _find_saturation_end_from_mask(mask_ch, start, n_samples):
    """Variant of ``_find_saturation_end`` driven by a pre-computed
    clip mask rather than an amplitude threshold.

    When the caller has raw (pre-filter) traces available, it is more
    correct to identify saturated samples from the raw signal and pass
    a boolean mask here — bandpass filtering of a stim artifact
    produces ringing whose amplitude can exceed the raw ADC rail even
    on samples that weren't actually clipped.
    """
    idx = start
    while idx < n_samples and mask_ch[idx]:
        idx += 1
    return idx


def _signal_reached_baseline(
    channel_trace, start, baseline_threshold, window_samples, n_samples
):
    """Check whether the signal has returned to baseline-like levels.

    The signal is considered at baseline when the rolling maximum
    of ``|voltage|`` over *window_samples* consecutive samples drops
    below *baseline_threshold*.

    Parameters:
        channel_trace (np.ndarray): 1-D voltage trace.
        start (int): Sample index to start checking from.
        baseline_threshold (float): Absolute voltage threshold.
        window_samples (int): Number of consecutive sub-threshold
            samples required.
        n_samples (int): Trace length.

    Returns:
        at_baseline (bool): True if the signal reached baseline before
            the end of the trace.
        end_idx (int): Sample index where baseline was reached, or
            ``n_samples``.
    """
    consecutive = 0
    idx = start
    while idx < n_samples:
        if np.abs(channel_trace[idx]) < baseline_threshold:
            consecutive += 1
            if consecutive >= window_samples:
                return True, idx - window_samples + 1
        else:
            consecutive = 0
        idx += 1
    return False, n_samples


_MIN_DESCENT_SAMPLES = 2  # min samples between fit_start and neg-peak to split


def _polyfit_and_subtract(
    channel_trace,
    blanked,
    ch_idx,
    lo,
    hi,
    poly_order,
    clamp_threshold=None,
    clamp_counter=None,
):
    """Fit a polynomial to ``channel_trace[lo:hi]`` (excluding blanked
    samples) and subtract it in-place.

    If too few non-blanked samples remain to support the fit (e.g.
    because Fit 1 of the auto-split landed on a very short descent
    window between the recentered stim time and the negative peak),
    the region is left untouched rather than blanked — those samples
    are not saturated, just covered by a window too small for a
    reliable polynomial of this order.

    When ``clamp_threshold`` is finite, the post-subtraction segment is
    sanity-checked: if any sample exceeds ``clamp_threshold`` in
    absolute value, the polynomial fit is treated as having diverged
    (e.g. extrapolating wildly across saturated tails at high stim
    amplitudes), the segment is blanked instead of left in place, and
    ``clamp_counter[0]`` is incremented for caller-side reporting.
    """
    if hi <= lo:
        return
    x = np.arange(hi - lo, dtype=np.float64)
    y = channel_trace[lo:hi].astype(np.float64)
    mask = ~blanked[ch_idx, lo:hi]
    if np.sum(mask) <= poly_order:
        return
    coeffs = np.polyfit(x[mask], y[mask], poly_order)
    channel_trace[lo:hi] -= np.polyval(coeffs, x)

    if clamp_threshold is not None and np.isfinite(clamp_threshold):
        seg = channel_trace[lo:hi]
        if seg.size and float(np.max(np.abs(seg))) > clamp_threshold:
            seg[:] = 0.0
            blanked[ch_idx, lo:hi] = True
            if clamp_counter is not None:
                clamp_counter[0] += 1


def _process_stim_group_polynomial(
    channel_trace,
    group_start,
    last_desat,
    artifact_window_samples,
    baseline_threshold,
    baseline_window_samples,
    poly_order,
    n_samples,
    blanked,
    ch_idx,
    pre_artifact_samples=0,  # accepted for API stability, currently unused
    clip_mask_ch=None,  # accepted for API stability, currently unused
    clamp_threshold=None,
    clamp_counter=None,
):
    """Polynomial detrend for one stim group on one channel.

    Workflow per stim group:
      1. Blank ``[group_start, last_desat)`` (any genuine ADC clip).
      2. Determine the fit window
         ``[fit_start = last_desat, fit_end = last_desat + artifact_window]``,
         extending ``fit_end`` to where the signal returns to baseline.
      3. Locate the negative peak (``argmin``) inside the window, and
         the subsequent positive peak (``argmax`` after the negative
         peak) inside the window.
      4. Split the fit at the meaningful peaks and run an independent
         polynomial on each segment:

         * **3-fit split** (descent + ascent + decay) when both a
           descent of ≥ ``_MIN_DESCENT_SAMPLES`` and an ascent of
           ≥ ``_MIN_DESCENT_SAMPLES`` exist:
              - Fit A: ``[fit_start, neg_peak]`` — descent.
              - Fit B: ``[neg_peak, pos_peak]`` — ascent through zero
                up to the post-artifact positive overshoot.
              - Fit C: ``[pos_peak, fit_end]`` — decay back to
                baseline (the original implementation).
           This is the typical biphasic anodic-first case sorted with
           ``peak_mode="down_edge"``: the post-stim signal goes down,
           up through zero, may overshoot, and decays.

         * **2-fit split** (descent + tail) when there is a descent
           but no meaningful ascent before ``fit_end``:
              - Fit A: ``[fit_start, neg_peak]``
              - Fit B+C: ``[neg_peak, fit_end]`` — single tail fit.

         * **Single fit** when there's no descent (stim time is
           already at or essentially at the negative peak — e.g.
           ``peak_mode="abs_max"`` or ``"neg_peak"``):
              - Fit C: ``[fit_start, fit_end]``.

      Each segment is monotonic-ish, so a low-order polynomial (cubic)
      fits each well; one polynomial trying to fit the full
      down-up-down shape would have to interpolate two inflection
      points and leaves residuals.
    """
    # Blank from group start through desaturation
    blank_end = min(last_desat, n_samples)
    channel_trace[group_start:blank_end] = 0.0
    blanked[ch_idx, group_start:blank_end] = True

    # Determine the fit region: from desaturation through the artifact tail
    fit_start = last_desat
    fit_end = min(last_desat + artifact_window_samples, n_samples)

    if fit_start >= n_samples or fit_start >= fit_end:
        return

    # Extend fit_end to where the signal reaches baseline (if within window)
    reached, baseline_idx = _signal_reached_baseline(
        channel_trace,
        fit_start,
        baseline_threshold,
        baseline_window_samples,
        min(fit_end, n_samples),
    )
    if reached:
        # Anchor the fit polynomial to a span of clean baseline
        # samples beyond the artifact tail.  Without this anchor the
        # cubic had freedom to curl in the trailing region and left a
        # small step at the boundary between the subtracted region
        # and the un-touched baseline tail.  Extending 3 ms past
        # ``baseline_idx`` (≈1 ms for the detection window + 2 ms of
        # additional anchor) gives the polynomial enough "known-
        # baseline" points to be pulled naturally toward zero at the
        # boundary without over-extending the fit.
        fit_end = min(baseline_idx + 3 * baseline_window_samples, n_samples)

    if fit_end <= fit_start:
        return

    # Locate the negative peak in the fit window, then the subsequent
    # positive peak.  Both indices are computed on the un-modified
    # trace before any subtraction so the splits are stable.
    neg_peak_offset = int(np.argmin(channel_trace[fit_start:fit_end]))
    neg_peak_sample = fit_start + neg_peak_offset

    if neg_peak_sample + 1 < fit_end:
        pos_peak_offset_after = int(
            np.argmax(channel_trace[neg_peak_sample + 1 : fit_end])
        )
        pos_peak_sample = neg_peak_sample + 1 + pos_peak_offset_after
    else:
        pos_peak_sample = neg_peak_sample  # no room for a subsequent peak

    descent_samples = neg_peak_offset
    ascent_samples = pos_peak_sample - neg_peak_sample

    has_descent = descent_samples >= _MIN_DESCENT_SAMPLES
    has_ascent = ascent_samples >= _MIN_DESCENT_SAMPLES

    if has_descent and has_ascent:
        # 3-fit split: descent + ascent + decay
        _polyfit_and_subtract(
            channel_trace,
            blanked,
            ch_idx,
            fit_start,
            neg_peak_sample + 1,
            poly_order,
            clamp_threshold=clamp_threshold,
            clamp_counter=clamp_counter,
        )
        _polyfit_and_subtract(
            channel_trace,
            blanked,
            ch_idx,
            neg_peak_sample + 1,
            pos_peak_sample + 1,
            poly_order,
            clamp_threshold=clamp_threshold,
            clamp_counter=clamp_counter,
        )
        _polyfit_and_subtract(
            channel_trace,
            blanked,
            ch_idx,
            pos_peak_sample + 1,
            fit_end,
            poly_order,
            clamp_threshold=clamp_threshold,
            clamp_counter=clamp_counter,
        )
    elif has_descent:
        # 2-fit split: descent + tail (no positive overshoot found)
        _polyfit_and_subtract(
            channel_trace,
            blanked,
            ch_idx,
            fit_start,
            neg_peak_sample + 1,
            poly_order,
            clamp_threshold=clamp_threshold,
            clamp_counter=clamp_counter,
        )
        _polyfit_and_subtract(
            channel_trace,
            blanked,
            ch_idx,
            neg_peak_sample + 1,
            fit_end,
            poly_order,
            clamp_threshold=clamp_threshold,
            clamp_counter=clamp_counter,
        )
    else:
        # No descent — stim already at neg peak; single fit.
        _polyfit_and_subtract(
            channel_trace,
            blanked,
            ch_idx,
            fit_start,
            fit_end,
            poly_order,
            clamp_threshold=clamp_threshold,
            clamp_counter=clamp_counter,
        )


def _global_polynomial_detrend(
    channel_trace,
    window_samples,
    overlap_samples,
    saturation_threshold,
    poly_order,
    n_samples,
    blanked,
    ch_idx,
    clamp_threshold=None,
    clamp_counter=None,
):
    """Sliding-window polynomial detrend applied to an entire channel.

    Divides the trace into overlapping windows, fits a polynomial to
    the non-saturated samples in each window, and subtracts the fit.
    Overlap regions are blended with a linear crossfade to avoid
    discontinuities at window boundaries.  Saturated samples are
    blanked.

    Parameters:
        channel_trace (np.ndarray): 1-D trace (modified in-place).
        window_samples (int): Window length in samples.
        overlap_samples (int): Overlap between consecutive windows.
        saturation_threshold (float): Absolute voltage saturation level.
        poly_order (int): Polynomial order for the detrend.
        n_samples (int): Trace length.
        blanked (np.ndarray): 2-D boolean mask ``(channels, samples)``,
            modified in-place.
        ch_idx (int): Channel index for the blanked mask.
    """
    step = window_samples - overlap_samples
    if step < 1:
        step = 1

    # Pre-compute the output buffer so we can blend overlaps
    output = np.zeros(n_samples, dtype=np.float64)
    weight = np.zeros(n_samples, dtype=np.float64)

    start = 0
    while start < n_samples:
        end = min(start + window_samples, n_samples)
        seg = channel_trace[start:end].astype(np.float64)
        seg_len = end - start

        # Mark saturated samples
        sat_mask = np.abs(seg) >= saturation_threshold
        if np.any(sat_mask):
            blanked[ch_idx, start:end] |= sat_mask

        fit_mask = ~sat_mask & ~np.isnan(seg)
        if np.sum(fit_mask) > poly_order:
            x = np.arange(seg_len, dtype=np.float64)
            coeffs = np.polyfit(x[fit_mask], seg[fit_mask], poly_order)
            artifact_estimate = np.polyval(coeffs, x)
            detrended = seg - artifact_estimate
        else:
            # Not enough non-saturated samples — zero out
            detrended = np.zeros(seg_len)
            blanked[ch_idx, start:end] = True

        # Zero saturated samples in the detrended output
        detrended[sat_mask] = 0.0

        # Sanity clamp: a polynomial fit that diverged across saturated
        # samples can produce extra-physiological residuals.  Blank the
        # whole window in that case rather than ship 10+ V "neural" data.
        if (
            clamp_threshold is not None
            and np.isfinite(clamp_threshold)
            and detrended.size
            and float(np.max(np.abs(detrended))) > clamp_threshold
        ):
            detrended[:] = 0.0
            blanked[ch_idx, start:end] = True
            if clamp_counter is not None:
                clamp_counter[0] += 1

        # Build a blending window (linear ramps in overlap regions)
        w = np.ones(seg_len)
        if start > 0 and overlap_samples > 0:
            ramp_len = min(overlap_samples, seg_len)
            w[:ramp_len] = np.linspace(0, 1, ramp_len)
        if end < n_samples and overlap_samples > 0:
            ramp_len = min(overlap_samples, seg_len)
            w[-ramp_len:] = np.linspace(1, 0, ramp_len)

        output[start:end] += detrended * w
        weight[start:end] += w

        start += step

    # Normalize by blending weights
    nonzero = weight > 0
    channel_trace[nonzero] = output[nonzero] / weight[nonzero]
    channel_trace[~nonzero] = 0.0


def _maybe_warn_polynomial_clamp(counter, clamp_threshold, saturation_threshold):
    """Emit one warning per ``remove_stim_artifacts`` call when the
    polynomial divergence sanity clamp fired one or more times."""
    if counter is None or counter[0] == 0 or clamp_threshold is None:
        return
    warnings.warn(
        f"remove_stim_artifacts: polynomial fit diverged on "
        f"{counter[0]} segment(s) — exceeded clamp threshold "
        f"{clamp_threshold:.0f} (= poly_clamp_factor * "
        f"saturation_threshold = {saturation_threshold:.0f}).  Those "
        f"segments were blanked instead.  This usually indicates a stim "
        f"amplitude high enough to keep electrodes saturated through the "
        f"polynomial's fit window (e.g. >500 mV on MaxOne); consider "
        f"method='blank' for such recordings, or pass "
        f"poly_clamp_factor=None to disable this fallback.",
        UserWarning,
        stacklevel=3,
    )


[docs] def remove_stim_artifacts( traces, stim_times_ms, fs_Hz, method="polynomial", artifact_window_ms=10.0, saturation_threshold=None, baseline_threshold=None, poly_order=3, artifact_window_only=True, copy=True, *, recording=None, raw_traces=None, poly_clamp_factor=10.0, ): """Remove stimulation artifacts from multi-channel voltage traces. Processes each stim event independently per channel. Saturated samples are always blanked (zeroed). For the ``"polynomial"`` method, a low-order polynomial is fit to the post-saturation artifact tail and subtracted, preserving neural spikes (which are too fast for the smooth polynomial to capture). When multiple stim events occur in rapid succession and the signal re-saturates before reaching baseline levels, the blanking region is extended dynamically and the polynomial fit is deferred until after the final desaturation in the burst. The polynomial detrend is conceptually related to SALPA (Wagenaar & Potter 2002, J Neurosci Methods), adapted for offline processing where look-ahead past saturation is available — see the module docstring for details. Parameters: traces (np.ndarray): Raw voltage traces, shape ``(channels, samples)``. stim_times_ms (array-like): Corrected stim times in milliseconds (e.g. from ``recenter_stim_times``). fs_Hz (float): Sampling frequency in Hz. method (str): ``"polynomial"`` (default) or ``"blank"``. artifact_window_ms (float): Maximum duration in milliseconds of the artifact tail after the last desaturation point. The polynomial is fit over this window. Default 10.0. Note: when the post-stim window contains a clear descent from the recentered stim time to a subsequent negative peak (typical for biphasic anodic-first pulses sorted with ``peak_mode="down_edge"``), the fit is automatically split into two independent polynomials at the negative peak — one for ``[stim_time, neg_peak]`` (the descent) and one for ``[neg_peak, baseline_recovery]`` (the tail). When the recentered stim time IS the negative peak (e.g. ``peak_mode="abs_max"`` or ``"neg_peak"``), no descent exists and a single fit is used. This is automatic; no user knob. saturation_threshold (float or None): Absolute voltage value above which a sample is considered saturated. When None, auto-detected — preferring gain-anchored detection from ``recording`` metadata when supplied (see ``recording`` kwarg below), falling back to the 99.9th percentile of ``|traces|`` otherwise. raw_traces (np.ndarray or None): Optional pre-bandpass traces, same shape as ``traces``, used as the source of truth for saturation detection. Bandpass filtering of a stim artifact produces ringing whose filtered amplitude can exceed the raw ADC rail even on unsaturated samples, so auto-detection from ``traces`` (filtered) both over- reports (ringing overshoot) and under-reports (group-delay smoothing) clips. When provided, the threshold is derived from ``raw_traces`` and the clip mask is built from ``np.abs(raw_traces) >= threshold``; the filtered ``traces`` are blanked at those same sample indices and polynomial- detrended around them. baseline_threshold (float or None): Absolute voltage envelope below which the signal is considered to have returned to baseline. When None, auto-detected from pre-stim MAD. poly_order (int): Polynomial order for the detrend. Default 3 (cubic). Higher orders risk fitting spike-like features; lower orders may not capture the artifact decay shape. artifact_window_only (bool): If True (default), only process windows around stim events. If False, apply a global polynomial detrend to the entire trace (for recordings with very frequent stimulation). copy (bool): If True (default), return a copy; if False, modify ``traces`` in-place. poly_clamp_factor (float or None): Sanity-clamp factor for the ``"polynomial"`` method. After each polynomial subtraction, if any post-subtract sample exceeds ``poly_clamp_factor * saturation_threshold`` in absolute value, the segment is treated as a divergent fit (extrapolated wildly across saturated samples), blanked instead of left in place, and counted toward a one-shot warning emitted at the end of the call. Default ``10.0`` — well above any plausible neural amplitude (~100 µV) when ``saturation_threshold`` is in the multi-thousand-µV range. Set to ``None`` to disable. Has no effect when ``saturation_threshold`` is ``+inf`` (no clipping detected) or ``method="blank"``. Returns: cleaned (np.ndarray): Cleaned traces, shape ``(channels, samples)``. blanked_mask (np.ndarray): Boolean array, shape ``(channels, samples)``. True for samples that were blanked (zeroed) because they fell within a saturation region. """ stim_times_ms = np.asarray(stim_times_ms, dtype=np.float64) if copy: traces = traces.copy() n_channels, n_samples = traces.shape blanked = np.zeros((n_channels, n_samples), dtype=bool) if len(stim_times_ms) == 0: return traces, blanked if method not in ("polynomial", "blank"): raise ValueError( f"Unknown artifact removal method {method!r}; " "expected 'polynomial' or 'blank'." ) if traces.shape[0] == 0 or traces.shape[1] == 0: raise ValueError( f"traces must have at least one channel and one sample, " f"got shape {traces.shape}" ) # Pick the source of truth for saturation detection. Prefer raw # (pre-bandpass) traces when provided — filter ringing after a stim # artifact can drive filtered samples past the raw ADC rail even # when nothing was actually clipped, so detecting on the filtered # signal both over-reports clips (filter overshoot) and under- # reports them (group delay + smoothing). ``raw_traces`` is # typically the un-filtered ``ScaleRecording`` output extracted by # the caller; for a filtered-only path pass nothing and the # filtered ``traces`` will be used. detection_traces = raw_traces if raw_traces is not None else traces # Auto-detect thresholds. Prefer gain-anchored detection when a # recording object is provided — anchors the threshold to actual # ADC bit boundaries and returns +inf when no clipping is detected, # so non-saturated recordings are left alone. Falls back to the # quantile-based heuristic when no recording metadata is available. if saturation_threshold is None: if recording is not None: saturation_threshold = _saturation_threshold_from_recording( recording, detection_traces ) else: saturation_threshold = _auto_saturation_threshold(detection_traces) # Pre-compute the clip mask once, from detection_traces. All # downstream saturation checks read this mask instead of re- # computing ``|trace| >= threshold`` against the filtered signal. # When ``saturation_threshold`` is ``+inf`` (no clipping detected) # the mask is all-False and all blanking logic short-circuits. clip_mask = np.abs(detection_traces) >= saturation_threshold if baseline_threshold is None: baseline_threshold = _auto_baseline_threshold(traces, stim_times_ms, fs_Hz) artifact_window_samples = int(np.round(artifact_window_ms * fs_Hz / 1000.0)) baseline_window_samples = max( 1, int(np.round(1.0 * fs_Hz / 1000.0)) # 1 ms of consecutive samples ) # Sanity-clamp threshold for the polynomial fit. Inactive when the # caller disabled it, when no clipping was detected (saturation # threshold = +inf), or when method != "polynomial". if ( method == "polynomial" and poly_clamp_factor is not None and np.isfinite(saturation_threshold) ): poly_clamp_threshold = float(poly_clamp_factor) * float(saturation_threshold) else: poly_clamp_threshold = None poly_clamp_counter = [0] # Convert stim times to sample indices and sort stim_samples = np.round(stim_times_ms * fs_Hz / 1000.0).astype(int) stim_samples = np.sort(stim_samples) stim_samples = stim_samples[(stim_samples >= 0) & (stim_samples < n_samples)] if len(stim_samples) == 0: return traces, blanked if not artifact_window_only: # Global mode: apply a sliding-window polynomial detrend to the # entire recording. Useful when stimulation is so frequent that # artifact windows overlap or cover most of the trace, or when # stim timing information is unavailable. overlap_samples = artifact_window_samples // 2 for ch in range(n_channels): if method == "polynomial": _global_polynomial_detrend( traces[ch], artifact_window_samples, overlap_samples, saturation_threshold, poly_order, n_samples, blanked, ch, clamp_threshold=poly_clamp_threshold, clamp_counter=poly_clamp_counter, ) elif method == "blank": # Global blank: blank only the saturated samples sat = clip_mask[ch] traces[ch, sat] = 0.0 blanked[ch, sat] = True _maybe_warn_polynomial_clamp( poly_clamp_counter, poly_clamp_threshold, saturation_threshold ) return traces, blanked # Process each channel independently for ch in range(n_channels): ch_trace = traces[ch] # Group stim events that form a sequential burst. # Walk through sorted stim samples; after each stim, find where # saturation ends. If the signal re-saturates or hasn't reached # baseline before the next stim, merge into the same group. i = 0 while i < len(stim_samples): group_start = max(0, stim_samples[i]) # Walk forward through this stim and any sequential stims current_stim_idx = i last_desat = _find_saturation_end_from_mask( clip_mask[ch], group_start, n_samples ) while True: # Check if the next stim event is before the signal # reaches baseline next_idx = current_stim_idx + 1 if next_idx < len(stim_samples): next_stim = stim_samples[next_idx] # Has signal reached baseline before the next stim? reached, _ = _signal_reached_baseline( ch_trace, last_desat, baseline_threshold, baseline_window_samples, min(next_stim, n_samples), ) if not reached: # Signal hasn't recovered — merge with next stim current_stim_idx = next_idx new_desat = _find_saturation_end_from_mask( clip_mask[ch], next_stim, n_samples, ) last_desat = max(last_desat, new_desat) continue # Either no more stim events, or signal reached baseline break # Now process this group if method == "polynomial": _process_stim_group_polynomial( ch_trace, group_start, last_desat, artifact_window_samples, baseline_threshold, baseline_window_samples, poly_order, n_samples, blanked, ch, clamp_threshold=poly_clamp_threshold, clamp_counter=poly_clamp_counter, ) elif method == "blank": blank_end = min(last_desat + artifact_window_samples, n_samples) ch_trace[group_start:blank_end] = 0.0 blanked[ch, group_start:blank_end] = True # Advance past all stim events in this group i = current_stim_idx + 1 _maybe_warn_polynomial_clamp( poly_clamp_counter, poly_clamp_threshold, saturation_threshold ) return traces, blanked