Source code for divi.qprog.problems._graphs

# SPDX-FileCopyrightText: 2025-2026 Qoro Quantum Ltd <divi@qoroquantum.de>
#
# SPDX-License-Identifier: Apache-2.0

"""Graph problem classes for QAOA."""

from collections.abc import Callable, Hashable
from typing import Any
from warnings import warn

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from qiskit.quantum_info import SparsePauliOp

from divi.hamiltonians._term_ops import _clean_hamiltonian_spo
from divi.qprog import GraphProblemTypes
from divi.qprog.algorithms import (
    InitialState,
    OnesState,
    SuperpositionState,
    ZerosState,
)
from divi.qprog.problems import GraphPartitioningConfig, QAOAProblem
from divi.qprog.problems._graph_hamiltonians import (
    max_clique_hamiltonians,
    max_independent_set_hamiltonians,
    max_weight_cycle_hamiltonians,
    maxcut_hamiltonians,
    min_vertex_cover_hamiltonians,
)
from divi.qprog.problems._graph_partitioning_utils import _node_partition_graph


class _GraphProblemBase(QAOAProblem):
    """Shared logic for graph problems built directly from ``SparsePauliOp``.

    Subclasses set ``_resolver`` (a function returning ``(cost_spo, mixer_spo)``
    or ``(cost_spo, mixer_spo, metadata)``) and the two ``_*_state_cls`` class
    attributes, then call ``super().__init__``.
    """

    _resolver: staticmethod
    _constrained_state_cls: type[InitialState]
    _unconstrained_state_cls: type[InitialState]

    def __init__(
        self,
        graph: GraphProblemTypes,
        *,
        is_constrained: bool = True,
        config: GraphPartitioningConfig | None = None,
    ):
        self._graph = graph
        self._is_constrained = is_constrained

        cost_spo, self._mixer_hamiltonian, *self._metadata = self._resolve(
            graph, is_constrained
        )

        cleaned, ham_constant = _clean_hamiltonian_spo(cost_spo, raise_on_constant=True)

        self._cost_hamiltonian = cleaned
        self._loss_constant = ham_constant
        self._wire_labels = self._compute_wire_labels(graph)
        self._initial_state = (
            self._constrained_state_cls
            if is_constrained
            else self._unconstrained_state_cls
        )()
        self._config = config
        self._reverse_index_maps = {}

    @classmethod
    def _resolve(cls, graph, is_constrained):
        """Build cost/mixer SPOs for this problem type."""
        try:
            return cls._resolver(graph, constrained=is_constrained)
        except TypeError:
            return cls._resolver(graph)

    @staticmethod
    def _compute_wire_labels(graph: GraphProblemTypes) -> tuple:
        """Map qubit positions back to original node values in node-iteration order."""
        if isinstance(graph, nx.Graph):
            return tuple(graph.nodes())
        # rustworkx graph: edge_list() / node values; mirror the relabeling done
        # inside the SPO builders.
        return tuple(graph.nodes())

    @property
    def graph(self) -> GraphProblemTypes:
        """The underlying graph."""
        return self._graph

    @property
    def cost_hamiltonian(self) -> SparsePauliOp:
        return self._cost_hamiltonian

    @property
    def mixer_hamiltonian(self) -> SparsePauliOp:
        return self._mixer_hamiltonian

    @property
    def wire_labels(self) -> tuple:
        return self._wire_labels

    @property
    def loss_constant(self) -> float:
        return self._loss_constant

    @property
    def recommended_initial_state(self) -> InitialState:
        return self._initial_state

    @property
    def decode_fn(self) -> Callable[[str], Any]:
        wires = self._wire_labels

        def _decode(bitstring: str) -> list:
            return [
                wires[idx]
                for idx, bit in enumerate(bitstring)
                if bit == "1" and idx < len(wires)
            ]

        return _decode

    @property
    def metadata(self) -> dict[str, Any]:
        return self._metadata[0] if self._metadata else {}

    def decompose(self) -> dict[Hashable, QAOAProblem]:
        if self._config is None:
            raise ValueError(
                "Cannot decompose: no config was provided at construction."
            )

        # Warn if this problem type has known partitioning risks
        tier = _PARTITIONING_COMPATIBILITY_TIERS.get(type(self))
        if tier is not None:
            risk_level, rationale = tier
            prefix = "High-risk" if risk_level == "high-risk" else "Heuristic-risk"
            detail = (
                "Aggregation is heuristic and may miss globally valid/high-quality "
                f"solutions because {rationale}"
                if risk_level == "high-risk"
                else "Results may be sensitive to partition boundaries because "
                f"{rationale}"
            )
            warn(
                f"{prefix} graph partitioning objective: "
                f"{type(self).__name__}. {detail}",
                UserWarning,
                stacklevel=2,
            )

        subgraphs = _node_partition_graph(
            self.graph,
            partitioning_config=self._config,
        )

        self._reverse_index_maps = {}
        sub_problems: dict[Hashable, QAOAProblem] = {}

        for i, (subgraph, cluster_ids) in enumerate(subgraphs):
            prog_id = (f"P{i}", len(subgraph))
            # ``cluster_ids[local_idx] == original_node_id``; the partitioner
            # has already relabeled each subgraph to ``0..M-1``.
            self._reverse_index_maps[prog_id] = dict(enumerate(cluster_ids))
            sub_problems[prog_id] = type(self)(
                subgraph, is_constrained=self._is_constrained
            )

        return sub_problems

    def initial_solution_size(self) -> int:
        return len(self.graph)

    def extend_solution(
        self,
        current_solution: list[int],
        prog_id: Hashable,
        candidate_decoded: list[int],
    ) -> list[int]:
        extended = list(current_solution)
        reverse_map = self._reverse_index_maps[prog_id]

        # Reset all positions belonging to this partition to 0
        for global_idx in reverse_map.values():
            extended[global_idx] = 0

        # Set positions for nodes in the candidate's decoded solution to 1
        for local_node in candidate_decoded:
            global_idx = reverse_map[local_node]
            extended[global_idx] = 1

        return extended

    def evaluate_global_solution(self, solution: list[int]) -> float:
        spo: SparsePauliOp = self.cost_hamiltonian
        energy = self.loss_constant
        for label, coeff in zip(spo.paulis.to_labels(), spo.coeffs):
            eigenvalue = 1.0
            for qubit, char in enumerate(reversed(label)):
                if char == "I":
                    continue
                if char != "Z":
                    raise ValueError(
                        f"Cost Hamiltonian contains non-diagonal term {label!r}; "
                        f"evaluate_global_solution requires Z-only operators."
                    )
                eigenvalue *= 1 - 2 * solution[qubit]
            energy += float(np.real(coeff)) * eigenvalue
        return energy

    def postprocess_candidates(
        self, candidates: list[tuple[float, list[int]]], *, strict: bool = False
    ) -> list[tuple[list[int], float]]:
        return [(list(np.where(solution)[0]), score) for score, solution in candidates]


