"""Live host-memory watchdog for spike-sorting runs.
The watchdog is a daemon thread that polls
``psutil.virtual_memory().percent`` at a configurable cadence. When the
system memory percentage crosses a *warning* threshold the watchdog
prints a rate-limited notice; when it crosses an *abort* threshold it:
1. Terminates every registered subprocess (e.g. the Kilosort2 MATLAB
child) so they release their RAM promptly.
2. Calls :func:`_thread.interrupt_main` to inject a
``KeyboardInterrupt`` into the main Python thread at the next
bytecode boundary.
The pipeline catches the resulting interrupt and re-raises it as
:class:`spikelab.spike_sorting._exceptions.HostMemoryWatchdogError`,
which is a :class:`ResourceSortFailure` subclass — so callers can apply
retry/skip policies uniformly with other resource failures.
Discovery
---------
The watchdog publishes itself via a :class:`contextvars.ContextVar` on
``__enter__`` and clears it on ``__exit__``. Backends that spawn child
processes (e.g. the Kilosort2 MATLAB runner) call
:func:`get_active_watchdog` to find the live instance and register
their ``subprocess.Popen`` handle with it. This avoids threading a
watchdog parameter through every backend signature.
Platform notes
--------------
The detection step is fully platform-agnostic — ``psutil`` reads
system-wide pressure on Linux, macOS, and Windows alike. The reaction
step has known limits:
* ``_thread.interrupt_main`` only fires at Python bytecode boundaries.
A long-running C extension (a single multi-GB ``np.concatenate``,
a numba ``@njit(parallel=True)`` kernel, a PyTorch CUDA kernel) will
not see the interrupt until it returns. The watchdog still emits its
warning, and the abort takes effect as soon as control returns to
the interpreter.
* Subprocess termination uses :meth:`subprocess.Popen.terminate` then
:meth:`subprocess.Popen.kill` after a grace period. On Windows the
MATLAB JVM occasionally ignores the initial terminate; the kill
fallback handles that.
* If ``psutil`` is not installed the watchdog degrades to a no-op
context manager so the pipeline still runs.
"""
from __future__ import annotations
import _thread
import contextvars
import logging
import math
import subprocess
import threading
import time
from typing import Callable, List, Optional, Tuple
from .._exceptions import HostMemoryWatchdogError
from ._audit import append_audit_event
_logger = logging.getLogger(__name__)
_active_watchdog: contextvars.ContextVar[Optional["HostMemoryWatchdog"]] = (
contextvars.ContextVar("active_host_memory_watchdog", default=None)
)
def get_active_watchdog() -> Optional["HostMemoryWatchdog"]:
"""Return the watchdog active for the current context, or None.
Backends that spawn child processes (Kilosort2 MATLAB, Docker
containers, etc.) call this to find the live watchdog and register
their ``Popen`` handle so the watchdog can terminate the child on
abort.
Returns:
watchdog (HostMemoryWatchdog or None): The active instance, or
``None`` when no watchdog is currently running.
"""
return _active_watchdog.get()
[docs]
class HostMemoryWatchdog:
"""Daemon-thread watchdog that aborts the sort on host RAM pressure.
Use as a context manager. While the context is active a daemon
thread polls system memory; on abort it terminates registered
subprocesses and injects a ``KeyboardInterrupt`` into the main
thread.
Parameters:
warn_pct (float): System memory percentage at which the
watchdog prints a (rate-limited) warning. Defaults to
``85.0``.
abort_pct (float): System memory percentage at which the
watchdog terminates registered subprocesses and aborts
the main thread. Defaults to ``92.0``.
poll_interval_s (float): Seconds between polls. Defaults to
``2.0``.
warn_repeat_s (float): Minimum seconds between repeated
warnings at the same level. Defaults to ``30.0``.
kill_grace_s (float): Default seconds between
``terminate()`` and ``kill()`` for registered
subprocesses. Per-subprocess overrides are accepted in
:meth:`register_subprocess`. Defaults to ``5.0``.
Notes:
- Degrades to a no-op when ``psutil`` is missing.
- Safe to nest: the inner context is the active one for the
duration of its body, and the outer context resumes on exit.
"""
[docs]
def __init__(
self,
warn_pct: float = 85.0,
abort_pct: float = 92.0,
poll_interval_s: float = 2.0,
warn_repeat_s: float = 30.0,
kill_grace_s: float = 5.0,
) -> None:
if not 0.0 < warn_pct < abort_pct <= 100.0:
raise ValueError(
f"warn_pct ({warn_pct}) and abort_pct ({abort_pct}) must "
f"satisfy 0 < warn_pct < abort_pct <= 100."
)
if poll_interval_s <= 0.0:
raise ValueError(
f"poll_interval_s must be positive, got {poll_interval_s}."
)
if kill_grace_s < 0.0:
raise ValueError(f"kill_grace_s must be non-negative, got {kill_grace_s}.")
self.warn_pct = float(warn_pct)
self.abort_pct = float(abort_pct)
self.poll_interval_s = float(poll_interval_s)
self.warn_repeat_s = float(warn_repeat_s)
self.kill_grace_s = float(kill_grace_s)
self._subprocesses: List[Tuple[subprocess.Popen, float]] = []
self._kill_callbacks: List[Callable[[], None]] = []
self._lock = threading.Lock()
self._stop_event = threading.Event()
self._thread: Optional[threading.Thread] = None
self._token: Optional[contextvars.Token] = None
self._tripped = False
self._percent_at_trip: Optional[float] = None
self._last_warn_t = 0.0
self._psutil = None
self._enabled = False
# Captured at ``__enter__`` time on the main thread because
# ContextVars do not propagate to the polling thread.
self._snapshot_log_path = None
# Set True when the trip cascade ran but
# ``_thread.interrupt_main`` raised — the main thread did not
# receive the KeyboardInterrupt and will surface a downstream
# error instead of the classified watchdog error. Catch sites
# check this via :meth:`interrupt_delivery_failed` to
# reclassify the downstream exception.
self._interrupt_main_failed = False
# ------------------------------------------------------------------
# Subprocess registration (called by backends)
# ------------------------------------------------------------------
[docs]
def register_subprocess(
self,
popen: subprocess.Popen,
*,
kill_grace_s: Optional[float] = None,
) -> None:
"""Track a subprocess for termination on watchdog abort.
Parameters:
popen (subprocess.Popen): The child process handle. The
watchdog calls ``terminate()`` first, then ``kill()``
after ``kill_grace_s`` seconds if the process is still
alive.
kill_grace_s (float or None): Override the default grace
period for this subprocess. ``None`` uses the
watchdog's ``kill_grace_s``.
"""
grace = self.kill_grace_s if kill_grace_s is None else float(kill_grace_s)
with self._lock:
self._subprocesses.append((popen, grace))
[docs]
def unregister_subprocess(self, popen: subprocess.Popen) -> None:
"""Stop tracking a previously registered subprocess.
Parameters:
popen (subprocess.Popen): Handle previously passed to
:meth:`register_subprocess`. No-op if not registered.
"""
with self._lock:
self._subprocesses = [
(p, g) for (p, g) in self._subprocesses if p is not popen
]
[docs]
def register_kill_callback(self, callback: Callable[[], None]) -> None:
"""Track a zero-arg callable to invoke on watchdog abort.
Used for kill targets that are not ``subprocess.Popen``
objects — Docker containers, kubernetes pods, custom
cleanup hooks. The callback runs after any registered
subprocesses have been terminated. Exceptions raised by a
callback are logged but do not prevent other callbacks from
running.
Parameters:
callback (Callable[[], None]): Zero-arg function. Should
be idempotent and tolerate being called on an
already-stopped target — the watchdog cannot tell
whether the kill target is still alive.
Notes:
- To allow the kill target to be garbage-collected even
while registered, build the callback with a weakref to
the target rather than capturing it directly. See
``docker_utils.patched_container_client`` for the
container-kill pattern.
"""
with self._lock:
self._kill_callbacks.append(callback)
[docs]
def unregister_kill_callback(self, callback: Callable[[], None]) -> None:
"""Stop tracking a previously registered kill callback.
Parameters:
callback (Callable[[], None]): Callable previously passed
to :meth:`register_kill_callback`. No-op if not
registered. Identity comparison is used.
"""
with self._lock:
self._kill_callbacks = [
c for c in self._kill_callbacks if c is not callback
]
# ------------------------------------------------------------------
# Trip state (read by the pipeline catch site)
# ------------------------------------------------------------------
[docs]
def tripped(self) -> bool:
"""Return True if the watchdog has fired its abort path."""
return self._tripped
[docs]
def interrupt_delivery_failed(self) -> bool:
"""Return True if the trip fired but ``_thread.interrupt_main`` raised.
When True, host protection ran successfully (subprocesses
terminated, kill callbacks invoked) but the main thread did
not receive a ``KeyboardInterrupt``. The pipeline's catch
site checks this to reclassify a downstream
``BrokenPipeError`` / ``RuntimeError`` (caused by the now-dead
subprocess) as the appropriate watchdog error.
Returns:
failed (bool): True only when the watchdog tripped and
the interrupt delivery raised.
"""
return self._interrupt_main_failed
[docs]
def percent_at_trip(self) -> Optional[float]:
"""Return the memory percent at the trip moment, or None."""
return self._percent_at_trip
[docs]
def make_error(self, message: Optional[str] = None) -> HostMemoryWatchdogError:
"""Build a :class:`HostMemoryWatchdogError` from the trip state.
Parameters:
message (str or None): Override the default message.
Returns:
err (HostMemoryWatchdogError): Exception ready to raise.
"""
if message is None:
pct = (
f"{self._percent_at_trip:.1f}"
if self._percent_at_trip is not None
else "?"
)
message = (
f"Host RAM watchdog tripped at {pct}% "
f"(abort threshold: {self.abort_pct:.1f}%). "
"Subprocesses terminated; current recording aborted."
)
return HostMemoryWatchdogError(
message,
percent_at_trip=self._percent_at_trip,
abort_pct=self.abort_pct,
)
# ------------------------------------------------------------------
# Context manager
# ------------------------------------------------------------------
def __enter__(self) -> "HostMemoryWatchdog":
# Capture the active per-recording log path on the main
# thread; the daemon polling thread cannot read the
# ContextVar reliably.
try:
from ._inactivity import get_active_log_path
self._snapshot_log_path = get_active_log_path()
except Exception:
self._snapshot_log_path = None
try:
import psutil
self._psutil = psutil
self._enabled = True
except ImportError:
_logger.warning(
"psutil not installed — watchdog disabled. Install "
"psutil to enable host RAM monitoring."
)
self._enabled = False
self._token = _active_watchdog.set(self)
return self
_logger.info(
"active: warn=%.1f%% abort=%.1f%% poll=%.1fs",
self.warn_pct,
self.abort_pct,
self.poll_interval_s,
)
self._token = _active_watchdog.set(self)
self._stop_event.clear()
self._thread = threading.Thread(
target=self._poll_loop,
name="HostMemoryWatchdog",
daemon=True,
)
self._thread.start()
return self
def __exit__(self, exc_type, exc, tb) -> None:
self._stop_event.set()
if self._thread is not None:
self._thread.join(timeout=self.poll_interval_s + 1.0)
self._thread = None
if self._token is not None:
try:
_active_watchdog.reset(self._token)
except (LookupError, ValueError, RuntimeError):
# Another context modified the var between set/reset,
# or the token was already consumed (Python 3.10+
# raises RuntimeError on re-used tokens). Matches the
# symmetric guards in GpuMemoryWatchdog and
# IOStallWatchdog so a degenerate teardown does not
# leak the active-watchdog publication.
pass
self._token = None
with self._lock:
self._subprocesses.clear()
# ------------------------------------------------------------------
# Internals
# ------------------------------------------------------------------
def _poll_loop(self) -> None:
"""Polling loop: warn, then trip, then exit."""
# Defer the first measurement by one poll interval so
# ``__enter__`` always returns and the protected body starts
# executing before any trip can fire. Without this delay a
# watchdog spawned in an already-stressed environment could
# land its KeyboardInterrupt inside ``Thread.start`` itself,
# which leaves the with-block in a half-entered state.
if self._stop_event.wait(self.poll_interval_s):
return
blind_threshold_s = 5.0 * self.warn_repeat_s
blind_started_t: Optional[float] = None
blind_warned = False
while not self._stop_event.is_set():
now = time.time()
try:
pct = float(self._psutil.virtual_memory().percent)
except Exception:
# psutil on some platforms can transiently fail; skip
# this tick rather than tearing down the watchdog.
# Track how long the readings have been unavailable so
# a sustained psutil failure surfaces a one-time
# warning instead of silently disabling the abort path.
if blind_started_t is None:
blind_started_t = now
elif not blind_warned and now - blind_started_t >= blind_threshold_s:
self._warn_blind(now - blind_started_t)
blind_warned = True
self._stop_event.wait(self.poll_interval_s)
continue
# NaN comparisons are always False, so a NaN reading would
# silently disable the watchdog. Skip the tick rather than
# treating it as either a healthy or unhealthy reading.
if math.isnan(pct):
if blind_started_t is None:
blind_started_t = now
elif not blind_warned and now - blind_started_t >= blind_threshold_s:
self._warn_blind(now - blind_started_t)
blind_warned = True
self._stop_event.wait(self.poll_interval_s)
continue
# Successful reading — clear the blindness tracker so a
# later episode is reported afresh.
blind_started_t = None
blind_warned = False
if pct >= self.abort_pct:
self._on_abort(pct)
return
if pct >= self.warn_pct:
self._maybe_warn(pct)
self._stop_event.wait(self.poll_interval_s)
def _maybe_warn(self, pct: float) -> None:
"""Print a warning if enough time has passed since the last one."""
now = time.time()
if now - self._last_warn_t < self.warn_repeat_s:
return
self._last_warn_t = now
_logger.warning(
"system memory at %.1f%% (warn=%.1f%% / abort=%.1f%%). "
"Free memory or expect an abort if pressure keeps climbing.",
pct,
self.warn_pct,
self.abort_pct,
)
append_audit_event(
watchdog="host_memory",
event="warn",
log_path=self._snapshot_log_path,
used_pct=pct,
warn_pct=self.warn_pct,
abort_pct=self.abort_pct,
)
def _warn_blind(self, blind_for: float) -> None:
_logger.warning(
"psutil.virtual_memory() unreadable for %.1fs — watchdog "
"is blind to RAM-pressure aborts until readings recover.",
blind_for,
)
append_audit_event(
watchdog="host_memory",
event="blind_warn",
source="virtual_memory",
log_path=self._snapshot_log_path,
blind_for_s=blind_for,
)
def _on_abort(self, pct: float) -> None:
"""Terminate registered subprocesses and interrupt the main thread."""
self._tripped = True
self._percent_at_trip = pct
_logger.error(
"ABORT: system memory at %.1f%% (>= %.1f%%). Terminating "
"subprocesses and raising into main thread.",
pct,
self.abort_pct,
)
append_audit_event(
watchdog="host_memory",
event="abort",
log_path=self._snapshot_log_path,
used_pct=pct,
abort_pct=self.abort_pct,
)
# Best-effort GPU snapshot for postmortem analysis. Useful
# even on host-RAM trips since RT-Sort / KS4 often hold
# significant GPU state alongside their host buffers.
try:
from ._gpu_watchdog import _try_capture_snapshot_to_results
_try_capture_snapshot_to_results(
self._snapshot_log_path,
f"Host memory watchdog trip — system at {pct:.1f}%",
)
except Exception:
pass
self._terminate_registered()
self._run_kill_callbacks()
# If __exit__ ran while we were mid-cascade (terminate +
# grace + callbacks can take several seconds), the with-block
# has already torn down. Sending interrupt_main() now would
# land a phantom KeyboardInterrupt in whatever code is running
# next — the next sort, an exception handler, or the
# interactive prompt. Skip it.
if self._stop_event.is_set():
_logger.info("suppressing interrupt_main: watchdog is already exiting.")
return
try:
_thread.interrupt_main()
except Exception as exc:
self._interrupt_main_failed = True
_logger.error("failed to interrupt main: %s", exc)
append_audit_event(
watchdog="host_memory",
event="interrupt_delivery_failed",
log_path=self._snapshot_log_path,
error=repr(exc),
)
def _run_kill_callbacks(self) -> None:
"""Invoke every registered kill callback; isolate failures."""
with self._lock:
callbacks = list(self._kill_callbacks)
for cb in callbacks:
try:
cb()
except Exception as exc:
_logger.error("kill_callback raised: %r; continuing.", exc)
def _terminate_registered(self) -> None:
"""Best-effort terminate-then-kill of every registered subprocess."""
with self._lock:
entries = list(self._subprocesses)
# Terminate every still-alive process first; gives them all the
# full grace period to exit cleanly in parallel.
for popen, _grace in entries:
try:
if popen.poll() is None:
popen.terminate()
except Exception as exc:
_logger.error(
"terminate() failed for pid=%s: %s",
getattr(popen, "pid", "?"),
exc,
)
if entries:
grace = max(g for _, g in entries)
time.sleep(grace)
for popen, _grace in entries:
try:
if popen.poll() is None:
popen.kill()
_logger.warning(
"killed pid=%s (terminate ignored).",
getattr(popen, "pid", "?"),
)
except Exception as exc:
_logger.error(
"kill() failed for pid=%s: %s",
getattr(popen, "pid", "?"),
exc,
)