import torch
from torch import nn
from radionets.architecture.activation import GeneralELU
from radionets.architecture.archs import SRResNet34
from radionets.architecture.layers import LocallyConnected2d
__all__ = [
"Uncertainty",
"UncertaintyWrapper",
]
[docs]
class Uncertainty(nn.Module):
def __init__(self, img_size):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(4, 16, 9, stride=1, padding=4, groups=2),
nn.InstanceNorm2d(16),
nn.ReLU(),
)
self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, 3, stride=1, padding=1),
nn.InstanceNorm2d(32),
nn.ReLU(),
)
self.conv3 = nn.Sequential(
nn.Conv2d(32, 64, 9, stride=1, padding=4, groups=2),
nn.InstanceNorm2d(64),
nn.ReLU(),
)
self.final = nn.Sequential(
LocallyConnected2d(
64,
2,
[img_size // 2 + 1, img_size],
1,
stride=1,
bias=False,
)
)
self.elu = GeneralELU(add=+(1 + 1e-7))
[docs]
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.final(x)
return self.elu(x)
[docs]
class UncertaintyWrapper(SRResNet34):
def __init__(self, img_size):
super().__init__()
self.uncertainty = Uncertainty(img_size)
[docs]
def forward(self, x):
# Get prediction from SRResNet34
pred = super.forward(x)
inp = x.clone()
# x = torch.abs(pred - inp)
x = torch.cat([pred, inp], dim=1)
unc = self.uncertainty(x)
val = torch.cat(
[
pred[:, 0].unsqueeze(1),
unc[:, 0].unsqueeze(1),
pred[:, 1].unsqueeze(1),
unc[:, 1].unsqueeze(1),
],
dim=1,
)
return val