Source code for divi.qprog.early_stopping

# SPDX-FileCopyrightText: 2025-2026 Qoro Quantum Ltd <divi@qoroquantum.de>
#
# SPDX-License-Identifier: Apache-2.0

"""Early stopping utilities for variational quantum algorithms."""

from collections import deque
from enum import Enum

import numpy as np

__all__ = ["EarlyStopping", "StopReason"]


[docs] class StopReason(str, Enum): """Reason why early stopping was triggered. Inherits from ``str`` so that values serialize naturally to JSON and can be compared directly with plain strings. """ PATIENCE_EXCEEDED = "patience_exceeded" """Cost did not improve by at least ``min_delta`` for ``patience`` consecutive iterations.""" GRADIENT_BELOW_THRESHOLD = "gradient_below_threshold" """L2 norm of the gradient fell below ``grad_norm_threshold``.""" COST_VARIANCE_SETTLED = "cost_variance_settled" """Variance of recent cost values dropped below ``variance_threshold``."""
[docs] class EarlyStopping: """Early stopping controller for variational quantum algorithm optimization. Tracks optimization progress and signals when to stop based on configurable criteria. A single instance is created before the optimization loop and :meth:`check` is called once per iteration. Args: patience: Number of consecutive iterations with no improvement (by at least ``min_delta``) before stopping. Must be ≥ 1. min_delta: Minimum absolute decrease in ``best_loss`` that counts as an improvement. Must be ≥ 0. grad_norm_threshold: If not ``None``, stop when the L2 norm of the gradient drops below this value. Only effective when the optimizer exposes gradient information (e.g. ``ScipyOptimizer`` with ``L_BFGS_B``). variance_window: Number of recent cost values used to compute the rolling variance. Must be ≥ 2. variance_threshold: If not ``None``, stop when the variance of the last ``variance_window`` cost values drops below this value. Raises: ValueError: If any parameter violates its constraints. """ def __init__( self, patience: int = 5, min_delta: float = 1e-4, grad_norm_threshold: float | None = None, variance_window: int = 20, variance_threshold: float | None = None, ) -> None: if patience < 1: raise ValueError(f"patience must be >= 1, got {patience}") if min_delta < 0: raise ValueError(f"min_delta must be >= 0, got {min_delta}") if variance_window < 2: raise ValueError(f"variance_window must be >= 2, got {variance_window}") self.patience = patience self.min_delta = min_delta self.grad_norm_threshold = grad_norm_threshold self.variance_window = variance_window self.variance_threshold = variance_threshold # --- Internal state --- self._stale_count: int = 0 self._tracked_best: float = float("inf") self._loss_history: deque[float] = deque(maxlen=variance_window) # ------------------------------------------------------------------ # # Public API # ------------------------------------------------------------------ #
[docs] def check( self, current_loss: float, *, grad_norm: float | None = None, ) -> StopReason | None: """Evaluate all enabled stopping criteria. Must be called **once per iteration**, after loss (and optionally gradient) computation. Args: current_loss: The minimum loss value observed at this iteration. grad_norm: L2 norm of the current gradient vector, or ``None`` if gradient information is not available. Returns: A :class:`StopReason` if any criterion triggered, otherwise ``None`` (meaning optimization should continue). """ # 1. Patience -------------------------------------------------- if current_loss < self._tracked_best - self.min_delta: self._tracked_best = current_loss self._stale_count = 0 else: self._stale_count += 1 if self._stale_count >= self.patience: return StopReason.PATIENCE_EXCEEDED # 2. Gradient norm --------------------------------------------- if ( self.grad_norm_threshold is not None and grad_norm is not None and grad_norm < self.grad_norm_threshold ): return StopReason.GRADIENT_BELOW_THRESHOLD # 3. Cost variance --------------------------------------------- if self.variance_threshold is not None: self._loss_history.append(current_loss) if len(self._loss_history) >= self.variance_window: variance = float(np.var(self._loss_history)) if variance < self.variance_threshold: return StopReason.COST_VARIANCE_SETTLED return None
[docs] def reset(self) -> None: """Reset internal state so the instance can be reused.""" self._stale_count = 0 self._tracked_best = float("inf") self._loss_history.clear()