Source code for spikelab.batch_jobs.session

"""High-level run orchestration for packaging, uploading, and job submission."""

from __future__ import annotations

import dataclasses
import json
import os
import tempfile
import time
from pathlib import Path
from typing import Any, Dict, Iterable, Optional, Union
from uuid import uuid4

from .artifact_packager import package_analysis_bundle
from .backend_k8s import KubernetesBatchJobBackend
from .credentials import ResolvedCredentials, resolve_credentials
from .models import ClusterProfile, JobSpec, SubmitResult
from .policy import evaluate_policy, summarize_preflight
from .storage_s3 import S3StorageClient
from .templating import build_template_context, render_job_manifest

# Workspace path convention: save(base) produces base.h5 + base.json
_WORKSPACE_BASE_NAME = "workspace"


[docs] class RunSession: """Coordinates artifact packaging, job submission, and result retrieval."""
[docs] def __init__( self, *, profile: ClusterProfile, backend: KubernetesBatchJobBackend, storage_client: S3StorageClient, credentials: Optional[ResolvedCredentials] = None, ) -> None: self.profile = profile self.backend = backend self.storage = storage_client self.credentials = credentials or resolve_credentials()
# ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ @staticmethod def _build_job_name(prefix: str) -> str: token = uuid4().hex[:8] max_prefix = 63 - 1 - len(token) # 54 truncated = prefix[:max_prefix].rstrip("-") return f"{truncated}-{token}" def _preflight(self, job_spec: JobSpec, allow_policy_risk: bool) -> None: """Run policy checks and raise on BLOCK unless overridden.""" findings = evaluate_policy(job_spec, self.profile) status, summary = summarize_preflight(findings) if status == "BLOCK" and not allow_policy_risk: raise RuntimeError( "Policy preflight blocked submission. " f"Re-run with allow_policy_risk=True if intentional.\n{summary}" )
[docs] def render_manifest(self, *, job_name: str, job_spec: JobSpec, run_id: str) -> str: """Render a Kubernetes Job manifest from a spec and profile.""" context = build_template_context( job_name=job_name, job_spec=job_spec, profile=self.profile, extra_labels={"run_id": run_id}, ) return render_job_manifest(context)
def _submit( self, *, job_spec: JobSpec, run_id: str, uploaded_input_uri: str, job_type: str, ) -> SubmitResult: """Render manifest, apply to cluster, return result.""" job_name = self._build_job_name(job_spec.name_prefix) manifest_text = self.render_manifest( job_name=job_name, job_spec=job_spec, run_id=run_id ) with tempfile.NamedTemporaryFile( mode="w", suffix=".yaml", prefix=f"{job_name}-", delete=False, encoding="utf-8", ) as f: f.write(manifest_text) manifest_path = f.name try: self.backend.apply_manifest(manifest_path) finally: os.unlink(manifest_path) return SubmitResult( job_name=job_name, manifest_yaml=manifest_text, run_id=run_id, uploaded_input_uri=uploaded_input_uri, output_prefix=self.storage.output_prefix_for_run(run_id), logs_prefix=self.storage.logs_prefix_for_run(run_id), job_type=job_type, ) @staticmethod def _inject_env(job_spec: JobSpec, env: Dict[str, str]) -> JobSpec: """Return a copy of *job_spec* with additional env vars on the container.""" merged = dict(job_spec.container.env) merged.update(env) updated_container = job_spec.container.model_copy(update={"env": merged}) return job_spec.model_copy(update={"container": updated_container}) # ------------------------------------------------------------------ # Submission: workspace job # ------------------------------------------------------------------
[docs] def submit_workspace_job( self, *, workspace: Any, script: str, job_spec: JobSpec, allow_policy_risk: bool = False, bundle_input_paths: Optional[Iterable[str]] = None, metadata: Optional[Dict[str, object]] = None, ) -> SubmitResult: """Save a workspace, bundle it with a script, and submit a job. Parameters: workspace: An ``AnalysisWorkspace`` instance or a ``str`` path to an existing workspace base path (without extension). script (str): Path to the analysis script to run inside the container. job_spec (JobSpec): Kubernetes job specification. allow_policy_risk (bool): Override policy BLOCK findings. bundle_input_paths (iterable[str] | None): Extra files to include in the bundle. metadata (dict | None): Arbitrary metadata written into the bundle manifest. Returns: result (SubmitResult): Submission details including the output prefix where the updated workspace will appear. """ self._preflight(job_spec, allow_policy_risk) run_id = uuid4().hex with tempfile.TemporaryDirectory(prefix=f"{run_id}-session-") as temp_dir: # Resolve workspace to .h5 + .json on disk workspace_base = self._save_workspace(workspace, temp_dir) script_path = Path(script) if not script_path.exists(): raise FileNotFoundError(f"Analysis script not found: {script_path}") input_files = [ f"{workspace_base}.h5", f"{workspace_base}.json", str(script_path), *(bundle_input_paths or []), ] bundle_zip = package_analysis_bundle( input_paths=input_files, run_id=run_id, output_dir=temp_dir, output_format="workspace", metadata=metadata, ) uploaded_input_uri = self.storage.upload_bundle( local_zip=bundle_zip, run_id=run_id ) enriched_spec = self._inject_env( job_spec, { "INPUT_URI": uploaded_input_uri, "OUTPUT_PREFIX": self.storage.output_prefix_for_run(run_id), "SCRIPT_NAME": script_path.name, }, ) # Set container command to the workspace entrypoint enriched_spec = enriched_spec.model_copy( update={ "container": enriched_spec.container.model_copy( update={ "command": [ "python", "-m", "spikelab.batch_jobs.entrypoints.workspace", ], } ) } ) return self._submit( job_spec=enriched_spec, run_id=run_id, uploaded_input_uri=uploaded_input_uri, job_type="workspace", )
# ------------------------------------------------------------------ # Submission: sorting job # ------------------------------------------------------------------
[docs] def submit_sorting_job( self, *, recording_paths: list, config: Any = None, config_overrides: Optional[Dict[str, Any]] = None, job_spec: JobSpec, allow_policy_risk: bool = False, metadata: Optional[Dict[str, object]] = None, ) -> SubmitResult: """Bundle recording files with a sorting config and submit a job. Parameters: recording_paths (list[str]): Paths to recording files. config: A ``SortingPipelineConfig`` instance, a preset name string (e.g. ``"kilosort4"``), or None for defaults. config_overrides (dict | None): Flat keyword overrides applied to the config via ``config.override()``. job_spec (JobSpec): Kubernetes job specification. allow_policy_risk (bool): Override policy BLOCK findings. metadata (dict | None): Arbitrary metadata written into the bundle manifest. Returns: result (SubmitResult): Submission details including the output prefix where sorted results will appear. """ self._preflight(job_spec, allow_policy_risk) run_id = uuid4().hex with tempfile.TemporaryDirectory(prefix=f"{run_id}-session-") as temp_dir: config_dict = self._resolve_sorting_config(config, config_overrides) config_path = Path(temp_dir) / "sorting_config.json" config_path.write_text( json.dumps(config_dict, indent=2, default=str), encoding="utf-8" ) # Validate recording paths for rpath in recording_paths: if not Path(rpath).exists(): raise FileNotFoundError(f"Recording file not found: {rpath}") input_files = [ str(config_path), *[str(p) for p in recording_paths], ] bundle_zip = package_analysis_bundle( input_paths=input_files, run_id=run_id, output_dir=temp_dir, output_format="sorting", metadata=metadata, ) uploaded_input_uri = self.storage.upload_bundle( local_zip=bundle_zip, run_id=run_id ) enriched_spec = self._inject_env( job_spec, { "INPUT_URI": uploaded_input_uri, "OUTPUT_PREFIX": self.storage.output_prefix_for_run(run_id), }, ) enriched_spec = enriched_spec.model_copy( update={ "container": enriched_spec.container.model_copy( update={ "command": [ "python", "-m", "spikelab.batch_jobs.entrypoints.sorting", ], } ) } ) return self._submit( job_spec=enriched_spec, run_id=run_id, uploaded_input_uri=uploaded_input_uri, job_type="sorting", )
# ------------------------------------------------------------------ # Submission: prepared job (no bundling) # ------------------------------------------------------------------
[docs] def submit_prepared_job( self, *, job_spec: JobSpec, run_id: Optional[str] = None, allow_policy_risk: bool = False, ) -> SubmitResult: """Submit a job without generating bundle artifacts.""" self._preflight(job_spec, allow_policy_risk) current_run_id = run_id or uuid4().hex return self._submit( job_spec=job_spec, run_id=current_run_id, uploaded_input_uri="", job_type="prepared", )
# ------------------------------------------------------------------ # Retrieval # ------------------------------------------------------------------
[docs] def retrieve_result( self, submit_result: SubmitResult, local_dir: str, ) -> Any: """Download job outputs and return an AnalysisWorkspace. Parameters: submit_result (SubmitResult): The result from a prior ``submit_workspace_job`` or ``submit_sorting_job`` call. local_dir (str): Local directory to download outputs into. Returns: workspace (AnalysisWorkspace): The workspace produced by the job. For workspace jobs this is the updated workspace; for sorting jobs it contains per-recording namespaces with SpikeData at key ``"spikedata"``. Notes: - Call ``wait_for_completion`` before calling this method to ensure the job has finished. """ from ..workspace.workspace import AnalysisWorkspace local = Path(local_dir) local.mkdir(parents=True, exist_ok=True) if submit_result.job_type == "workspace": return self._retrieve_workspace(submit_result, local) elif submit_result.job_type == "sorting": return self._retrieve_sorting(submit_result, local) else: raise ValueError( f"Cannot retrieve results for job_type={submit_result.job_type!r}. " "Only 'workspace' and 'sorting' jobs produce retrievable outputs." )
def _retrieve_workspace(self, result: SubmitResult, local_dir: Path) -> Any: """Download workspace .h5 + .json and load.""" from ..workspace.workspace import AnalysisWorkspace h5_name = f"{_WORKSPACE_BASE_NAME}.h5" json_name = f"{_WORKSPACE_BASE_NAME}.json" self.storage.download_output( run_id=result.run_id, filename=h5_name, local_dir=str(local_dir) ) self.storage.download_output( run_id=result.run_id, filename=json_name, local_dir=str(local_dir) ) base_path = str(local_dir / _WORKSPACE_BASE_NAME) return AnalysisWorkspace.load(base_path) def _retrieve_sorting(self, result: SubmitResult, local_dir: Path) -> Any: """Download all sorting outputs, build workspace from pickles.""" from ..data_loaders.data_loaders import load_spikedata_from_pickle from ..workspace.workspace import AnalysisWorkspace # Download everything under the output prefix keys = self.storage.list_output_files(result.run_id) if not keys: raise FileNotFoundError(f"No output files found for run_id={result.run_id}") from ..data_loaders.s3_utils import parse_s3_url prefix = result.output_prefix bucket, prefix_key = parse_s3_url(prefix) downloaded = [] for key in keys: # Derive relative path from prefix relative = key[len(prefix_key) :] if key.startswith(prefix_key) else key local_path = local_dir / relative local_path.parent.mkdir(parents=True, exist_ok=True) s3_uri = f"s3://{bucket}/{key}" self.storage.download_file(s3_uri=s3_uri, local_path=str(local_path)) downloaded.append((relative, str(local_path))) # Build workspace from downloaded SpikeData pickles ws = AnalysisWorkspace(name=f"sorting-{result.run_id[:8]}") for relative, local_path in downloaded: if local_path.endswith(".pkl"): try: sd = load_spikedata_from_pickle(local_path) except Exception: continue namespace = Path(relative).stem ws.store(namespace, "spikedata", sd) elif local_path.endswith(".json") and "config" not in relative: # Store sorting metadata try: with open(local_path, "r", encoding="utf-8") as f: meta = json.load(f) namespace = Path(relative).stem ws.store(namespace, "sorting_metadata", meta) except Exception: continue return ws # ------------------------------------------------------------------ # Wait # ------------------------------------------------------------------
[docs] def wait_for_completion( self, *, job_name: str, max_wait_seconds: int = 3600, poll_interval_seconds: int = 10, ) -> str: """Poll until completion/failure or timeout and return final state.""" deadline = time.time() + max_wait_seconds while time.time() < deadline: state = self.backend.job_status(job_name) if state in {"Complete", "Failed"}: return state time.sleep(poll_interval_seconds) return "Timeout"
# ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ @staticmethod def _save_workspace(workspace: Any, work_dir: str) -> str: """Ensure workspace is saved to disk, return the base path. Parameters: workspace: ``AnalysisWorkspace`` or ``str`` base path. work_dir (str): Directory to save into if workspace is an object. Returns: base_path (str): Path without extension. """ if isinstance(workspace, str): # Assume it's a base path; verify .h5 exists if not Path(f"{workspace}.h5").exists(): raise FileNotFoundError(f"Workspace file not found: {workspace}.h5") return workspace # It's an AnalysisWorkspace object base_path = str(Path(work_dir) / _WORKSPACE_BASE_NAME) workspace.save(base_path) return base_path @staticmethod def _resolve_sorting_config( config: Any, overrides: Optional[Dict[str, Any]] ) -> dict: """Resolve a sorting config to a serializable dict. Parameters: config: ``SortingPipelineConfig``, preset name string, or None. overrides (dict | None): Flat keyword overrides. Returns: config_dict (dict): JSON-serializable nested dict. """ from ..spike_sorting.config import SortingPipelineConfig if config is None: resolved = SortingPipelineConfig() elif isinstance(config, str): # Treat as preset name import spikelab.spike_sorting.config as cfg_module preset = getattr(cfg_module, config.upper(), None) if preset is None: raise ValueError( f"Unknown sorting preset: {config!r}. Available: " "KILOSORT2, KILOSORT4, KILOSORT2_DOCKER, KILOSORT4_DOCKER, " "RT_SORT_MEA, RT_SORT_NEUROPIXELS" ) resolved = preset else: resolved = config if overrides: resolved = resolved.override(**overrides) return dataclasses.asdict(resolved)