import numpy as np
import torch
from torch import nn
__all__ = [
"beta_nll_loss",
"create_circular_mask",
"jet_seg",
"l1",
"mse",
"splitted_L1",
"splitted_L1_masked",
]
[docs]
def l1(x, y):
l1 = nn.L1Loss()
loss = l1(x, y)
return loss
[docs]
def create_circular_mask(h, w, center=None, radius=None, bs=64):
if center is None:
center = (int(w / 2), int(h / 2))
if radius is None:
radius = min(center[0], center[1], w - center[0], h - center[1])
Y, X = np.ogrid[:h, :w]
dist_from_center = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2)
mask = dist_from_center <= radius
return np.repeat([mask], bs, axis=0)
[docs]
def splitted_L1_masked(x, y):
inp_amp = x[:, 0, :]
inp_phase = x[:, 1, :]
tar_amp = y[:, 0, :]
tar_phase = y[:, 1, :]
mask = torch.tensor(create_circular_mask(256, 256, radius=50, bs=y.shape[0]))
inp_amp[~mask] *= 0.3
inp_phase[~mask] *= 0.3
tar_amp[~mask] *= 0.3
tar_phase[~mask] *= 0.3
l1 = nn.L1Loss()
loss_amp = l1(inp_amp, tar_amp)
loss_phase = l1(inp_phase, tar_phase)
loss = loss_amp + loss_phase
return loss
[docs]
def splitted_L1(x, y):
inp_amp = x[:, 0, :]
inp_phase = x[:, 1, :]
tar_amp = y[:, 0, :]
tar_phase = y[:, 1, :]
l1 = nn.L1Loss()
loss_amp = l1(inp_amp, tar_amp)
loss_phase = l1(inp_phase, tar_phase)
loss = loss_amp + loss_phase
return loss
[docs]
def beta_nll_loss(x: torch.tensor, y: torch.tensor, beta: float = 0.5):
"""Compute beta-NLL loss
Parameters
----------
x : :func:`torch.tensor`
Prediction of the model.
y : :func:`torch.tensor`
Ground truth.
beta : float
Parameter from range [0, 1] controlling relative
weighting between data points, where "0" corresponds to
high weight on low error points and "1" to an equal weighting.
Returns
-------
float : Loss per batch element of shape B
"""
pred_amp = x[:, 0, :]
pred_phase = x[:, 2, :]
mean = torch.stack([pred_amp, pred_phase], axis=1)
unc_amp = x[:, 1, :]
unc_phase = x[:, 3, :]
variance = torch.stack([unc_amp, unc_phase], axis=1)
tar_amp = y[:, 0, :]
tar_phase = y[:, 1, :]
target = torch.stack([tar_amp, tar_phase], axis=1)
loss = 0.5 * ((target - mean) ** 2 / variance + variance.log())
if beta > 0:
loss = loss * variance.detach() ** beta
return loss.mean()
[docs]
def mse(x, y):
mse = nn.MSELoss()
loss = mse(x, y)
return loss
[docs]
def jet_seg(x, y):
# weight components farer outside more
loss_l1_weighted = 0
for i in range(x.shape[1]):
loss_l1_weighted += l1(x[:, i], y[:, i]) * (i + 1)
return loss_l1_weighted