resolve_sample_loss

resolve_sample_loss(loss)[source]

Resolve a SampleLossFn literal/callable to a concrete callable.

User-supplied callables are wrapped in float(...) so the per-sample loss is a plain Python float regardless of the numpy types it returns. A custom callable must return a finite value — a NaN/Inf result is not guarded and propagates into the reduction and on to the optimizer.

Return type:

Callable[[float, float], float]