Source code for radionets.architecture.layers
import torch
from torch import nn
from torch.nn.modules.utils import _pair
[docs]
class LocallyConnected2d(nn.Module):
def __init__(
self, in_channels, out_channels, output_size, kernel_size, stride, bias=False
):
super().__init__()
self.weight = nn.Parameter(
torch.randn(
1,
out_channels,
in_channels,
output_size[0],
output_size[1],
kernel_size**2,
)
)
if bias:
self.bias = nn.Parameter(
torch.randn(1, out_channels, output_size[0], output_size[1])
)
else:
self.register_parameter("bias", None)
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
[docs]
def forward(self, x):
_, c, h, w = x.size()
kh, kw = self.kernel_size
dh, dw = self.stride
x = x.unfold(2, kh, dh).unfold(3, kw, dw)
x = x.contiguous().view(*x.size()[:-2], -1)
# Sum in in_channel and kernel_size dims
out = (x.unsqueeze(1) * self.weight).sum([2, -1])
if self.bias is not None:
out += self.bias
return out