Source code for radionets.architecture.archs

from math import pi

import torch
from torch import nn

from radionets.architecture.activation import GeneralReLU
from radionets.architecture.blocks import SRBlock

__all__ = [
    "SRResNet",
    "SRResNet18",
    "SRResNet34",
    "SRResNet34_unc",
    "SRResNet34_unc_no_grad",
]


[docs] class SRResNet(nn.Module): def __init__(self): super().__init__() self.channels = 64 self.preBlock = nn.Sequential( nn.Conv2d( in_channels=2, out_channels=self.channels, kernel_size=9, stride=1, padding=4, groups=2, ), nn.PReLU(), ) self.postBlock = nn.Sequential( nn.Conv2d( in_channels=self.channels, out_channels=self.channels, kernel_size=3, stride=1, padding=1, bias=False, ), nn.BatchNorm2d(self.channels), ) self.final = nn.Sequential( nn.Conv2d( in_channels=self.channels, out_channels=2, kernel_size=9, stride=1, padding=4, groups=2, ), ) self.hardtanh = nn.Hardtanh(-pi, pi) self.relu = nn.ReLU() def _create_blocks(self, n_blocks): blocks = [] for _ in range(n_blocks): blocks.append(SRBlock(64, 64)) self.blocks = nn.Sequential(*blocks)
[docs] def forward(self, x): x = self.preBlock(x) x = x + self.postBlock(self.blocks(x)) x = self.final(x) x0 = self.relu(x[:, 0].unsqueeze(1)) x1 = self.hardtanh(x[:, 1].unsqueeze(1)) return torch.cat([x0, x1], dim=1)
[docs] class SRResNet18(SRResNet): def __init__(self): super().__init__() # Create 8 ResBlocks to build a SRResNet18 self._create_blocks(8)
[docs] class SRResNet34(SRResNet): def __init__(self): super().__init__() # Create 16 ResBlocks to build a SRResNet34 self._create_blocks(16) self.postBlock = nn.Sequential( nn.Conv2d( in_channels=self.channels, out_channels=self.channels, kernel_size=3, stride=1, padding=1, bias=False, ), nn.InstanceNorm2d(self.channels), )
[docs] class SRResNet34_unc(SRResNet): def __init__(self): super().__init__() self._create_blocks(16) self.postBlock = nn.Sequential( nn.Conv2d( in_channels=self.channels, out_channels=self.channels, kernel_size=3, stride=1, padding=1, bias=False, ), nn.InstanceNorm2d(self.channels), ) self.elu = GeneralReLU(sub=-1e-10)
[docs] def forward(self, x): s = x.shape[-1] x = self.preBlock(x) x = x + self.postBlock(self.blocks(x)) x = self.final(x) x0 = x[:, 0].reshape(-1, 1, s // 2 + 1, s) x1 = x[:, 1].reshape(-1, 1, s // 2 + 1, s) x3 = x[:, 2].reshape(-1, 1, s // 2 + 1, s) x3 = self.elu(x3) x4 = x[:, 3].reshape(-1, 1, s // 2 + 1, s) x4 = self.elu(x4) return torch.cat([x0, x3, x1, x4], dim=1)
[docs] class SRResNet34_unc_no_grad(SRResNet34_unc): def __init__(self): super().__init__()
[docs] def forward(self, x): s = x.shape[-1] x = self.preBlock(x) x = x + self.postBlock(self.blocks(x)) x = self.final(x) x0 = x[:, 0].reshape(-1, 1, s // 2 + 1, s) x1 = x[:, 1].reshape(-1, 1, s // 2 + 1, s) x3 = x[:, 2].reshape(-1, 1, s // 2 + 1, s) with torch.no_grad(): x3 = self.elu(x3) x4 = x[:, 3].reshape(-1, 1, s // 2 + 1, s) with torch.no_grad(): x4 = self.elu(x4) return torch.cat([x0, x3, x1, x4], dim=1)