Source code for spikelab.batch_jobs.storage_s3

"""S3-compatible storage helpers for batch job artifacts."""

from __future__ import annotations

from pathlib import Path
from typing import Optional

try:
    import boto3
except ImportError:  # pragma: no cover
    boto3 = None  # type: ignore[assignment]

from ..data_loaders.s3_utils import parse_s3_url
from .models import StoragePathTemplates


[docs] class S3StorageClient: """Small wrapper around boto3 for upload/download URI handling. Path layout is controlled by *path_templates* (a :class:`StoragePathTemplates` instance loaded from the active profile). """
[docs] def __init__( self, *, prefix: Optional[str] = None, path_templates: Optional[StoragePathTemplates] = None, endpoint_url: Optional[str] = None, region_name: Optional[str] = None, aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, aws_session_token: Optional[str] = None, ) -> None: self.prefix = ( (prefix if prefix.endswith("/") else f"{prefix}/") if prefix else None ) self.endpoint_url = endpoint_url self.region_name = region_name self._templates = path_templates or StoragePathTemplates() if boto3 is None: raise ImportError("boto3 is required for S3 storage: pip install boto3") self._client = boto3.client( "s3", endpoint_url=endpoint_url, region_name=region_name, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, )
[docs] def build_uri(self, *, run_id: str, filename: str, category: str = "inputs") -> str: """Build an S3 URI for a file using the active path templates.""" if not self.prefix: raise ValueError( "S3 prefix is not configured. Set it in the profile or command." ) template = getattr(self._templates, category, self._templates.inputs) return template.format(prefix=self.prefix, run_id=run_id, filename=filename)
[docs] def upload_file(self, *, local_path: str, s3_uri: str) -> str: """Upload a local file to S3 and return the URI.""" bucket, key = parse_s3_url(s3_uri) self._client.upload_file(local_path, bucket, key) return s3_uri
[docs] def upload_bundle(self, *, local_zip: str, run_id: str) -> str: """Upload a zip bundle to S3 under the inputs path template.""" filename = Path(local_zip).name uri = self.build_uri(run_id=run_id, filename=filename, category="inputs") return self.upload_file(local_path=local_zip, s3_uri=uri)
[docs] def output_prefix_for_run(self, run_id: str) -> str: """Return the S3 prefix for a run's output files.""" if not self.prefix: return "" return self._templates.outputs.format( prefix=self.prefix, run_id=run_id, filename="" )
[docs] def logs_prefix_for_run(self, run_id: str) -> str: """Return the S3 prefix for a run's log files.""" if not self.prefix: return "" return self._templates.logs.format( prefix=self.prefix, run_id=run_id, filename="" )
[docs] def download_file(self, *, s3_uri: str, local_path: str) -> str: """Download a single file from S3. Parameters: s3_uri (str): Full ``s3://bucket/key`` URI. local_path (str): Destination path on disk. Returns: local_path (str): The same *local_path* for convenience. """ bucket, key = parse_s3_url(s3_uri) Path(local_path).parent.mkdir(parents=True, exist_ok=True) self._client.download_file(bucket, key, local_path) return local_path
[docs] def download_output(self, *, run_id: str, filename: str, local_dir: str) -> str: """Download a file from the output prefix of a run. Parameters: run_id (str): Run identifier. filename (str): Name of the file within the output prefix. local_dir (str): Local directory to save the file into. Returns: local_path (str): Absolute path of the downloaded file. """ prefix = self.output_prefix_for_run(run_id) s3_uri = prefix + filename local_path = str(Path(local_dir) / filename) return self.download_file(s3_uri=s3_uri, local_path=local_path)
[docs] def list_output_files(self, run_id: str) -> list: """List object keys under the output prefix of a run. Parameters: run_id (str): Run identifier. Returns: keys (list[str]): S3 object keys found under the output prefix. """ prefix = self.output_prefix_for_run(run_id) if not prefix: return [] bucket, key_prefix = parse_s3_url(prefix) paginator = self._client.get_paginator("list_objects_v2") keys = [] for page in paginator.paginate(Bucket=bucket, Prefix=key_prefix): for obj in page.get("Contents", []): keys.append(obj["Key"]) return keys