Source code for divi.viz._gradients
# SPDX-FileCopyrightText: 2026 Qoro Quantum Ltd <divi@qoroquantum.de>
#
# SPDX-License-Identifier: Apache-2.0
"""Gradient computation strategies for :mod:`divi.viz`.
Provides finite-difference and parameter-shift gradient methods, used by
:func:`~divi.viz.compute_hessian` and :func:`~divi.viz.run_neb`.
"""
from enum import Enum
import numpy as np
import numpy.typing as npt
[docs]
class GradientMethod(str, Enum):
"""Strategy for computing gradients in viz analysis functions.
``PARAMETER_SHIFT`` uses the parameter-shift rule (shift = π/2,
exact for standard quantum gates). ``FINITE_DIFFERENCE`` uses centered
finite differences with a configurable step size ``eps``.
"""
PARAMETER_SHIFT = "parameter_shift"
FINITE_DIFFERENCE = "finite_difference"
_PARAM_SHIFT = 0.5 * np.pi
def _finite_difference_gradients(
evaluate_fn,
pivots: npt.NDArray[np.float64],
eps: float,
) -> npt.NDArray[np.float64]:
"""Compute gradients via centered finite differences (shift = *eps*)."""
m, d = pivots.shape
eye = eps * np.eye(d, dtype=np.float64)
pivots_exp = pivots[:, np.newaxis, :] # (m, 1, d)
plus = (pivots_exp + eye).reshape(m * d, d)
minus = (pivots_exp - eye).reshape(m * d, d)
probes = np.empty((2 * m * d, d), dtype=np.float64)
probes[0::2] = plus
probes[1::2] = minus
losses = evaluate_fn(probes)
return ((losses[0::2] - losses[1::2]) / (2.0 * eps)).reshape(m, d)
def _parameter_shift_gradients(
evaluate_fn,
pivots: npt.NDArray[np.float64],
) -> npt.NDArray[np.float64]:
"""Compute gradients via the parameter-shift rule (shift = π/2)."""
m, d = pivots.shape
eye = _PARAM_SHIFT * np.eye(d, dtype=np.float64)
pivots_exp = pivots[:, np.newaxis, :]
plus = (pivots_exp + eye).reshape(m * d, d)
minus = (pivots_exp - eye).reshape(m * d, d)
probes = np.empty((2 * m * d, d), dtype=np.float64)
probes[0::2] = plus
probes[1::2] = minus
losses = evaluate_fn(probes)
return (0.5 * (losses[0::2] - losses[1::2])).reshape(m, d)
def _compute_gradients(
evaluate_fn,
pivots: npt.NDArray[np.float64],
method: GradientMethod,
eps: float,
) -> npt.NDArray[np.float64]:
"""Dispatch gradient computation to the chosen method."""
if method is GradientMethod.PARAMETER_SHIFT:
return _parameter_shift_gradients(evaluate_fn, pivots)
return _finite_difference_gradients(evaluate_fn, pivots, eps)