# SPDX-FileCopyrightText: 2025-2026 Qoro Quantum Ltd <divi@qoroquantum.de>
#
# SPDX-License-Identifier: Apache-2.0
import os
from enum import Enum
from queue import Empty, Queue
from threading import Event, Lock
from typing import Any, cast
from rich.live import Live
from rich.progress import (
BarColumn,
MofNCompleteColumn,
Progress,
ProgressColumn,
SpinnerColumn,
Task,
TaskID,
TextColumn,
TimeElapsedColumn,
)
from rich.text import Text
from divi.reporting._qlogger import _ensure_unbuffered_stdout
[docs]
class TerminalStatus(str, Enum):
"""Terminal status for a progress row; members equal their string value."""
SUCCESS = "Success"
FAILED = "Failed"
CANCELLED = "Cancelled"
ABORTED = "Aborted"
#: Color cycle assigned to flush groups in ensemble progress displays.
#:
#: Each running flush group is tinted with the next color in this tuple so
#: that progress rows can be visually associated with their participating
#: programs.
BATCH_COLORS = ("green", "cyan", "magenta", "yellow", "red", "blue")
_PROGRESS_DISABLE_ENV = "DIVI_DISABLE_PROGRESS"
_PROGRESS_DISABLE_TRUTHY = frozenset({"1", "true", "yes", "on"})
[docs]
def progress_disabled() -> bool:
"""Return True if ``DIVI_DISABLE_PROGRESS`` is set to a truthy value
(``1``, ``true``, ``yes``, ``on``; case-insensitive)."""
return (
os.getenv(_PROGRESS_DISABLE_ENV, "").strip().lower() in _PROGRESS_DISABLE_TRUTHY
)
class BatchIndicatorColumn(ProgressColumn):
"""Renders a colored square prefix to associate programs with their batch."""
def render(self, task):
color = task.fields.get("batch_color", "")
if color:
return Text("■ ", style=color)
return Text(" ")
class _UnfinishedTaskWrapper:
"""Wrapper that forces a task to appear unfinished for spinner animation."""
def __init__(self, task):
self._task = task
def __getattr__(self, name):
if name == "finished":
return False
return getattr(self._task, name)
class _ProgramOnlyColumn(ProgressColumn):
"""Wrap a stock column so it renders only on rows tagged
``row_kind="program"``; batch-row cells fall back to empty text."""
def __init__(self, inner: ProgressColumn):
super().__init__()
self._inner = inner
def render(self, task):
if task.fields.get("row_kind") == "program":
return self._inner.render(task)
return Text("")
class ConditionalSpinnerColumn(ProgressColumn):
_FINAL_STATUSES = frozenset(TerminalStatus)
def __init__(self):
super().__init__()
self.spinner = SpinnerColumn("point")
def render(self, task):
status = task.fields.get("final_status")
if status in self._FINAL_STATUSES:
return Text("")
# Force the task to appear unfinished for spinner animation
return self.spinner.render(cast(Task, _UnfinishedTaskWrapper(task)))
class PhaseStatusColumn(ProgressColumn):
_STATUS_MESSAGES = {
TerminalStatus.SUCCESS: ("• Success! ✅ ", "bold green"),
TerminalStatus.FAILED: ("• Failed! ❌ ", "bold red"),
TerminalStatus.CANCELLED: ("• Cancelled ⏹️ ", "bold yellow"),
TerminalStatus.ABORTED: ("• Aborted ⚠️ ", "dim magenta"),
}
def __init__(self, table_column=None):
super().__init__(table_column)
def _build_polling_string(
self, split_job_id: str, job_status: str, poll_attempt: int, max_retries: int
) -> str:
"""Build the polling status string for service job tracking."""
if job_status == "COMPLETED":
return f" [Job {split_job_id} is complete.]"
elif poll_attempt > 0:
return f" [Job {split_job_id} is {job_status}. Polling attempt {poll_attempt} / {max_retries}]"
return ""
@staticmethod
def _build_loss_string(loss: float | None) -> str:
"""Build a compact loss display when a numeric loss is present."""
if loss is None:
return ""
return f" [loss: {float(loss):.6f}]"
def render(self, task):
final_status = task.fields.get("final_status")
loss = task.fields.get("loss")
loss_str = self._build_loss_string(loss)
# Early return for final statuses
if final_status in self._STATUS_MESSAGES:
status_text, style = self._STATUS_MESSAGES[final_status]
detail = task.fields.get("message", "")
suffix = f" ({detail})" if detail else ""
return Text(f"{status_text}{suffix}{loss_str}", style=style)
# Build message with polling information
message = task.fields.get("message")
service_job_id = task.fields.get("service_job_id")
job_status = task.fields.get("job_status")
poll_attempt = task.fields.get("poll_attempt", 0)
max_retries = task.fields.get("max_retries")
polling_str = ""
split_job_id = None
if service_job_id is not None:
split_job_id = service_job_id.split("-")[0]
polling_str = self._build_polling_string(
split_job_id, job_status, poll_attempt, max_retries
)
msg_str = f"[{message}]" if message else ""
final_text = Text(f"{msg_str}{loss_str}{polling_str}")
# Highlight job ID if present
if split_job_id is not None:
final_text.highlight_words([split_job_id], "blue")
return final_text
[docs]
def make_progress_bar() -> Progress:
"""Create the unified Rich Progress bar.
Per-program rows render the full bar/M-of-N/elapsed columns; batch
rows render only the indicator/text/spinner/status columns thanks
to the :class:`_ProgramOnlyColumn` wrappers. Tasks distinguish
themselves via the ``row_kind`` field.
"""
return Progress(
BatchIndicatorColumn(),
TextColumn("[bold blue]{task.fields[job_name]}"),
_ProgramOnlyColumn(BarColumn()),
_ProgramOnlyColumn(MofNCompleteColumn()),
_ProgramOnlyColumn(TimeElapsedColumn()),
ConditionalSpinnerColumn(),
PhaseStatusColumn(),
)
[docs]
def make_progress_display(
is_jupyter: bool = False,
) -> tuple[Progress | None, Live | None]:
"""Create a ``Live``-wrapped progress bar covering both per-program
and batch rows.
In Jupyter, ``auto_refresh`` is disabled to avoid double-rendering
(rich#1737); the caller is responsible for ``live.refresh()`` after
each update in that mode.
Returns ``(None, None)`` when :func:`progress_disabled` is true.
"""
if progress_disabled():
return None, None
_ensure_unbuffered_stdout()
progress_bar = make_progress_bar()
live = Live(
progress_bar,
auto_refresh=not is_jupyter,
refresh_per_second=10,
)
return progress_bar, live
# ---------------------------------------------------------------------------
# Queue listener & batch message handler
# ---------------------------------------------------------------------------
def _safe_log(console, msg: str) -> None:
"""Best-effort console log — swallow any errors from Rich (especially
during interpreter / live-display teardown) so the listener thread
never dies on its own diagnostics."""
try:
console.log(msg)
except Exception:
pass
def _safe_call(fn, /, *args, **kwargs) -> None:
"""Run *fn(*args, **kwargs)* and swallow any exception; same intent
as :func:`_safe_log` for non-logging Rich calls (e.g. ``live.refresh``)."""
try:
fn(*args, **kwargs)
except Exception:
pass
def _drain_queue_quietly(queue: Queue) -> None:
"""Drain remaining messages and ``task_done()`` each so callers
blocked on ``queue.join()`` aren't held hostage by a dead listener."""
while True:
try:
queue.get_nowait()
except Empty:
break
except Exception:
break
try:
queue.task_done()
except Exception:
pass
[docs]
def queue_listener(
queue: Queue,
progress_bar: Progress,
pb_task_map: dict[Any, TaskID],
done_event: Event,
lock: Lock,
*,
hide_program_rows: bool = False,
prep_task_id: TaskID | None = None,
):
"""Drain a message queue and update the unified progress bar.
Runs in a daemon thread until *done_event* is set. Messages with
``batch=True`` are routed to :func:`handle_batch_message`; messages
with ``prep_advance=True`` advance the prep row; all others are
program-level updates resolved through ``pb_task_map``.
When ``hide_program_rows`` is set, per-program rows were created
invisible by the ensemble — the listener reveals them only on a
non-Success terminal status so failures stay diagnosable.
The body is fully guarded: per-message exceptions are caught by the
inner try; anything that escapes (including a thread-construction
error before this body even runs) is caught by the outer
``BaseException`` handler in the spawned thread, which drains the
queue so ``queue.join()`` callers never hang on a dead listener.
Each per-message body is wrapped so a malformed message (e.g. unknown
``job_id``, Rich raising during teardown) cannot starve ``queue.join()``
or kill the listener thread. Exceptions are logged and the queue
advances to the next message.
"""
console = progress_bar.console
# Batch rows live in the same Progress as program rows now; their
# TaskIDs are tracked locally so split flush groups (expval vs
# shots) each get their own row.
batch_task_ids: dict[int, TaskID] = {}
while not done_event.is_set():
# Outer-loop guard: this listener is the sole writer of progress
# updates after the queue-routing refactor. A dead listener
# silently freezes the display and hangs ``ensemble.join()``'s
# drain wait, so any escaping ``Exception`` is logged and the
# loop continues — the inner per-message try is the primary
# defence; this is belt-and-suspenders.
try:
try:
msg: dict[str, Any] = queue.get(timeout=0.1)
except Empty:
continue
except Exception as e:
_safe_log(console, f"[queue_listener] queue.get failed: {e}")
continue
try:
# --- Batch-level messages from the coordinator ---
if msg.get("batch"):
handle_batch_message(
msg,
progress_bar,
batch_task_ids,
lock,
)
continue
# --- Prep-progress signals from the coordinator ---
if msg.get("prep_advance"):
if prep_task_id is None:
continue
prep_update: dict[str, Any] = {"advance": 1}
prep_task = progress_bar._tasks.get(prep_task_id)
if (
prep_task is not None
and prep_task.total is not None
and prep_task.completed + 1 >= prep_task.total
):
# Last program reached submit — barrier is about
# to fire. Mark the prep row as final so the
# display reads "Success ✅" instead of leaving
# an active spinner.
prep_update["final_status"] = TerminalStatus.SUCCESS
prep_update.pop("advance", None)
prep_update["completed"] = prep_task.total
progress_bar.update(prep_task_id, **prep_update)
continue
# --- Regular per-program messages ---
with lock:
task_id = pb_task_map.get(msg["job_id"])
if task_id is None:
# Stale or unknown job_id (e.g. a late progress message
# arriving after the program's task was torn down).
# Drop it rather than letting a KeyError kill the
# listener.
_safe_log(
console,
f"[queue_listener] dropped message for unknown job_id "
f"{msg.get('job_id')!r}",
)
continue
update_args: dict[str, Any] = {"advance": msg["progress"]}
for key in (
"poll_attempt",
"max_retries",
"service_job_id",
"job_status",
"loss",
):
if key in msg:
update_args[key] = msg[key]
if msg.get("message"):
update_args["message"] = msg["message"]
if "final_status" in msg:
final_status = msg["final_status"]
update_args["final_status"] = final_status
if final_status == TerminalStatus.SUCCESS:
# Fill the bar so a successful program isn't
# displayed as 0/N when it didn't tick the
# counter incrementally.
task = progress_bar._tasks.get(task_id)
if task is not None and task.total is not None:
update_args["completed"] = task.total
update_args.pop("advance", None)
elif hide_program_rows and final_status in (
TerminalStatus.FAILED,
TerminalStatus.CANCELLED,
TerminalStatus.ABORTED,
):
# Per-program rows were created hidden by the
# ensemble; reveal this one so the user can see
# what went wrong.
update_args["visible"] = True
try:
progress_bar.update(task_id, **update_args)
except Exception as e:
_safe_log(
console,
f"[queue_listener] progress_bar.update failed: {e}",
)
except Exception as e:
# Per-message safety net: any unexpected exception in the
# processing body is logged and swallowed. Without this,
# queue.join() in ProgramEnsemble would block forever
# waiting for the task_done() that the dead listener
# thread never makes.
_safe_log(
console,
f"[queue_listener] unexpected exception while handling message: {e}",
)
finally:
queue.task_done()
except Exception as e:
_safe_log(
console,
f"[queue_listener] outer-loop exception (continuing): {e}",
)
[docs]
def handle_batch_message(
msg: dict[str, Any],
progress_bar: Progress,
batch_task_ids: dict[int, TaskID],
lock: Lock,
):
"""Process a batch-level progress message in the unified progress bar.
Batch rows are created dynamically per ``batch_id`` so that split
sub-batches (e.g. expval vs shots) each get their own status line.
The conditional column wrappers keep the bar/M-of-N/elapsed cells
empty for batch rows; the indicator/text/spinner/status columns
render normally.
Program-row coloring works by reading each task's ``program_key``
field — no parallel ``program_key → TaskID`` index needed. Reading
``progress_bar._tasks`` mirrors the same pattern used elsewhere in
the listener.
"""
batch_id = msg.get("batch_id")
if not isinstance(batch_id, int):
return
color = msg.get("batch_color", "")
label = msg.get("batch_label", "")
n_circuits = msg.get("n_circuits", 0)
n_programs = msg.get("n_programs", 0)
final_status = msg.get("final_status")
# Lazily create a batch row for this batch_id.
if batch_id not in batch_task_ids:
batch_task_ids[batch_id] = progress_bar.add_task(
"",
job_name="",
total=0,
visible=False,
row_kind="batch",
program_key=None,
batch_color="",
message="",
final_status=None,
)
task_id = batch_task_ids[batch_id]
# Build update args for the batch row
update_args: dict[str, Any] = {}
if not final_status:
update_args["visible"] = True
prefix = f"Batch ({label})" if label else "Batch"
update_args["job_name"] = (
f"{prefix}: {n_circuits} circuits, {n_programs} programs"
)
update_args["batch_color"] = color
_apply_color_to_program_rows(
progress_bar, msg.get("program_keys", ()), color, lock
)
if "poll_attempt" in msg:
update_args["poll_attempt"] = msg["poll_attempt"]
if "max_retries" in msg:
update_args["max_retries"] = msg["max_retries"]
if "service_job_id" in msg:
update_args["service_job_id"] = msg["service_job_id"]
if "job_status" in msg:
update_args["job_status"] = msg["job_status"]
if msg.get("message"):
update_args["message"] = msg["message"]
if final_status:
update_args["visible"] = False
_apply_color_to_program_rows(
progress_bar, msg.get("program_keys", ()), "", lock
)
del batch_task_ids[batch_id]
progress_bar.update(task_id, **update_args)
def _apply_color_to_program_rows(
progress_bar: Progress,
program_keys,
color: str,
lock: Lock,
) -> None:
"""Set ``batch_color`` on every program row whose ``program_key``
field appears in *program_keys*. Pass ``color=""`` to clear."""
if not program_keys:
return
keys_set = set(program_keys)
# Snapshot under the lock so a concurrent ``add_task`` can't resize
# ``_tasks`` mid-iteration. ``progress_bar.update`` does not need
# the lock — it only mutates per-task fields, not the dict.
with lock:
snapshot = list(progress_bar._tasks.items())
for tid, task in snapshot:
if task.fields.get("program_key") in keys_set:
progress_bar.update(tid, batch_color=color)