Source code for radionets.architecture.layers

import torch
from torch import nn
from torch.nn.modules.utils import _pair


[docs] class LocallyConnected2d(nn.Module): def __init__( self, in_channels, out_channels, output_size, kernel_size, stride, bias=False ): super().__init__() self.weight = nn.Parameter( torch.randn( 1, out_channels, in_channels, output_size[0], output_size[1], kernel_size**2, ) ) if bias: self.bias = nn.Parameter( torch.randn(1, out_channels, output_size[0], output_size[1]) ) else: self.register_parameter("bias", None) self.kernel_size = _pair(kernel_size) self.stride = _pair(stride)
[docs] def forward(self, x): _, c, h, w = x.size() kh, kw = self.kernel_size dh, dw = self.stride x = x.unfold(2, kh, dh).unfold(3, kw, dw) x = x.contiguous().view(*x.size()[:-2], -1) # Sum in in_channel and kernel_size dims out = (x.unsqueeze(1) * self.weight).sum([2, -1]) if self.bias is not None: out += self.bias return out