[docs] class MaxCutProblem(_GraphProblemBase): """MaxCut problem on a graph. Args: graph: NetworkX or RustworkX graph. """ _resolver = staticmethod(maxcut_hamiltonians) # type: ignore[assignment, bad-override] _constrained_state_cls = SuperpositionState _unconstrained_state_cls = SuperpositionState
[docs] class MaxCliqueProblem(_GraphProblemBase): """Max clique problem on a graph. Args: graph: NetworkX or RustworkX graph. is_constrained: Use constrained mixer. Defaults to True. """ _resolver = staticmethod(max_clique_hamiltonians) # type: ignore[assignment, bad-override] _constrained_state_cls = ZerosState _unconstrained_state_cls = SuperpositionState
[docs] class MaxIndependentSetProblem(_GraphProblemBase): """Max independent set problem on a graph. Args: graph: NetworkX or RustworkX graph. is_constrained: Use constrained mixer. Defaults to True. """ _resolver = staticmethod(max_independent_set_hamiltonians) # type: ignore[assignment, bad-override] _constrained_state_cls = ZerosState _unconstrained_state_cls = SuperpositionState
[docs] class MinVertexCoverProblem(_GraphProblemBase): """Min vertex cover problem on a graph. Args: graph: NetworkX or RustworkX graph. is_constrained: Use constrained mixer. Defaults to True. """ _resolver = staticmethod(min_vertex_cover_hamiltonians) # type: ignore[assignment, bad-override] _constrained_state_cls = OnesState _unconstrained_state_cls = SuperpositionState
[docs] class MaxWeightCycleProblem(_GraphProblemBase): """Max weight cycle problem on a directed graph. Args: graph: NetworkX DiGraph or RustworkX PyDiGraph with weighted edges. is_constrained: Use cycle-mixer (preserves valid cycles). Defaults to True. """ _resolver = staticmethod(max_weight_cycle_hamiltonians) # type: ignore[assignment, bad-override] _constrained_state_cls = SuperpositionState _unconstrained_state_cls = SuperpositionState @staticmethod def _compute_wire_labels(graph: GraphProblemTypes) -> tuple: # Cycle problems use edge variables; wires are 0-indexed by edge count. if hasattr(graph, "number_of_edges"): return tuple(range(graph.number_of_edges())) return tuple(range(len(graph.edge_list()))) # type: ignore[attr-defined]
# Partitioning is most robust for cut-style objectives (e.g. MaxCut). # Structure-dependent objectives may lose cross-partition constraints. _PARTITIONING_COMPATIBILITY_TIERS = { MaxWeightCycleProblem: ( "high-risk", "partitioning can break cycles across cluster boundaries.", ), MaxCliqueProblem: ( "heuristic-risk", "partitioning can hide cross-partition adjacency needed for global cliques.", ), MaxIndependentSetProblem: ( "heuristic-risk", "partitioning can hide cross-partition conflicts between selected vertices.", ), MinVertexCoverProblem: ( "heuristic-risk", "partitioning can hide cross-partition edges that must be covered globally.", ), }
[docs] def draw_graph_solution_nodes(main_graph: nx.Graph, partition_nodes): """Visualize a graph with solution nodes highlighted. Draws the graph with nodes colored to distinguish solution nodes (red) from other nodes (light blue). Args: main_graph (nx.Graph): NetworkX graph to visualize. partition_nodes: Collection of node indices that are part of the solution. """ node_colors = [ "red" if node in partition_nodes else "lightblue" for node in main_graph.nodes() ] plt.figure(figsize=(10, 8)) pos = nx.spring_layout(main_graph) nx.draw_networkx_nodes(main_graph, pos, node_color=node_colors, node_size=500) nx.draw_networkx_edges(main_graph, pos) nx.draw_networkx_labels(main_graph, pos, font_size=10, font_weight="bold") plt.axis("off") plt.tight_layout() plt.show()