Source code for divi.qprog.checkpointing

# SPDX-FileCopyrightText: 2025-2026 Qoro Quantum Ltd <divi@qoroquantum.de>
#
# SPDX-License-Identifier: Apache-2.0

"""Checkpointing utilities for variational quantum algorithms."""

import json
import shutil
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any

__all__ = [
    "CheckpointConfig",
    "CheckpointCorruptedError",
    "CheckpointError",
    "CheckpointInfo",
    "CheckpointNotFoundError",
    "cleanup_old_checkpoints",
    "get_checkpoint_info",
    "get_latest_checkpoint",
    "list_checkpoints",
    "resolve_checkpoint_path",
]

# Constants for checkpoint file and directory naming
PROGRAM_STATE_FILE = "program_state.json"
OPTIMIZER_STATE_FILE = "optimizer_state.json"
SUBDIR_PREFIX = "checkpoint_"

# Maximum reasonable iteration number (prevents parsing errors from corrupted names)
_MAX_ITERATION_NUMBER = 1_000_000


def _get_checkpoint_subdir_name(iteration: int) -> str:
    """Generate checkpoint subdirectory name for a given iteration.

    Args:
        iteration (int): Iteration number.

    Returns:
        str: Subdirectory name (e.g., "checkpoint_001").
    """
    return f"{SUBDIR_PREFIX}{iteration:03d}"


def _extract_iteration_from_subdir(subdir_name: str) -> int | None:
    """Extract iteration number from checkpoint subdirectory name.

    Args:
        subdir_name (str): Subdirectory name (e.g., "checkpoint_001").

    Returns:
        int | None: Iteration number if valid and reasonable, None otherwise.
    """
    if not subdir_name.startswith(SUBDIR_PREFIX):
        return None
    suffix = subdir_name[len(SUBDIR_PREFIX) :]
    if not suffix.isdigit():
        return None
    iteration = int(suffix)
    # Validate that iteration number is reasonable
    if iteration < 0 or iteration > _MAX_ITERATION_NUMBER:
        return None
    return iteration


def _ensure_checkpoint_dir(checkpoint_dir: Path) -> Path:
    """Ensure checkpoint directory exists.

    Args:
        checkpoint_dir (Path): Directory path.

    Returns:
        Path: The checkpoint directory path.
    """
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    return checkpoint_dir


def _get_checkpoint_subdir_path(main_dir: Path, iteration: int) -> Path:
    """Get the path to a checkpoint subdirectory for a given iteration.

    Args:
        main_dir (Path): Main checkpoint directory.
        iteration (int): Iteration number.

    Returns:
        Path: Path to the checkpoint subdirectory.
    """
    subdir_name = _get_checkpoint_subdir_name(iteration)
    return main_dir / subdir_name


def _find_latest_checkpoint_subdir(main_dir: Path) -> Path:
    """Find the latest checkpoint subdirectory by iteration number.

    Args:
        main_dir (Path): Main checkpoint directory.

    Returns:
        Path: Path to the latest checkpoint subdirectory.

    Raises:
        CheckpointNotFoundError: If no checkpoint subdirectories are found.
    """
    checkpoint_dirs = [
        d
        for d in main_dir.iterdir()
        if d.is_dir() and _extract_iteration_from_subdir(d.name) is not None
    ]
    if not checkpoint_dirs:
        # Provide helpful error message with available directories
        available_dirs = [d.name for d in main_dir.iterdir() if d.is_dir()]
        available_str = ", ".join(available_dirs[:5])  # Show first 5
        if len(available_dirs) > 5:
            available_str += f", ... ({len(available_dirs) - 5} more)"
        raise CheckpointNotFoundError(
            f"No checkpoint subdirectories found in {main_dir}",
            main_dir=main_dir,
            available_directories=available_dirs,
        )
    checkpoint_dirs.sort(key=lambda d: _extract_iteration_from_subdir(d.name) or -1)
    return checkpoint_dirs[-1]


