"""Unit curation methods for SpikeData objects.
Each public function accepts a SpikeData as its first argument and returns
``(SpikeData, result_dict)`` where *result_dict* always contains:
- ``metric`` — ``np.ndarray (N,)`` with the per-unit metric value
(computed over **all** original units).
- ``passed`` — ``np.ndarray (N,)`` boolean mask indicating which units
passed the curation criterion.
The returned SpikeData contains only the passing units (via ``subset``).
These functions are bound as methods on ``SpikeData`` by
``spikedata.py`` so they can be called as ``sd.curate_by_*(…)``.
"""
import warnings
import numpy as np
from .utils import compute_cosine_similarity_with_lag
from spikelab.spike_sorting._exceptions import EmptyWaveformMetricsError
[docs]
def curate_by_min_spikes(sd, min_spikes=30):
"""Remove units with fewer than *min_spikes* spikes.
Parameters:
sd (SpikeData): Source spike data.
min_spikes (int): Minimum spike count threshold.
Returns:
sd_out (SpikeData): SpikeData with only passing units.
result (dict): ``{"metric": (N,) spike counts, "passed": (N,) bool mask}``.
"""
metric = np.array([len(t) for t in sd.train], dtype=float)
passed = metric >= min_spikes
return sd.subset(np.where(passed)[0]), {"metric": metric, "passed": passed}
[docs]
def curate_by_firing_rate(sd, min_rate_hz=0.05):
"""Remove units whose firing rate is below *min_rate_hz*.
Parameters:
sd (SpikeData): Source spike data.
min_rate_hz (float): Minimum firing rate in Hz.
Returns:
sd_out (SpikeData): SpikeData with only passing units.
result (dict): ``{"metric": (N,) firing rates in Hz, "passed": (N,) bool mask}``.
"""
duration_s = sd.length / 1000.0
if duration_s <= 0:
metric = np.zeros(sd.N, dtype=float)
else:
metric = np.array([len(t) / duration_s for t in sd.train], dtype=float)
passed = metric >= min_rate_hz
return sd.subset(np.where(passed)[0]), {"metric": metric, "passed": passed}
[docs]
def curate_by_isi_violations(
sd, max_violation=0.01, threshold_ms=1.5, min_isi_ms=0.0, method="percent"
):
"""Remove units with excessive inter-spike-interval violations.
Two methods are available:
- ``"percent"`` — violation count divided by total spike count,
expressed as a **fraction** in ``[0, 1]`` (e.g. ``0.01`` means 1 %
of spikes are ISI violations). The ``"percent"`` name is kept for
backward compatibility with prior versions; the value is now a
fraction, not a percentage.
- ``"hill"`` — violation rate ratio from Hill et al. (2011)
J Neurosci 31:8699-8705. Values above 1 indicate highly
contaminated units.
.. deprecated:: 0.105
With ``method="percent"``, ``max_violation`` is now a fraction
(``0.01`` = 1 % of spikes) instead of a percent value
(``1.0`` = 1 %). Passing a value ``>= 1.0`` with
``method="percent"`` emits a :class:`DeprecationWarning` and is
auto-converted by dividing by 100. The legacy default ``1.0``
is therefore treated as ``0.01``. This compatibility shim will
be removed in a future release.
Parameters:
sd (SpikeData): Source spike data.
max_violation (float): Maximum allowed metric. With
``method="percent"`` this is a fraction in ``[0, 1]``
(default ``0.01`` = 1 % of spikes). With ``method="hill"``
it is a contamination ratio.
threshold_ms (float): Refractory period threshold in ms.
min_isi_ms (float): Minimum possible ISI enforced by hardware or
post-processing, in ms.
method (str): ``"percent"`` or ``"hill"``.
Returns:
sd_out (SpikeData): SpikeData with only passing units.
result (dict): ``{"metric": (N,) ISI violation metric, "passed": (N,) bool mask}``.
"""
if method not in ("percent", "hill"):
raise ValueError(f"method must be 'percent' or 'hill', got '{method}'")
if method == "percent" and max_violation is not None and max_violation >= 1.0:
legacy_value = max_violation
max_violation = max_violation / 100.0
warnings.warn(
(
f"curate_by_isi_violations: max_violation={legacy_value!r} "
f"(>= 1.0) with method='percent' is interpreted as a legacy "
f"percent value and auto-converted to {max_violation!r}. As of "
f"this release, max_violation is a fraction (0.01 = 1% of "
f"spikes). Pass {max_violation!r} explicitly to silence this "
f"warning. The compatibility shim will be removed in a future "
f"release."
),
DeprecationWarning,
stacklevel=2,
)
duration_s = sd.length / 1000.0
threshold_s = threshold_ms / 1000.0
min_isi_s = min_isi_ms / 1000.0
metric = np.zeros(sd.N, dtype=float)
for i, train in enumerate(sd.train):
n_spikes = len(train)
if n_spikes < 2:
metric[i] = 0.0
continue
isis = np.diff(train) # already in ms
violation_count = np.sum(isis < threshold_ms)
if method == "hill":
violation_time = 2 * n_spikes * (threshold_s - min_isi_s)
total_rate = n_spikes / duration_s if duration_s > 0 else 0.0
violation_rate = (
violation_count / violation_time if violation_time > 0 else 0.0
)
metric[i] = violation_rate / total_rate if total_rate > 0 else 0.0
else:
metric[i] = violation_count / n_spikes
passed = metric <= max_violation
return sd.subset(np.where(passed)[0]), {"metric": metric, "passed": passed}
[docs]
def curate_by_snr(sd, min_snr=5.0, ms_before=1.0, ms_after=2.0):
"""Remove units whose signal-to-noise ratio is below *min_snr*.
SNR is defined as ``peak_amplitude / noise_level`` where peak
amplitude is the absolute maximum of the average waveform on the
channel with the largest amplitude, and noise level is estimated
via the median absolute deviation (MAD) of the raw trace on that
channel.
The method first checks for a precomputed ``"snr"`` value in
``neuron_attributes``. If not found, it computes SNR from
``raw_data`` (using ``get_waveform_traces``). If neither is
available a ``ValueError`` is raised.
Parameters:
sd (SpikeData): Source spike data.
min_snr (float): Minimum SNR threshold.
ms_before (float): ms before spike for waveform extraction
(only used when computing from raw_data).
ms_after (float): ms after spike for waveform extraction
(only used when computing from raw_data).
Returns:
sd_out (SpikeData): SpikeData with only passing units.
result (dict): ``{"metric": (N,) per-unit SNR, "passed": (N,) bool mask}``.
"""
metric = _get_or_compute_waveform_metric(sd, "snr", ms_before, ms_after)
passed = metric >= min_snr
return sd.subset(np.where(passed)[0]), {"metric": metric, "passed": passed}
[docs]
def curate_by_std_norm(
sd,
max_std_norm=1.0,
at_peak=True,
window_ms_before=0.5,
window_ms_after=1.5,
ms_before=1.0,
ms_after=2.0,
):
"""Remove units whose normalized waveform standard deviation exceeds
*max_std_norm*.
Normalized STD is ``|std| / |amplitude|`` on the channel with the
largest amplitude. When *at_peak* is True, STD is measured at the
single peak sample; otherwise it is averaged over a window around
the peak.
The method first checks for a precomputed ``"std_norm"`` value in
``neuron_attributes``. If not found, it computes the metric from
``raw_data``. If neither is available a ``ValueError`` is raised.
Parameters:
sd (SpikeData): Source spike data.
max_std_norm (float): Maximum allowed normalized STD.
at_peak (bool): Measure STD at peak sample only.
window_ms_before (float): Window before peak for averaging STD
(only used when *at_peak* is False).
window_ms_after (float): Window after peak for averaging STD
(only used when *at_peak* is False).
ms_before (float): ms before spike for waveform extraction
(only used when computing from raw_data).
ms_after (float): ms after spike for waveform extraction
(only used when computing from raw_data).
Returns:
sd_out (SpikeData): SpikeData with only passing units.
result (dict): ``{"metric": (N,) normalized STD, "passed": (N,) bool mask}``.
"""
metric = _get_or_compute_waveform_metric(
sd,
"std_norm",
ms_before,
ms_after,
at_peak=at_peak,
window_ms_before=window_ms_before,
window_ms_after=window_ms_after,
)
passed = metric <= max_std_norm
return sd.subset(np.where(passed)[0]), {"metric": metric, "passed": passed}
[docs]
def curate(
sd,
min_spikes=None,
min_rate_hz=None,
isi_max=None,
isi_threshold_ms=1.5,
isi_min_ms=0.0,
isi_method="percent",
min_snr=None,
max_std_norm=None,
std_at_peak=True,
std_window_ms_before=0.5,
std_window_ms_after=1.5,
snr_ms_before=1.0,
snr_ms_after=2.0,
):
"""Apply multiple curation criteria in sequence (intersection).
Only criteria whose threshold is not None are applied. Returns the
filtered SpikeData and a dict of per-criterion results.
Parameters:
sd (SpikeData): Source spike data.
min_spikes (int or None): Minimum spike count.
min_rate_hz (float or None): Minimum firing rate in Hz.
isi_max (float or None): Maximum ISI violation metric.
isi_threshold_ms (float): Refractory period for ISI check.
isi_min_ms (float): Minimum possible ISI for ISI check.
isi_method (str): ``"percent"`` or ``"hill"`` for ISI check.
min_snr (float or None): Minimum SNR.
max_std_norm (float or None): Maximum normalized STD.
std_at_peak (bool): Measure STD at peak only.
std_window_ms_before (float): Window before peak for STD averaging.
std_window_ms_after (float): Window after peak for STD averaging.
snr_ms_before (float): ms before spike for waveform extraction.
snr_ms_after (float): ms after spike for waveform extraction.
Returns:
sd_out (SpikeData): SpikeData with only units passing all criteria.
results (dict): Mapping from criterion name to ``{"metric": (N,), "passed": (N,)}``.
"""
results = {}
current = sd
if min_spikes is not None:
current, res = curate_by_min_spikes(current, min_spikes=min_spikes)
results["spike_count"] = res
if min_rate_hz is not None:
current, res = curate_by_firing_rate(current, min_rate_hz=min_rate_hz)
results["firing_rate"] = res
if isi_max is not None:
current, res = curate_by_isi_violations(
current,
max_violation=isi_max,
threshold_ms=isi_threshold_ms,
min_isi_ms=isi_min_ms,
method=isi_method,
)
results["isi_violation"] = res
if min_snr is not None:
current, res = curate_by_snr(
current,
min_snr=min_snr,
ms_before=snr_ms_before,
ms_after=snr_ms_after,
)
results["snr"] = res
if max_std_norm is not None:
current, res = curate_by_std_norm(
current,
max_std_norm=max_std_norm,
at_peak=std_at_peak,
window_ms_before=std_window_ms_before,
window_ms_after=std_window_ms_after,
ms_before=snr_ms_before,
ms_after=snr_ms_after,
)
results["std_norm"] = res
return current, results
[docs]
def build_curation_history(sd_original, sd_curated, results, parameters=None):
"""Translate curation results into a serializable history dict.
The output format mirrors the curation history produced by the
Kilosort2 pipeline, making it suitable for saving as JSON.
Parameters:
sd_original (SpikeData): The SpikeData **before** curation.
sd_curated (SpikeData): The SpikeData **after** curation.
results (dict): Results dict returned by ``curate()`` or
assembled manually from individual ``curate_by_*`` calls.
Keys are criterion names, values are dicts with ``"metric"``
and ``"passed"`` arrays.
parameters (dict or None): Curation parameter values to record.
If None, an empty dict is stored.
Returns:
history (dict): Serializable curation history with keys:
``curation_parameters``, ``initial``, ``curations``,
``curated``, ``failed``, ``metrics``, ``curated_final``.
"""
# Resolve unit IDs: use neuron_attributes["unit_id"] if available,
# otherwise fall back to positional indices.
def _unit_ids(sd):
if sd.neuron_attributes is not None:
ids = [a.get("unit_id") for a in sd.neuron_attributes]
if all(uid is not None for uid in ids):
return [int(uid) for uid in ids]
return list(range(sd.N))
original_ids = _unit_ids(sd_original)
final_ids = _unit_ids(sd_curated)
curations = []
curated = {}
failed = {}
metrics = {}
# Walk through results in insertion order. Each result was computed
# on the SpikeData that entered that stage (after previous filters),
# but the metric and passed arrays are indexed relative to that
# stage's input. We need to map back to the original unit IDs.
#
# Because curate() applies criteria sequentially, each stage's input
# is a subset of the original. We track the surviving ID list to
# perform the mapping.
surviving_ids = list(original_ids)
for criterion, res in results.items():
curations.append(criterion)
metric_arr = res["metric"]
passed_arr = res["passed"]
stage_curated = []
stage_failed = []
stage_metrics = {}
for j, uid in enumerate(surviving_ids):
stage_metrics[uid] = float(metric_arr[j])
if passed_arr[j]:
stage_curated.append(uid)
else:
stage_failed.append(uid)
curated[criterion] = stage_curated
failed[criterion] = stage_failed
metrics[criterion] = stage_metrics
# Update survivors for the next stage
surviving_ids = stage_curated
return {
"curation_parameters": parameters if parameters is not None else {},
"initial": original_ids,
"curations": curations,
"curated": curated,
"failed": failed,
"metrics": metrics,
"curated_final": final_ids,
}
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _estimate_noise_levels(raw_data, num_chunks=20, chunk_size=10000, seed=0):
"""Estimate per-channel noise via MAD on random chunks of *raw_data*.
Parameters:
raw_data (np.ndarray): Shape ``(channels, time)``.
num_chunks (int): Number of random chunks to sample.
chunk_size (int): Samples per chunk.
seed (int): Random seed.
Returns:
noise (np.ndarray): Shape ``(channels,)``.
"""
rng = np.random.default_rng(seed)
n_channels, n_samples = raw_data.shape
max_start = n_samples - chunk_size
if max_start <= 0:
# Recording shorter than one chunk — use all data
data = raw_data
else:
starts = rng.integers(0, max_start, size=num_chunks)
chunks = [raw_data[:, s : s + chunk_size] for s in starts]
data = np.concatenate(chunks, axis=1)
# MAD-based noise estimate: median(|x - median(x)|) / 0.6745
medians = np.median(data, axis=1, keepdims=True)
noise = np.median(np.abs(data - medians), axis=1) / 0.6745
return noise
def _get_or_compute_waveform_metric(sd, metric_name, ms_before, ms_after, **kwargs):
"""Try to read a precomputed metric from neuron_attributes, fall back
to computing from raw_data, or raise if neither is available.
Returns:
metric (np.ndarray): Shape ``(N,)``.
"""
# 1. Check neuron_attributes for precomputed values
if sd.neuron_attributes is not None:
values = []
for attrs in sd.neuron_attributes:
val = attrs.get(metric_name)
if val is None:
break
values.append(float(val))
if len(values) == sd.N:
return np.array(values, dtype=float)
# 2. Fall back to computing from raw_data
if sd.raw_data.size > 0:
at_peak = kwargs.get("at_peak", True)
window_ms_before = kwargs.get("window_ms_before", 0.5)
window_ms_after = kwargs.get("window_ms_after", 1.5)
_, metrics = compute_waveform_metrics(
sd,
ms_before=ms_before,
ms_after=ms_after,
at_peak=at_peak,
window_ms_before=window_ms_before,
window_ms_after=window_ms_after,
)
return metrics[metric_name]
# 3. Neither available
raise EmptyWaveformMetricsError(
f"Cannot compute '{metric_name}': no precomputed values in "
"neuron_attributes and raw_data is empty. Call "
"compute_waveform_metrics() first, or attach raw voltage traces.",
metric_name=metric_name,
)
# ---------------------------------------------------------------------------
# Merge-based deduplication
# ---------------------------------------------------------------------------
def _find_nearby_unit_pairs(sd, dist_um=24.8):
"""Return all pairs of units whose electrode positions are within distance.
Uses ``sd.unit_locations`` (which normalizes across ``"location"``,
``"x"/"y"``, and ``"position"`` keys) so this works regardless of
which loader populated the SpikeData object.
Parameters:
sd (SpikeData): spike data
dist_um (float): Maximum inter-electrode distance in um.
Default 24.8 accounts for the 24.7 µm electrode neighbourhood
radius plus floating-point tolerance.
Returns:
pairs (set[tuple[int, int]]): Set of (i, j) index tuples with i < j.
"""
locations = sd.unit_locations
if locations is None:
raise ValueError(
"sd.unit_locations is None. Position data is required to find nearby pairs."
)
pairs = set()
for i in range(sd.N):
for j in range(i + 1, sd.N):
pos_i, pos_j = locations[i], locations[j]
dist = np.sqrt(np.sum((pos_i[:2] - pos_j[:2]) ** 2))
if dist <= dist_um:
pairs.add((i, j))
return pairs
def _filter_pairs_by_isi_violations(
sd, pairs, max_violation_rate=0.04, threshold_ms=1.5
):
"""Remove pairs where either unit exceeds the ISI violation rate threshold.
ISI violation rate is n_violations / n_spikes where a violation is any
inter-spike interval shorter than threshold_ms.
Parameters:
sd (SpikeData): spike data
pairs (set[tuple[int, int]]): Candidate unit-index pairs.
max_violation_rate (float): Maximum allowed violation rate as a
fraction (not percent). Default 0.04 (4 %).
threshold_ms (float): Refractory period threshold in ms.
Returns:
filtered_pairs (set[tuple[int, int]]): Pairs where both units pass.
violation_rates (dict[int, float]): Per-unit violation rates for
every unit that appeared in pairs.
"""
if not pairs:
return set(), {}
units_in_pairs = {u for pair in pairs for u in pair}
violation_rates = {
u: _isi_violation_fraction(sd.train[u], threshold_ms) for u in units_in_pairs
}
filtered_pairs = {
(i, j)
for i, j in pairs
if violation_rates[i] <= max_violation_rate
and violation_rates[j] <= max_violation_rate
}
return filtered_pairs, violation_rates
def _compute_pairwise_similarity(sd, pairs, max_lag=10):
"""Compute cosine similarity for candidate pairs using flat concatenated waveforms.
Each unit is represented as a single 1-D array built by concatenating
avg_waveform in a globally consistent channel order (sorted numerically).
Channels absent for a unit are zero-padded. No spatial weighting or channel
selection is applied.
Requires ``avg_waveform`` (shape: n_channels x n_samples) and
``traces_meta["channels"]`` in neuron_attributes, as populated by
``sd.get_waveforms()`` or ``sd.get_waveform_traces(store=True)``.
Parameters:
sd (SpikeData): Source spike data with neuron_attributes.
pairs (set[tuple[int, int]]): Candidate unit-index pairs to evaluate.
max_lag (int): Maximum lag in samples for cosine similarity alignment.
Pairs whose best match falls at the boundary (abs(lag) == max_lag)
are assigned 0 similarity. Default 10.
Returns:
similarity_matrix (np.ndarray): Shape (N, N). NaN for unevaluated
pairs; 1.0 on the diagonal.
lag_matrix (np.ndarray): Shape (N, N). Best lag in samples; NaN
for unevaluated pairs; 0 on the diagonal.
unit_ids (list): neuron_attributes["unit_id"] or index fallback,
one entry per unit.
"""
if sd.neuron_attributes is None:
raise ValueError(
"neuron_attributes is None. Waveform data is required for similarity computation."
)
n = sd.N
sim_mat = np.full((n, n), np.nan)
lag_mat = np.full((n, n), np.nan)
np.fill_diagonal(sim_mat, 1.0)
np.fill_diagonal(lag_mat, 0.0)
unit_ids = [attrs.get("unit_id", i) for i, attrs in enumerate(sd.neuron_attributes)]
if not pairs:
return sim_mat, lag_mat, unit_ids
# Build a single global channel list from all units, sorted numerically.
all_channels = set()
wf_lengths = []
for attrs in sd.neuron_attributes:
avg_wf = attrs.get("avg_waveform")
if avg_wf is None:
continue
wf_lengths.append(avg_wf.shape[1])
traces_meta = attrs.get("traces_meta", {})
for ch in traces_meta.get("channels", []):
all_channels.add(int(ch))
if not wf_lengths:
raise ValueError(
"No units have 'avg_waveform' in neuron_attributes. "
"Load waveform data before calling _compute_pairwise_similarity()."
)
if not all_channels:
raise ValueError(
"avg_waveform found in neuron_attributes but no unit has "
"traces_meta['channels']. Call compute_waveform_metrics() or "
"get_waveform_traces(store=True) to populate both keys together."
)
global_channels = sorted(all_channels)
template_len = max(wf_lengths)
# Pre-build a 1-D array for every unit using the shared channel order.
unit_arrays = {
i: _build_1d_array_for_channels(sd, i, global_channels, template_len)
for i in range(n)
}
for i, j in pairs:
arr_i = unit_arrays[i]
arr_j = unit_arrays[j]
if not np.any(arr_i) or not np.any(arr_j):
continue
sim, best_lag = compute_cosine_similarity_with_lag(
arr_i, arr_j, max_lag=max_lag
)
if abs(best_lag) == max_lag:
sim = 0.0
sim_mat[i, j] = sim_mat[j, i] = sim
lag_mat[i, j] = best_lag
lag_mat[j, i] = -best_lag
return sim_mat, lag_mat, unit_ids
def _filter_by_cosine_sim(pairs, similarity_matrix, threshold=0.9):
"""Return the subset of pairs whose cosine similarity meets threshold.
Parameters:
pairs (set[tuple[int, int]]): Candidate unit-index pairs.
similarity_matrix (np.ndarray): Shape (N, N) similarity values,
e.g. from _compute_pairwise_similarity().
threshold (float): Minimum cosine similarity to retain a pair.
Returns:
filtered_pairs (set[tuple[int, int]]): Pairs passing the threshold.
"""
return {
(i, j)
for i, j in pairs
if not np.isnan(similarity_matrix[i, j])
and similarity_matrix[i, j] >= threshold
}
[docs]
def curate_by_merge_duplicates(
sd,
dist_um=24.8,
max_violation_rate=0.04,
isi_threshold_ms=1.5,
cosine_threshold=0.5,
max_lag=10,
delta_ms=0.4,
max_isi_increase=0.04,
verbose=False,
):
"""Remove duplicate units by merging nearby pairs with similar waveforms.
Runs the full merge-based deduplication pipeline:
1. Find spatially nearby unit pairs within dist_um.
2. Discard pairs where either unit exceeds the ISI violation threshold.
3. Compute pairwise cosine waveform similarity.
4. Discard pairs below cosine_threshold.
5. Greedily merge accepted pairs; a merge is rejected if the ISI
violation fraction increases by more than max_isi_increase.
Requires neuron_attributes with position and avg_waveform
entries. Unlike other curate_by_* functions this merges spike
trains rather than simply removing units.
Parameters:
sd (SpikeData): spike data.
dist_um (float): Maximum inter-electrode distance in µm to consider
a pair as candidate duplicates.
max_violation_rate (float): Maximum ISI violation rate (fraction,
not percent) for a unit to participate in a merge.
isi_threshold_ms (float): Refractory period threshold in ms.
cosine_threshold (float): Minimum cosine similarity to merge a pair.
max_lag (int): Maximum lag in samples for cosine similarity alignment.
delta_ms (float): Spike deduplication window in ms when merging trains.
max_isi_increase (float): Maximum allowable absolute increase in ISI
violation fraction after merging.
verbose (bool): Print per-pair merge decisions.
Returns:
sd_out (SpikeData): SpikeData with merged units.
result (dict): ``{"metric": (N,) cosine similarity to merge partner (0 if unmerged), "passed": (N,) bool mask of retained units}``.
"""
metric = np.zeros(sd.N, dtype=float)
passed = np.ones(sd.N, dtype=bool)
pairs = _find_nearby_unit_pairs(sd, dist_um=dist_um)
if not pairs:
return sd.subset(np.arange(sd.N)), {"metric": metric, "passed": passed}
pairs, _ = _filter_pairs_by_isi_violations(
sd, pairs, max_violation_rate=max_violation_rate, threshold_ms=isi_threshold_ms
)
if not pairs:
return sd.subset(np.arange(sd.N)), {"metric": metric, "passed": passed}
sim_mat, lag_mat, _ = _compute_pairwise_similarity(sd, pairs, max_lag=max_lag)
pairs = _filter_by_cosine_sim(pairs, sim_mat, threshold=cosine_threshold)
if not pairs:
return sd.subset(np.arange(sd.N)), {"metric": metric, "passed": passed}
sd_out, merge_result = _merge_redundant_units(
sd,
pairs,
sim_mat,
lag_matrix=lag_mat,
delta_ms=delta_ms,
max_isi_increase=max_isi_increase,
isi_threshold_ms=isi_threshold_ms,
verbose=verbose,
)
for primary, secondary, sim in merge_result["merged_pairs"]:
passed[secondary] = False
metric[secondary] = sim
metric[primary] = max(metric[primary], sim)
return sd_out, {"metric": metric, "passed": passed}
def _merge_redundant_units(
sd,
pairs,
similarity_matrix,
lag_matrix=None,
delta_ms=0.4,
max_isi_increase=0.04,
isi_threshold_ms=1.5,
verbose=False,
):
"""Merge pre-filtered candidate duplicate unit pairs into a new SpikeData.
Pairs are processed in descending similarity order (greedy). For each
pair the unit with more spikes is kept as primary; the unit with fewer
spikes is merged into it. A merge is accepted only if the ISI violation
rate after merging does not exceed the pre-merge maximum by more than
max_isi_increase. Units can be involved in multiple merges (e.g., A→B
then B→C results in a final unit containing spikes from A, B, and C).
Parameters:
sd (SpikeData): Source spike data.
pairs (set[tuple[int, int]] or list[tuple[int, int]]): Candidate
duplicate pairs, e.g. from _filter_by_cosine_sim().
similarity_matrix (np.ndarray): Shape (N, N) similarity values
used to sort pairs and record scores.
lag_matrix (np.ndarray, optional): Shape (N, N) lag values in samples.
If provided, the secondary unit's spikes are shifted by the lag
before merging to correct for timing offsets. Default None.
delta_ms (float): Spike deduplication window in ms.
max_isi_increase (float): Maximum allowable absolute increase in ISI
violation fraction after merging. Default 0.04 (4 percentage points).
isi_threshold_ms (float): ISI violation threshold in ms.
verbose (bool): Print a line for each pair decision.
Returns:
sd_out (SpikeData): New SpikeData with merged units.
result (dict): {"merged_pairs": list[tuple], "n_removed": int}.
merged_pairs is a list of (primary, secondary, similarity)
tuples that were accepted.
"""
if not pairs:
raise ValueError(
"pairs must be a non-empty collection. "
"Run _filter_by_cosine_sim() first."
)
sorted_pairs = sorted(
((i, j, float(similarity_matrix[i, j])) for i, j in pairs),
key=lambda x: x[2],
reverse=True,
)
merge_chain: dict = (
{}
) # unit_idx → primary_idx (maps each unit to its final primary)
current_train: dict = {} # tracks merged-so-far train for each primary
accepted_pairs = []
def _resolve(unit):
"""Follow merge_chain to its root (with path compression)."""
while unit in merge_chain:
nxt = merge_chain[unit]
if nxt in merge_chain:
merge_chain[unit] = merge_chain[nxt]
unit = nxt
return unit
for i, j, sim in sorted_pairs:
# Resolve which primary each unit is currently merged into (full chain)
prim_i = _resolve(i)
prim_j = _resolve(j)
# Skip if both units are already merged into the same primary
if prim_i == prim_j:
continue
primary, secondary = _choose_primary_unit(sd, i, j)
prim_primary = _resolve(primary)
prim_secondary = _resolve(secondary)
# Skip if already chained together via different paths
if prim_primary == prim_secondary:
continue
# Use the already-merged primary train if it has prior merges accepted
primary_train = current_train.get(prim_primary, sd.train[prim_primary])
secondary_train = current_train.get(prim_secondary, sd.train[prim_secondary])
# Apply lag correction only when secondary is fresh (not yet merged into another chain)
if lag_matrix is not None and prim_secondary == secondary:
lag_val = lag_matrix[primary, secondary]
if not np.isnan(lag_val) and lag_val != 0:
fs = sd.metadata.get("fs_Hz") if sd.metadata else None
if fs is None:
raise ValueError(
"Lag correction requires 'fs_Hz' in sd.metadata. "
"Set sd.metadata['fs_Hz'] to the sampling rate in Hz."
)
secondary_train = secondary_train + float(lag_val) / (fs / 1000.0)
before_max = max(
_isi_violation_fraction(primary_train, isi_threshold_ms),
_isi_violation_fraction(secondary_train, isi_threshold_ms),
)
merged_train, _ = _merge_two_trains(primary_train, secondary_train, delta_ms)
after_rate = _isi_violation_fraction(merged_train, isi_threshold_ms)
isi_increase = after_rate - before_max
if isi_increase <= max_isi_increase:
# Accept the merge: prim_secondary is merged into prim_primary
merge_chain[prim_secondary] = prim_primary
current_train[prim_primary] = merged_train
accepted_pairs.append((primary, secondary, sim))
if verbose:
print(
f" Merge [{i},{j}]: sim={sim:.3f}, "
f"ISI {before_max:.3f}→{after_rate:.3f} (Δ={isi_increase:+.3f})"
)
else:
if verbose:
print(
f" Skip [{i},{j}]: sim={sim:.3f}, "
f"ISI increase too high (Δ={isi_increase:+.3f} > {max_isi_increase})"
)
# Build a mapping: final_primary → list of original units that merged into it.
# Follow the chain for merges (e.g., A→B, B→C yields A,B→C).
primary_groups: dict = {}
for orig_unit in range(sd.N):
final_primary = merge_chain.get(orig_unit, orig_unit)
while final_primary in merge_chain:
final_primary = merge_chain[final_primary]
primary_groups.setdefault(final_primary, []).append(orig_unit)
new_trains = []
new_attrs = []
for primary in sorted(primary_groups.keys()):
constituent_units = primary_groups[primary]
# Reuse the pre-merged train from current_train if available
merged = current_train.get(primary, sd.train[primary])
original_spike_count = sum(len(sd.train[u]) for u in constituent_units)
total_dup = original_spike_count - len(merged)
new_trains.append(merged)
attrs = sd.neuron_attributes[primary].copy() if sd.neuron_attributes else {}
attrs["merged_from"] = [
(sd.neuron_attributes[u].get("unit_id", u) if sd.neuron_attributes else u)
for u in constituent_units
]
attrs["n_duplicates_removed"] = total_dup
new_attrs.append(attrs)
from .spikedata import SpikeData
sd_out = SpikeData(
new_trains,
length=sd.length,
start_time=sd.start_time,
neuron_attributes=new_attrs,
metadata=sd.metadata.copy() if sd.metadata else {},
raw_data=sd.raw_data,
raw_time=sd.raw_time,
)
n_removed = sd.N - len(new_trains)
if verbose:
print(f" {n_removed} units merged; " f"{sd.N} → {sd_out.N} units")
return sd_out, {"merged_pairs": accepted_pairs, "n_removed": n_removed}
# ---------------------------------------------------------------------------
# Internal helpers (merge-based deduplication)
# ---------------------------------------------------------------------------
def _isi_violation_fraction(train, threshold_ms):
"""Return the ISI violation rate as a fraction for a spike train.
ISI violation rate is n_violations / n_spikes.
Parameters:
train (np.ndarray): Spike times in milliseconds.
threshold_ms (float): Refractory period threshold in ms.
Returns:
rate (float): Violation rate as a fraction (0.0-1.0), or 0.0 for
fewer than 2 spikes.
"""
if len(train) < 2:
return 0.0
isis = np.diff(train)
violation_count = np.sum(isis < threshold_ms)
return float(violation_count / len(train))
def _build_1d_array_for_channels(sd, unit_idx, channels, template_len):
"""Build a 1-D waveform vector for unit_idx on an explicit channel list.
Channels present in the unit's avg_waveform are copied in; missing
channels are zero-padded.
Parameters:
sd (SpikeData): Source spike data with neuron_attributes.
unit_idx (int): Index of the unit.
channels (list[int]): Ordered channel list (defines output layout).
template_len (int): Samples per channel slot.
Returns:
arr (np.ndarray): 1-D array of length len(channels) * template_len.
"""
attrs = sd.neuron_attributes[unit_idx]
avg_wf = attrs.get("avg_waveform")
traces_meta = attrs.get("traces_meta", {})
channel_list = [int(c) for c in traces_meta.get("channels", [])]
ch_to_row = {ch: idx for idx, ch in enumerate(channel_list)}
arr = np.zeros(len(channels) * template_len)
if avg_wf is None:
return arr
for k, ch in enumerate(channels):
row = ch_to_row.get(ch)
if row is not None:
wf = avg_wf[row, :]
n = min(len(wf), template_len)
arr[k * template_len : k * template_len + n] = wf[:n]
return arr
def _merge_two_trains(train1, train2, delta_ms=0.4):
"""Merge two spike trains, removing duplicates within delta_ms.
Parameters:
train1 (np.ndarray): First spike train (ms).
train2 (np.ndarray): Second spike train (ms).
delta_ms (float): Deduplication window in ms.
Returns:
merged (np.ndarray): Sorted merged spike train.
n_duplicates (int): Number of spikes removed as duplicates.
"""
if len(train1) == 0 and len(train2) == 0:
return np.array([]), 0
if len(train1) == 0:
return np.sort(train2), 0
if len(train2) == 0:
return np.sort(train1), 0
times = np.concatenate([train1, train2])
membership = np.concatenate(
[np.zeros(len(train1), dtype=np.int8), np.ones(len(train2), dtype=np.int8)]
)
idx = np.argsort(times, kind="mergesort")
times = times[idx]
membership = membership[idx]
diffs = np.diff(times)
cross_train = np.diff(membership) != 0
dup_mask = (diffs <= delta_ms) & cross_train
keep = np.ones(len(times), dtype=bool)
keep[1:][dup_mask] = False
merged = times[keep]
return merged, int(np.sum(~keep))
def _choose_primary_unit(sd, i, j):
"""Return (primary, secondary) based on spike count.
The unit with the larger number of spikes is kept as primary.
"""
spike_count_i = len(sd.train[i])
spike_count_j = len(sd.train[j])
return (i, j) if spike_count_i >= spike_count_j else (j, i)