Source code for divi.pipeline.transformations

# SPDX-FileCopyrightText: 2025-2026 Qoro Quantum Ltd <divi@qoroquantum.de>
#
# SPDX-License-Identifier: Apache-2.0

"""Shared reduction/grouping helpers for pipeline stages.

Stages that fan out along an axis (ham, qem, obs_group, etc.) repeatedly:
- strip that axis from child labels to get a base key,
- group child results by base key (sometimes keyed by axis index for ordering),
- then reduce (e.g. mean or a postprocess over ordered values).

This module centralizes that logic for easier DevX and consistency.
"""

from collections.abc import Callable
from typing import Any

from divi.pipeline.abc import ChildResults

#: Key injected into scoped tokens by :func:`_reduce_with_isolated_axes`
#: so that stages can identify which foreign-axis group is being reduced.
FOREIGN_KEY_ATTR = "_foreign_key"


def strip_axis_from_label(
    child_label: tuple[Any, ...], axis_name: str
) -> tuple[Any, ...]:
    """Remove the (axis_name, value) pair from a child label to get the base key.

    Child labels are sequences of (axis_name, value) pairs. This returns
    the same tuple with any element whose first element equals
    *axis_name* removed.

    Example::

        >>> strip_axis_from_label((('ham', 0), ('obs', 1), ('qem', 2)), 'obs')
        (('ham', 0), ('qem', 2))
    """
    return tuple(
        element
        for element in child_label
        if not (
            isinstance(element, tuple) and len(element) >= 1 and element[0] == axis_name
        )
    )


def group_by_base_key(
    results: ChildResults,
    axis_name: str,
    *,
    indexed: bool = False,
) -> dict[tuple[Any, ...], Any]:
    """Group child results by base key (label with axis stripped).

    Args:
        results: Child label -> value mapping from the pipeline.
        axis_name: Axis to strip from labels to form base_key.
        indexed: If False, values are collected into a list per base_key.
            If True, values are stored in a dict[int, value] keyed by the
            axis value (parsed as int) so they can be ordered later.

    Returns:
        - If indexed=False: ``dict[base_key, list[value]]``
        - If indexed=True: ``dict[base_key, dict[int, value]]``

    Example::

        >>> results = {(('circ', 0), ('obs', 0)): 1.5, (('circ', 0), ('obs', 1)): 2.0}
        >>> group_by_base_key(results, 'obs')
        {(('circ', 0),): [1.5, 2.0]}
        >>> group_by_base_key(results, 'obs', indexed=True)
        {(('circ', 0),): {0: 1.5, 1: 2.0}}
    """
    if not indexed:
        grouped: dict[tuple[Any, ...], list[Any]] = {}
        for child_label, child_value in results.items():
            base_key = strip_axis_from_label(child_label, axis_name)
            grouped.setdefault(base_key, []).append(child_value)
        return grouped

    grouped_indexed: dict[tuple[Any, ...], dict[int, Any]] = {}
    for child_label, child_value in results.items():
        axis_values = dict(child_label)
        axis_idx = int(axis_values[axis_name])
        base_key = strip_axis_from_label(child_label, axis_name)
        grouped_indexed.setdefault(base_key, {})[axis_idx] = child_value
    return grouped_indexed


def reduce_mean(
    grouped: dict[tuple[Any, ...], list[Any]],
) -> ChildResults:
    """Reduce grouped values by averaging (e.g. Trotter ham samples).

    Each entry's values may be scalars (the standard case — averaged
    arithmetically) or per-observable lists of equal length emitted by a
    multi-observable :class:`MeasurementStage` postprocess (averaged
    element-wise so each observable's mean is preserved).

    Example::

        >>> reduce_mean({(('circ', 0),): [1.0, 3.0]})
        {(('circ', 0),): 2.0}
        >>> reduce_mean({(('circ', 0),): [[1.0, 5.0], [3.0, 7.0]]})
        {(('circ', 0),): [2.0, 6.0]}
    """
    out: ChildResults = {}
    for base_key, values in grouped.items():
        if values and isinstance(values[0], list):
            n = len(values)
            n_obs = len(values[0])
            out[base_key] = [sum(v[i] for v in values) / n for i in range(n_obs)]
        else:
            out[base_key] = sum(values) / len(values)
    return out


def reduce_postprocess_ordered(
    grouped: dict[tuple[Any, ...], dict[int, Any]],
    postprocess_fn: (
        Callable[[list[Any]], Any] | dict[tuple[Any, ...], Callable[[list[Any]], Any]]
    ),
) -> ChildResults:
    """Reduce grouped index->value dicts by sorting by index and calling a postprocess function.

    For each base_key, values are ordered by their integer index and passed
    to the postprocess function. Use a single callable for all keys (e.g. QEM)
    or a dict mapping base_key -> callable for per-spec postprocessing
    (e.g. observable grouping).

    Example::

        >>> grouped = {(('circ', 0),): {0: 10.0, 1: 20.0}}
        >>> reduce_postprocess_ordered(grouped, sum)
        {(('circ', 0),): 30.0}
    """
    reduced: ChildResults = {}
    for base_key, values_by_index in grouped.items():
        ordered = [v for _, v in sorted(values_by_index.items())]
        fn = (
            postprocess_fn[base_key]
            if isinstance(postprocess_fn, dict)
            else postprocess_fn
        )
        reduced[base_key] = fn(ordered)
    return reduced


[docs] def reduce_merge_histograms( grouped: dict[tuple[Any, ...], list[dict[str, float]]], ) -> ChildResults: """Reduce grouped probability dicts by averaging across groups. Equivalent to the VQA ``_average_probabilities`` logic: for each base_key, collects all probability dicts, unions all bitstrings, and averages the probability values. Used by ``TrotterSpecStage`` in measurement pipelines to merge probability histograms across Hamiltonian samples. Example:: >>> grouped = {(('circ', 0),): [{'00': 0.6, '11': 0.4}, {'00': 0.8, '11': 0.2}]} >>> reduce_merge_histograms(grouped) {(('circ', 0),): {'00': 0.7, '11': 0.3}} """ reduced: ChildResults = {} for base_key, prob_dicts in grouped.items(): if not prob_dicts: reduced[base_key] = {} continue all_bitstrings: set[str] = set() for probs in prob_dicts: all_bitstrings.update(probs.keys()) n = len(prob_dicts) reduced[base_key] = { bs: sum(p.get(bs, 0.0) for p in prob_dicts) / n for bs in all_bitstrings } return reduced