"""Kubernetes backend for batch job submission and monitoring."""
from __future__ import annotations
import subprocess
import tempfile
from pathlib import Path
from typing import Iterator, List, Optional
import yaml
try:
from kubernetes import client, config, watch
except ImportError: # pragma: no cover
client = None
config = None
watch = None
[docs]
class KubernetesBatchJobBackend:
"""Backend wrapper around Kubernetes client with kubectl fallback."""
[docs]
def __init__(
self,
namespace: str = "default",
kubeconfig: Optional[str] = None,
use_kubectl_fallback: bool = True,
) -> None:
self.namespace = namespace
self.kubeconfig = kubeconfig
self.use_kubectl_fallback = use_kubectl_fallback
self._batch_api = None
self._core_api = None
if client is not None and config is not None:
try:
config.load_kube_config(config_file=kubeconfig)
self._batch_api = client.BatchV1Api()
self._core_api = client.CoreV1Api()
except config.ConfigException:
pass # No valid kubeconfig — fall back to kubectl
def _run_kubectl(self, args: List[str]) -> str:
command = ["kubectl"]
if self.kubeconfig:
command.extend(["--kubeconfig", self.kubeconfig])
command.extend(args)
out = subprocess.run(command, check=True, text=True, capture_output=True)
return out.stdout.strip()
[docs]
def apply_manifest(self, manifest_path_or_str: str) -> str:
"""Apply a job manifest by YAML file path or raw YAML string."""
if self._batch_api is None:
if not self.use_kubectl_fallback:
raise RuntimeError(
"Kubernetes client unavailable and kubectl fallback disabled"
)
path = Path(manifest_path_or_str)
if path.exists():
return self._run_kubectl(
["apply", "-f", str(path), "-n", self.namespace]
)
temp_path = None
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", encoding="utf-8", delete=False
) as f:
f.write(manifest_path_or_str)
temp_path = f.name
try:
return self._run_kubectl(
["apply", "-f", temp_path, "-n", self.namespace]
)
finally:
if temp_path:
Path(temp_path).unlink(missing_ok=True)
path = Path(manifest_path_or_str)
if path.exists():
payload = yaml.safe_load(path.read_text(encoding="utf-8"))
else:
payload = yaml.safe_load(manifest_path_or_str)
self._batch_api.create_namespaced_job(namespace=self.namespace, body=payload)
return payload["metadata"]["name"]
[docs]
def delete_job(self, name: str) -> None:
"""Delete a job and its pods."""
if self._batch_api is None:
self._run_kubectl(
["delete", "job", name, "-n", self.namespace, "--ignore-not-found=true"]
)
return
self._batch_api.delete_namespaced_job(
name=name,
namespace=self.namespace,
body=client.V1DeleteOptions(propagation_policy="Background"),
)
[docs]
def job_status(self, name: str) -> str:
"""Return one of Pending/Running/Complete/Failed/Unknown."""
if self._batch_api is None:
out = self._run_kubectl(
["get", "job", name, "-n", self.namespace, "-o", "yaml"]
)
payload = yaml.safe_load(out)
status = payload.get("status", {})
else:
status_obj = self._batch_api.read_namespaced_job_status(
name, self.namespace
)
status = (
status_obj.status.to_dict() if status_obj and status_obj.status else {}
)
if status.get("failed"):
return "Failed"
if status.get("succeeded"):
return "Complete"
if status.get("active"):
return "Running"
return "Pending"
[docs]
def pods_for_job(self, job_name: str) -> List[str]:
"""Return pod names associated with a job."""
selector = f"job-name={job_name}"
if self._core_api is None:
out = self._run_kubectl(
["get", "pods", "-n", self.namespace, "-l", selector, "-o", "yaml"]
)
payload = yaml.safe_load(out)
return [item["metadata"]["name"] for item in payload.get("items", [])]
pods = self._core_api.list_namespaced_pod(
namespace=self.namespace,
label_selector=selector,
)
return [item.metadata.name for item in pods.items]
[docs]
def stream_logs(self, pod_name: str, follow: bool = True) -> Iterator[str]:
"""Yield log lines from a pod."""
if self._core_api is None:
args = ["logs", pod_name, "-n", self.namespace]
if follow:
args.append("-f")
process = subprocess.Popen(
["kubectl", *args],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
assert process.stdout is not None
for line in process.stdout:
yield line.rstrip("\n")
return
if follow and watch is not None:
watcher = watch.Watch()
for line in watcher.stream(
self._core_api.read_namespaced_pod_log,
name=pod_name,
namespace=self.namespace,
follow=True,
):
yield str(line)
return
text = self._core_api.read_namespaced_pod_log(
name=pod_name, namespace=self.namespace
)
for line in text.splitlines():
yield line