resolve_loss_reduction

resolve_loss_reduction(reduction)[source]

Resolve a LossReductionFn literal/callable to a concrete callable.

User-supplied callables are wrapped in float(...) so naked numpy reductions (e.g. loss_reduction=np.mean) — which return a 0-d ndarray — produce a plain Python float, matching the contract that downstream stages and losses_history expect.

Return type:

Callable[[ndarray[tuple[Any, ...], dtype[double]]], float]