# SPDX-FileCopyrightText: 2025-2026 Qoro Quantum Ltd <divi@qoroquantum.de>
#
# SPDX-License-Identifier: Apache-2.0
"""Weighted matching problem for QAOA-based quantum optimization."""
import warnings
from collections.abc import Callable, Hashable
from functools import partial
from typing import Any, Literal
import networkx as nx
import numpy as np
import scipy.sparse.linalg as spla
from divi.qprog.problems import BinaryOptimizationProblem, QAOAProblem
# ------------------------------------------------------------------
# Matching utility functions
# ------------------------------------------------------------------
def _construct_matching_qubo(
graph: nx.Graph,
edge_to_qubit: dict[tuple, int],
penalty_scale: float = 10.0,
) -> np.ndarray:
"""Build a QUBO matrix encoding the maximum-weight matching problem.
Linear terms ``-w_e`` maximize edge weight. Quadratic penalty terms
``+lambda`` for each pair of incident edges enforce the matching
constraint (at most one edge per node).
Args:
graph: Weighted graph.
edge_to_qubit: Mapping from ``(u, v)`` edge tuples to qubit indices.
penalty_scale: Multiplier for the penalty strength. The actual
penalty is ``penalty_scale * sum(all_edge_weights)``.
Returns:
Symmetric QUBO matrix of shape ``(n_edges, n_edges)``.
"""
n = len(set(edge_to_qubit.values()))
qubo = np.zeros((n, n), dtype=float)
total_weight = sum(d.get("weight", 1.0) for _, _, d in graph.edges(data=True))
penalty = penalty_scale * total_weight
# Linear terms: -w_e on the diagonal
for (u, v), idx in edge_to_qubit.items():
if u > v:
continue # skip reverse entries
w = graph[u][v].get("weight", 1.0)
qubo[idx, idx] = -w
# Quadratic terms: +penalty for pairs of incident edges
edges_by_idx = {}
for (u, v), idx in edge_to_qubit.items():
if u > v:
continue
edges_by_idx[idx] = (u, v)
node_to_qubits: dict[Any, list[int]] = {}
for idx, (u, v) in edges_by_idx.items():
node_to_qubits.setdefault(u, []).append(idx)
node_to_qubits.setdefault(v, []).append(idx)
for _node, qubits in node_to_qubits.items():
for i in range(len(qubits)):
for j in range(i + 1, len(qubits)):
qi, qj = qubits[i], qubits[j]
qubo[qi, qj] += penalty / 2
qubo[qj, qi] += penalty / 2
return qubo
def _sort_matching(matching: list[tuple]) -> list[tuple]:
"""Canonical sort: sort nodes within each edge, then sort edges."""
return sorted(tuple(sorted(edge)) for edge in matching)
[docs]
def is_valid_matching(edges: list[tuple]) -> bool:
"""Check that no node appears in more than one selected edge."""
seen: set = set()
for u, v in edges:
if u in seen or v in seen:
return False
seen.add(u)
seen.add(v)
return True
def _bitstring_to_matching(
bitstring: str, edge_to_qubit: dict[tuple, int]
) -> list[tuple]:
"""Decode a measurement bitstring into a list of matching edges.
Uses left-to-right qubit ordering: ``bitstring[i]`` corresponds to
qubit *i* of the cost Hamiltonian.
"""
matching = []
for edge, qubit in edge_to_qubit.items():
if edge[0] > edge[1]:
continue # skip reverse entries
if bitstring[qubit] == "1":
matching.append(edge)
return _sort_matching(matching)
[docs]
def check_matching_matrix(M: np.ndarray, A: np.ndarray) -> bool:
"""Validate that adjacency matrix *M* is a valid matching in graph *A*.
Checks:
1. ``M`` has no edges where ``A`` has none.
2. Each row and column sum of ``M`` is at most 1.
"""
if np.any(M[A == 0] != 0):
return False
row_sums = M.sum(axis=1)
col_sums = M.sum(axis=0)
return bool(np.all(row_sums <= 1) and np.all(col_sums <= 1))
# ------------------------------------------------------------------
# Edge-based graph partitioning
# ------------------------------------------------------------------
def _partition_graph_by_edges(
graph: nx.Graph,
max_edges: int,
algorithm: Literal["kernighan_lin", "spectral"] = "kernighan_lin",
seed: int | None = None,
) -> list[nx.Graph]:
"""Recursively partition a graph until each subgraph has <= *max_edges* edges.
Args:
graph: The graph to partition.
max_edges: Maximum number of edges per partition.
algorithm: ``"kernighan_lin"`` (weight-aware) or ``"spectral"``
(topology-based Fiedler vector).
seed: Random seed for reproducibility.
Returns:
List of subgraph copies.
"""
if graph.size() <= max_edges:
return [graph.copy()]
if graph.number_of_nodes() < 2:
return [graph.copy()]
if algorithm == "kernighan_lin":
part_a, part_b = _kl_bisect(graph, seed=seed)
elif algorithm == "spectral":
part_a, part_b = _spectral_bisect(graph)
else:
raise ValueError(
f"Unsupported partitioning algorithm: {algorithm!r}. "
"Supported: 'kernighan_lin', 'spectral'."
)
sg_a = graph.subgraph(part_a).copy()
sg_b = graph.subgraph(part_b).copy()
# Bail out if bisection made no progress (e.g. degenerate Fiedler vector)
if not part_b or sg_a.size() == graph.size():
return [graph.copy()]
result = []
for sg in (sg_a, sg_b):
if sg.size() == 0:
continue
result.extend(_partition_graph_by_edges(sg, max_edges, algorithm, seed=seed))
return result
def _kl_bisect(graph: nx.Graph, seed: int | None = None) -> tuple[set, set]:
"""Kernighan-Lin bisection with weight-negated edges.
Negates edge weights so KL preferentially cuts low-weight edges,
keeping high-weight edges within partitions.
"""
G_neg = graph.copy()
max_w = max(
(d.get("weight", 1.0) for _, _, d in G_neg.edges(data=True)),
default=1.0,
)
for u, v, d in G_neg.edges(data=True):
d["kl_weight"] = max_w + 1 - d.get("weight", 1.0)
part_a, part_b = nx.community.kernighan_lin_bisection(
G_neg, weight="kl_weight", seed=seed
)
return set(part_a), set(part_b)
def _spectral_bisect(graph: nx.Graph) -> tuple[set, set]:
"""Fiedler-vector bisection on the graph Laplacian."""
L = nx.laplacian_matrix(graph).astype(float)
_eigenvalues, eigenvectors = spla.eigsh(L, k=2, which="SM")
fiedler = eigenvectors[:, 1]
median = np.median(fiedler)
nodes = list(graph.nodes())
part_a = {nodes[i] for i in range(len(nodes)) if fiedler[i] <= median}
part_b = set(nodes) - part_a
return part_a, part_b
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _count_conflicts(solution: list[int], edges: list[tuple]) -> int:
"""Count matching constraint violations in a solution vector."""
node_count: dict = {}
for idx, bit in enumerate(solution):
if bit:
u, v = edges[idx]
node_count[u] = node_count.get(u, 0) + 1
node_count[v] = node_count.get(v, 0) + 1
return sum(max(0, c - 1) for c in node_count.values())
def _classical_cleanup(
solution: list[int],
graph: nx.Graph,
edges: list[tuple],
edge_to_qubit: dict[tuple, int],
) -> list[int]:
"""Fill unmatched nodes using exact classical matching on the residual graph.
Identifies nodes not covered by the quantum solution, builds the
residual subgraph, and runs :func:`~networkx.algorithms.matching.max_weight_matching` on it.
"""
matched_nodes: set = set()
for idx, bit in enumerate(solution):
if bit:
u, v = edges[idx]
matched_nodes.add(u)
matched_nodes.add(v)
residual_nodes = [n for n in graph.nodes() if n not in matched_nodes]
if not residual_nodes:
return solution
residual = graph.subgraph(residual_nodes)
if residual.number_of_edges() == 0:
return solution
extra_edges = nx.max_weight_matching(residual, maxcardinality=False)
result = list(solution)
for u, v in extra_edges:
key = (u, v) if (u, v) in edge_to_qubit else (v, u)
if key in edge_to_qubit:
result[edge_to_qubit[key]] = 1
return result
def _repair_matching(edges: list[tuple], graph: nx.Graph) -> list[tuple]:
"""Greedily repair an invalid matching by keeping highest-weight edges first."""
weighted = sorted(
edges,
key=lambda e: graph[e[0]][e[1]].get("weight", 1.0),
reverse=True,
)
valid: list[tuple] = []
used: set = set()
for u, v in weighted:
if u not in used and v not in used:
valid.append((u, v))
used.add(u)
used.add(v)
return valid
# ------------------------------------------------------------------
# MaxWeightMatchingProblem
# ------------------------------------------------------------------
[docs]
class MaxWeightMatchingProblem(QAOAProblem):
"""Maximum-weight matching problem for QAOA.
Given a weighted graph, finds a set of edges (matching) that maximizes
total weight while ensuring no two selected edges share a node.
Can be used directly with :class:`~divi.qprog.algorithms.QAOA` for
small graphs, or with
:class:`~divi.qprog.workflows.PartitioningProgramEnsemble` for large
graphs via edge-based partitioning.
Args:
graph: Weighted undirected graph.
penalty_scale: Strength of matching constraint penalties in the
QUBO formulation. Higher values enforce constraints more
strictly.
max_edges_per_partition: Maximum edges per partition. Setting
this enables :meth:`decompose` for partitioned solving.
partition_algorithm: Edge partitioning strategy.
``"kernighan_lin"`` (default, weight-aware) or ``"spectral"``.
use_classical_cleanup: If ``True`` (default), fill unmatched
residual nodes via :func:`~networkx.algorithms.matching.max_weight_matching` during
:meth:`postprocess_candidates`.
seed: Random seed for partitioning reproducibility.
Example::
from divi.qprog.problems import MaxWeightMatchingProblem
from divi.qprog import QAOA
from divi.qprog.optimizers import ScipyOptimizer, ScipyMethod
from divi.backends import MaestroSimulator
import networkx as nx
G = nx.gnm_random_graph(8, 12, seed=42)
for u, v in G.edges():
G[u][v]["weight"] = 1.0
problem = MaxWeightMatchingProblem(G, penalty_scale=10.0)
qaoa = QAOA(problem, n_layers=2,
optimizer=ScipyOptimizer(method=ScipyMethod.COBYLA),
max_iterations=20,
backend=MaestroSimulator())
qaoa.run()
"""
def __init__(
self,
graph: nx.Graph,
penalty_scale: float = 10.0,
*,
max_edges_per_partition: int | None = None,
partition_algorithm: Literal["kernighan_lin", "spectral"] = "kernighan_lin",
use_classical_cleanup: bool = True,
seed: int | None = None,
):
self._graph = graph
self._penalty_scale = penalty_scale
self._max_edges_per_partition = max_edges_per_partition
self._partition_algorithm = partition_algorithm
self._use_classical_cleanup = use_classical_cleanup
self._seed = seed
# Build edge-to-qubit mapping (canonical: u < v)
self._edges = [(u, v) if u < v else (v, u) for u, v in graph.edges()]
self._edge_to_qubit: dict[tuple, int] = {}
for i, (u, v) in enumerate(self._edges):
self._edge_to_qubit[(u, v)] = i
self._edge_to_qubit[(v, u)] = i
# Build full-graph QUBO and delegate Hamiltonian to BinaryOptimizationProblem
qubo_matrix = _construct_matching_qubo(
graph, self._edge_to_qubit, penalty_scale
)
self._bop = BinaryOptimizationProblem(qubo_matrix)
# Decomposition state (populated by decompose())
self._edge_index_maps: dict[Hashable, list[int]] = {}
# ------------------------------------------------------------------
# QAOAProblem interface (delegated to internal BinaryOptimizationProblem)
# ------------------------------------------------------------------
@property
def cost_hamiltonian(self):
return self._bop.cost_hamiltonian
@property
def mixer_hamiltonian(self):
return self._bop.mixer_hamiltonian
@property
def loss_constant(self) -> float:
return self._bop.loss_constant
@property
def decode_fn(self) -> Callable[[str], list[tuple]]:
return partial(_bitstring_to_matching, edge_to_qubit=self._edge_to_qubit)
@property
def graph(self) -> nx.Graph:
"""The input graph."""
return self._graph
[docs]
def is_feasible(self, bitstring: str) -> bool:
"""Check that the decoded matching has no node appearing in more than one edge."""
matching = self.decode_fn(bitstring)
return is_valid_matching(matching)
[docs]
def compute_energy(self, bitstring: str) -> float | None:
"""Compute matching weight (negated, since lower is better).
Returns ``None`` for infeasible bitstrings.
"""
matching = self.decode_fn(bitstring)
if not is_valid_matching(matching):
return None
weight = sum(self._graph[u][v].get("weight", 1.0) for u, v in matching)
return -weight
# ------------------------------------------------------------------
# Decomposition hooks
# ------------------------------------------------------------------
[docs]
def decompose(self) -> dict[Hashable, QAOAProblem]:
if self._max_edges_per_partition is None:
raise ValueError(
"Cannot decompose: max_edges_per_partition was not set at construction."
)
subgraphs = _partition_graph_by_edges(
self._graph,
max_edges=self._max_edges_per_partition,
algorithm=self._partition_algorithm,
seed=self._seed,
)
self._edge_index_maps = {}
sub_problems: dict[Hashable, QAOAProblem] = {}
for i, subgraph in enumerate(subgraphs):
prog_id = (f"P{i}", subgraph.size())
# Local edge-to-qubit mapping for this partition
local_edges = [(u, v) if u < v else (v, u) for u, v in subgraph.edges()]
local_e2q: dict[tuple, int] = {}
for j, (u, v) in enumerate(local_edges):
local_e2q[(u, v)] = j
local_e2q[(v, u)] = j
# Map local indices → global indices
self._edge_index_maps[prog_id] = [
self._edge_to_qubit[e] for e in local_edges
]
# Build per-partition QUBO
qubo = _construct_matching_qubo(subgraph, local_e2q, self._penalty_scale)
sub_problems[prog_id] = BinaryOptimizationProblem(qubo)
return sub_problems
[docs]
def initial_solution_size(self) -> int:
return len(self._edges)
[docs]
def extend_solution(
self,
current_solution: list[int],
prog_id: Hashable,
candidate_decoded: list[int],
) -> list[int]:
extended = list(current_solution)
global_indices = self._edge_index_maps[prog_id]
for local_idx, global_idx in enumerate(global_indices):
extended[global_idx] = int(candidate_decoded[local_idx])
return extended
[docs]
def evaluate_global_solution(self, solution: list[int]) -> float:
"""Score a solution: negative (weight - conflict_penalty * conflicts).
Lower is better for beam search. Maximizing weight while minimizing
conflicts.
"""
weight = 0.0
for idx, bit in enumerate(solution):
if bit:
u, v = self._edges[idx]
weight += self._graph[u][v].get("weight", 1.0)
conflicts = _count_conflicts(solution, self._edges)
avg_weight = sum(
d.get("weight", 1.0) for _, _, d in self._graph.edges(data=True)
) / max(self._graph.number_of_edges(), 1)
# Negate: beam search keeps lowest scores
return -(weight - avg_weight * conflicts)
def _postprocess_solution(self, solution: list[int]) -> tuple[list[tuple], float]:
"""Repair conflicts, apply cleanup, compute weight."""
# Repair first (fix conflicts), then cleanup (fill gaps)
matching = [self._edges[i] for i, bit in enumerate(solution) if bit]
if not is_valid_matching(matching):
matching = _repair_matching(matching, self._graph)
# Rebuild solution vector from repaired matching
solution = [0] * len(self._edges)
for edge in matching:
solution[self._edge_to_qubit[edge]] = 1
if self._use_classical_cleanup:
solution = _classical_cleanup(
solution, self._graph, self._edges, self._edge_to_qubit
)
matching = [self._edges[i] for i, bit in enumerate(solution) if bit]
weight = sum(self._graph[u][v].get("weight", 1.0) for u, v in matching)
return _sort_matching(matching), weight
def _decode_matching_without_repair(
self, solution: list[int]
) -> tuple[list[tuple], float]:
"""Decode a raw solution without repair or classical cleanup."""
matching = [self._edges[i] for i, bit in enumerate(solution) if bit]
weight = sum(self._graph[u][v].get("weight", 1.0) for u, v in matching)
return _sort_matching(matching), weight
[docs]
def postprocess_candidates(
self, candidates: list[tuple[float, list[int]]], *, strict: bool = False
) -> list[tuple[list[tuple], float]]:
"""Post-process matching candidates, optionally hard-filtering invalid ones.
With ``strict=False``, invalid raw candidates are repaired and may be
improved by classical cleanup. With ``strict=True``, invalid raw
candidates are discarded before repair or cleanup.
"""
if strict:
formatted = []
for _, solution in candidates:
matching = [self._edges[i] for i, bit in enumerate(solution) if bit]
if is_valid_matching(matching):
formatted.append(self._decode_matching_without_repair(solution))
if not formatted:
warnings.warn(
"No valid matching candidates found under strict=True. "
"Consider widening beam_width / n_partition_candidates, "
"or running with strict=False to inspect repaired output.",
UserWarning,
stacklevel=2,
)
else:
formatted = []
invalid_seen = False
for _score, solution in candidates:
matching = [self._edges[i] for i, bit in enumerate(solution) if bit]
if not is_valid_matching(matching):
invalid_seen = True
formatted.append(self._postprocess_solution(solution))
if invalid_seen:
warnings.warn(
"At least one partition aggregate was not a valid matching "
"and was repaired. Use get_top_solutions(..., strict=True) "
"to discard invalid raw candidates instead.",
UserWarning,
stacklevel=2,
)
# Sort by weight descending, then deduplicate
formatted.sort(key=lambda x: x[1], reverse=True)
seen: set[tuple] = set()
deduped = []
for edges, w in formatted:
key = tuple(edges)
if key not in seen:
seen.add(key)
deduped.append((edges, w))
return deduped