[docs] def resolve_checkpoint_path( main_dir: Path | str, subdirectory: str | None = None ) -> Path: """Resolve the path to a checkpoint subdirectory. Args: main_dir (Path | str): Main checkpoint directory. subdirectory (str | None): Specific checkpoint subdirectory to load (e.g., "checkpoint_001"). If None, loads the latest checkpoint based on iteration number. Returns: Path: Path to the checkpoint subdirectory. Raises: CheckpointNotFoundError: If the main directory or specified subdirectory does not exist. """ main_path = Path(main_dir) if not main_path.exists(): raise CheckpointNotFoundError( f"Checkpoint directory not found: {main_path}", main_dir=main_path, ) # Determine which subdirectory to load if subdirectory is None: checkpoint_path = _find_latest_checkpoint_subdir(main_path) else: checkpoint_path = main_path / subdirectory if not checkpoint_path.exists(): # Provide helpful error with available checkpoints available = [ d.name for d in main_path.iterdir() if d.is_dir() and d.name.startswith(SUBDIR_PREFIX) ] raise CheckpointNotFoundError( f"Checkpoint subdirectory not found: {checkpoint_path}", main_dir=main_path, available_directories=available, ) return checkpoint_path
[docs] class CheckpointError(Exception): """Base exception for checkpoint-related errors."""
[docs] class CheckpointNotFoundError(CheckpointError): """Raised when a checkpoint directory or file is not found.""" def __init__( self, message: str, main_dir: Path | None = None, available_directories: list[str] | None = None, ): super().__init__(message) self.main_dir = main_dir self.available_directories = available_directories or []
[docs] class CheckpointCorruptedError(CheckpointError): """Raised when a checkpoint file is corrupted or invalid.""" def __init__( self, message: str, file_path: Path | None = None, details: str | None = None ): super().__init__(message) self.file_path = file_path self.details = details
def _atomic_write(path: Path, content: str) -> None: """Write content to a file atomically using a temporary file and rename. This ensures that if the write is interrupted, the original file is not corrupted. Args: path (Path): Target file path. content (str): Content to write. Raises: OSError: If the file cannot be written. """ # Create temporary file in the same directory to ensure atomic rename works temp_file = path.with_suffix(path.suffix + ".tmp") try: with open(temp_file, "w") as f: f.write(content) # Atomic rename on POSIX systems temp_file.replace(path) except Exception as e: # Clean up temp file if it exists if temp_file.exists(): temp_file.unlink() raise OSError(f"Failed to write checkpoint file {path}: {e}") from e def _validate_checkpoint_json( path: Path, required_fields: Sequence[str] | None = None ) -> dict[str, Any]: """Validate that a checkpoint JSON file exists and is valid. Args: path (Path): Path to the JSON file. required_fields (list[str] | None): List of required top-level fields. Returns: dict[str, Any]: Parsed JSON data. Raises: CheckpointNotFoundError: If the file does not exist. CheckpointCorruptedError: If the file is invalid JSON or missing required fields. """ if not path.exists(): raise CheckpointNotFoundError( f"Checkpoint file not found: {path}", main_dir=path.parent, ) try: with open(path, "r") as f: data = json.load(f) except json.JSONDecodeError as e: raise CheckpointCorruptedError( f"Checkpoint file is not valid JSON: {path}", file_path=path, details=f"JSON decode error: {e}", ) from e except Exception as e: raise CheckpointCorruptedError( f"Failed to read checkpoint file: {path}", file_path=path, details=str(e), ) from e if required_fields: missing_fields = [field for field in required_fields if field not in data] if missing_fields: raise CheckpointCorruptedError( f"Checkpoint file is missing required fields: {path}", file_path=path, details=f"Missing fields: {', '.join(missing_fields)}", ) return data def _load_and_validate_pydantic_model( path: Path, model_class: type, required_fields: Sequence[str] | None = None, error_context: str | None = None, ) -> Any: """Load and validate a checkpoint JSON file with a Pydantic model. This function combines JSON validation, conversion to string, and Pydantic model validation into a single operation. Args: path (Path): Path to the JSON file. model_class (type): Pydantic model class to validate against. required_fields (list[str] | None): List of required top-level JSON fields. error_context (str | None): Additional context for error messages (e.g., "Program state" or "Pymoo optimizer state"). Returns: Any: Validated Pydantic model instance. Raises: CheckpointNotFoundError: If the file does not exist. CheckpointCorruptedError: If the file is invalid JSON, missing required fields, or fails Pydantic validation. """ try: json_data_dict = _validate_checkpoint_json( path, required_fields=required_fields ) # Convert dict back to JSON string for Pydantic json_data = json.dumps(json_data_dict) except CheckpointNotFoundError: raise CheckpointNotFoundError( f"Checkpoint file not found: {path}", main_dir=path.parent, ) except CheckpointCorruptedError: # Re-raise JSON validation errors as-is raise try: return model_class.model_validate_json(json_data) except Exception as e: context = f"{error_context} " if error_context else "" raise CheckpointCorruptedError( f"Failed to validate {context}checkpoint state: {path}", file_path=path, details=str(e), ) from e
[docs] @dataclass(frozen=True) class CheckpointConfig: """Configuration for checkpointing during optimization. Attributes: checkpoint_dir: Directory path for saving checkpoints. - If None: No checkpointing. - If Path: Uses that directory. checkpoint_interval: Save checkpoint every N iterations. If None, saves every iteration (if checkpoint_dir is set). """ checkpoint_dir: Path | None = None checkpoint_interval: int | None = None
[docs] @classmethod def with_timestamped_dir( cls, checkpoint_interval: int | None = None ) -> "CheckpointConfig": """Create CheckpointConfig with auto-generated directory name. Args: checkpoint_interval (int | None): Save checkpoint every N iterations. If None, saves every iteration (default). Returns: CheckpointConfig: A new CheckpointConfig with auto-generated directory. """ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") generated_dir = Path(f"checkpoint_{timestamp}") return cls( checkpoint_dir=generated_dir, checkpoint_interval=checkpoint_interval )
def _should_checkpoint(self, iteration: int) -> bool: """Determine if a checkpoint should be saved at the given iteration. Args: iteration (int): Current iteration number. Returns: bool: True if checkpointing is enabled and should occur at this iteration. """ if self.checkpoint_dir is None: return False if self.checkpoint_interval is None: return True return iteration % self.checkpoint_interval == 0
[docs] @dataclass(frozen=True) class CheckpointInfo: """Information about a checkpoint. Attributes: path: Path to the checkpoint subdirectory. iteration: Iteration number of this checkpoint. timestamp: Modification time of the checkpoint directory. size_bytes: Total size of the checkpoint in bytes. is_valid: Whether the checkpoint is valid (has required files). """ path: Path iteration: int timestamp: datetime size_bytes: int is_valid: bool
def _calculate_checkpoint_size(checkpoint_path: Path) -> int: """Calculate total size of a checkpoint directory in bytes. Args: checkpoint_path (Path): Path to checkpoint subdirectory. Returns: int: Total size in bytes. """ total_size = 0 if checkpoint_path.exists(): for file_path in checkpoint_path.rglob("*"): if file_path.is_file(): total_size += file_path.stat().st_size return total_size def _is_checkpoint_valid(checkpoint_path: Path) -> bool: """Check if a checkpoint directory contains required files. Args: checkpoint_path (Path): Path to checkpoint subdirectory. Returns: bool: True if checkpoint has required files, False otherwise. """ program_state = checkpoint_path / PROGRAM_STATE_FILE optimizer_state = checkpoint_path / OPTIMIZER_STATE_FILE return program_state.exists() and optimizer_state.exists()
[docs] def get_checkpoint_info(checkpoint_path: Path) -> CheckpointInfo: """Get information about a checkpoint. Args: checkpoint_path (Path): Path to the checkpoint subdirectory. Returns: CheckpointInfo: Information about the checkpoint. Raises: CheckpointNotFoundError: If the checkpoint directory does not exist. """ if not checkpoint_path.exists(): raise CheckpointNotFoundError( f"Checkpoint directory not found: {checkpoint_path}", main_dir=checkpoint_path.parent, ) if not checkpoint_path.is_dir(): raise CheckpointNotFoundError( f"Checkpoint path is not a directory: {checkpoint_path}", main_dir=checkpoint_path.parent, ) iteration = _extract_iteration_from_subdir(checkpoint_path.name) if iteration is None: raise ValueError( f"Invalid checkpoint directory name: {checkpoint_path.name}. " f"Expected format: {SUBDIR_PREFIX}XXX" ) # Get modification time mtime = checkpoint_path.stat().st_mtime timestamp = datetime.fromtimestamp(mtime) # Calculate size size_bytes = _calculate_checkpoint_size(checkpoint_path) # Check validity is_valid = _is_checkpoint_valid(checkpoint_path) return CheckpointInfo( path=checkpoint_path, iteration=iteration, timestamp=timestamp, size_bytes=size_bytes, is_valid=is_valid, )
[docs] def list_checkpoints(main_dir: Path) -> list[CheckpointInfo]: """List all checkpoints in a main checkpoint directory. Args: main_dir (Path): Main checkpoint directory. Returns: list[CheckpointInfo]: List of checkpoint information, sorted by iteration number. Raises: CheckpointNotFoundError: If the main directory does not exist. """ if not main_dir.exists(): raise CheckpointNotFoundError( f"Checkpoint directory not found: {main_dir}", main_dir=main_dir, ) if not main_dir.is_dir(): raise CheckpointNotFoundError( f"Path is not a directory: {main_dir}", main_dir=main_dir, ) checkpoints = [] for subdir in main_dir.iterdir(): if not subdir.is_dir(): continue iteration = _extract_iteration_from_subdir(subdir.name) if iteration is None: continue try: info = get_checkpoint_info(subdir) checkpoints.append(info) except (CheckpointNotFoundError, ValueError): # Skip invalid checkpoints continue # Sort by iteration number checkpoints.sort(key=lambda x: x.iteration) return checkpoints
[docs] def get_latest_checkpoint(main_dir: Path) -> Path | None: """Get the path to the latest checkpoint. Args: main_dir (Path): Main checkpoint directory. Returns: Path | None: Path to the latest checkpoint, or None if no checkpoints exist. """ try: return _find_latest_checkpoint_subdir(main_dir) except CheckpointNotFoundError: return None
[docs] def cleanup_old_checkpoints(main_dir: Path, keep_last_n: int) -> None: """Remove old checkpoints, keeping only the most recent N. Args: main_dir (Path): Main checkpoint directory. keep_last_n (int): Number of most recent checkpoints to keep. Raises: ValueError: If keep_last_n is less than 1. CheckpointNotFoundError: If the main directory does not exist. """ if keep_last_n < 1: raise ValueError("keep_last_n must be at least 1") checkpoints = list_checkpoints(main_dir) if len(checkpoints) <= keep_last_n: return # Sort by iteration (descending) and remove oldest checkpoints.sort(key=lambda x: x.iteration, reverse=True) to_remove = checkpoints[keep_last_n:] for checkpoint_info in to_remove: # Remove directory and all contents shutil.rmtree(checkpoint_info.path)