Source code for radionets.core.loss_functions

import numpy as np
import torch
from torch import nn

__all__ = [
    "beta_nll_loss",
    "create_circular_mask",
    "jet_seg",
    "l1",
    "mse",
    "splitted_L1",
    "splitted_L1_masked",
]


[docs] def l1(x, y): l1 = nn.L1Loss() loss = l1(x, y) return loss
[docs] def create_circular_mask(h, w, center=None, radius=None, bs=64): if center is None: center = (int(w / 2), int(h / 2)) if radius is None: radius = min(center[0], center[1], w - center[0], h - center[1]) Y, X = np.ogrid[:h, :w] dist_from_center = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2) mask = dist_from_center <= radius return np.repeat([mask], bs, axis=0)
[docs] def splitted_L1_masked(x, y): inp_amp = x[:, 0, :] inp_phase = x[:, 1, :] tar_amp = y[:, 0, :] tar_phase = y[:, 1, :] mask = torch.tensor(create_circular_mask(256, 256, radius=50, bs=y.shape[0])) inp_amp[~mask] *= 0.3 inp_phase[~mask] *= 0.3 tar_amp[~mask] *= 0.3 tar_phase[~mask] *= 0.3 l1 = nn.L1Loss() loss_amp = l1(inp_amp, tar_amp) loss_phase = l1(inp_phase, tar_phase) loss = loss_amp + loss_phase return loss
[docs] def splitted_L1(x, y): inp_amp = x[:, 0, :] inp_phase = x[:, 1, :] tar_amp = y[:, 0, :] tar_phase = y[:, 1, :] l1 = nn.L1Loss() loss_amp = l1(inp_amp, tar_amp) loss_phase = l1(inp_phase, tar_phase) loss = loss_amp + loss_phase return loss
[docs] def beta_nll_loss(x: torch.tensor, y: torch.tensor, beta: float = 0.5): """Compute beta-NLL loss Parameters ---------- x : :func:`torch.tensor` Prediction of the model. y : :func:`torch.tensor` Ground truth. beta : float Parameter from range [0, 1] controlling relative weighting between data points, where "0" corresponds to high weight on low error points and "1" to an equal weighting. Returns ------- float : Loss per batch element of shape B """ pred_amp = x[:, 0, :] pred_phase = x[:, 2, :] mean = torch.stack([pred_amp, pred_phase], axis=1) unc_amp = x[:, 1, :] unc_phase = x[:, 3, :] variance = torch.stack([unc_amp, unc_phase], axis=1) tar_amp = y[:, 0, :] tar_phase = y[:, 1, :] target = torch.stack([tar_amp, tar_phase], axis=1) loss = 0.5 * ((target - mean) ** 2 / variance + variance.log()) if beta > 0: loss = loss * variance.detach() ** beta return loss.mean()
[docs] def mse(x, y): mse = nn.MSELoss() loss = mse(x, y) return loss
[docs] def jet_seg(x, y): # weight components farer outside more loss_l1_weighted = 0 for i in range(x.shape[1]): loss_l1_weighted += l1(x[:, i], y[:, i]) * (i + 1) return loss_l1_weighted