Source code for radionets.architecture.layers

import torch
from torch import nn
from torch.nn.modules.utils import _pair


[docs] class LocallyConnected2d(nn.Module): """ A 2D locally connected layer implementation. Unlike convolutional layers that share weights across spatial locations, locally connected layers use different weights for each spatial position. This allows the layer to learn location-specific features while maintaining the sliding window approach of convolutions. Parameters ---------- in_channels : int Number of input channels. out_channels : int Number of output channels. output_size : tuple of int Expected output spatial dimensions as (height, width). kernel_size : int Size of the sliding window (assumes square kernel). stride : int Stride of the sliding window (assumes same stride for both dimensions). bias : bool, optional If True, adds a learnable bias parameter. Default is False. Attributes ---------- weight : nn.Parameter Learnable weights with shape (1, out_channels, in_channels, output_height, output_width, kernel_size²). bias : nn.Parameter or None Learnable bias with shape (1, out_channels, output_height, output_width) if bias=True, else None. kernel_size : tuple of int Kernel size as (height, width). stride : tuple of int Stride as (height, width). """ def __init__( self, in_channels, out_channels, output_size, kernel_size, stride, bias=False ): super().__init__() self.weight = nn.Parameter( torch.randn( 1, out_channels, in_channels, output_size[0], output_size[1], kernel_size**2, ) ) if bias: self.bias = nn.Parameter( torch.randn(1, out_channels, output_size[0], output_size[1]) ) else: self.register_parameter("bias", None) self.kernel_size = _pair(kernel_size) self.stride = _pair(stride)
[docs] def forward(self, x): _, c, h, w = x.size() kh, kw = self.kernel_size dh, dw = self.stride x = x.unfold(2, kh, dh).unfold(3, kw, dw) x = x.contiguous().view(*x.size()[:-2], -1) # Sum in in_channel and kernel_size dims out = (x.unsqueeze(1) * self.weight).sum([2, -1]) if self.bias is not None: out += self.bias return out
[docs] class ComplexConv2d(nn.Module): """ 2D convolution layer for complex-valued tensors. This layer performs 2D convolution on complex-valued inputs by decomposing the operation into separate real and imaginary components. It implements the mathematical formula for complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i Parameters ---------- in_channels : int Number of channels in the input tensor. out_channels : int Number of channels produced by the convolution. kernel_size : int or tuple of int Size of the convolving kernel. If int, the same value is used for both height and width dimensions. stride : int or tuple of int, optional Stride of the convolution. If int, the same value is used for both height and width dimensions. Default is 1. bias : bool, optional If True, adds a learnable bias to the output. Default is True. Attributes ---------- conv_real : torch.nn.Conv2d Convolution layer for processing real components of the complex multiplication formula. conv_imag : torch.nn.Conv2d Convolution layer for processing imaginary components of the complex multiplication formula. """ def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding="same", bias=True, ): """ Initialize the ComplexConv2d layer. Parameters ---------- in_channels : int Number of channels in the input tensor. out_channels : int Number of channels produced by the convolution. kernel_size : int or tuple of int Size of the convolving kernel. stride : int or tuple of int, optional Stride of the convolution. Default is 1. bias : bool, optional If True, adds a learnable bias to the output. Default is True. """ super().__init__() # Initialize real component convolution layer self.conv_real = nn.Conv2d( in_channels // 2, out_channels // 2, kernel_size, stride=stride, padding=padding, bias=bias, ) # Initialize imaginary component convolution layer self.conv_imag = nn.Conv2d( in_channels // 2, out_channels // 2, kernel_size, stride=stride, padding=padding, bias=bias, )
[docs] def forward(self, x): """ Forward pass of the complex convolution layer. Performs complex-valued 2D convolution by applying separate convolutions for real and imaginary values, and combining results according to complex multiplication rules. Parameters ---------- x : torch.Tensor Input tensor of shape (batch_size, channels, height, width) with dtype (torch.float32 or torch.float64). Expected channels are equally split into real and imag channels, e.g., num channels is 2 for first network layer. Returns ------- torch.Tensor Output tensor of shape (batch_size, out_channels, out_height, out_width). """ real, imag = x.chunk(2, dim=1) real_out = self.conv_real(real) - self.conv_imag(imag) imag_out = self.conv_real(imag) + self.conv_imag(real) return torch.cat([real_out, imag_out], dim=1)
[docs] class ComplexInstanceNorm2d(nn.Module): """ 2D instance normalization layer for complex-valued tensors. This layer performs instance normalization on complex-valued inputs by treating real and imaginary parts separately. The normalization is applied as: normalized = (x - mean) / sqrt(variance + eps) If affine=True, learnable scale and shift parameters are applied: output = normalized * weight + bias Parameters ---------- num_features : int Number of channels in the input tensor. For complex inputs, this represents the number of complex channels (input will have 2*num_features channels representing real and imaginary parts). eps : float, optional A small value added to the denominator for numerical stability. Default is 1e-5. affine : bool, optional If True, adds learnable affine parameters (scale and shift). Default is True. Attributes ---------- num_features : int Number of complex channels which get split into real and imaginary part equally. (num_channels // 2 for real and imag) eps : float Epsilon value for numerical stability. affine : bool Whether affine transformation is enabled. weight_real : torch.nn.Parameter or None Learnable scale parameter for real part. Shape: (num_features // 2,). Only exists if affine=True. weight_imag : torch.nn.Parameter or None Learnable scale parameter for imaginary part. Shape: (num_features // 2,). Only exists if affine=True. bias_real : torch.nn.Parameter or None Learnable shift parameter for real part. Shape: (num_features // 2,). Only exists if affine=True. bias_imag : torch.nn.Parameter or None Learnable shift parameter for imaginary part. Shape: (num_features // 2,). Only exists if affine=True. """ def __init__(self, num_features, eps=1e-5, affine=True): """ Initialize the ComplexInstanceNorm2d layer. Parameters ---------- num_features : int Number of channels in the input tensor. Num_features will be equally split into channels for real and imag. eps : float, optional Small value added to variance for numerical stability. Must be positive. Default is 1e-5. affine : bool, optional If True, creates learnable affine parameters (weights and biases) for both real and imaginary components. Default is True. """ super().__init__() # Store configuration self.num_features = num_features // 2 # Divide by 2 for equal real and imag self.eps = eps self.affine = affine if self.affine: # Separate parameters for real and imaginary parts # Initialize weights to ones for identity scaling self.weight_real = nn.Parameter(torch.ones(self.num_features)) self.weight_imag = nn.Parameter(torch.ones(self.num_features)) # Initialize biases to zeros for no initial shift self.bias_real = nn.Parameter(torch.zeros(self.num_features)) self.bias_imag = nn.Parameter(torch.zeros(self.num_features))
[docs] def forward(self, x): """ Forward pass of the complex instance normalization layer. Performs instance normalization on complex-valued input by processing real and imaginary components separately. Computes statistics across spatial dimensions for each sample and channel independently. Complex values have to be passed in separate channels as torch.float, e.g., one channel for real and one channel for imag, leading to a shape [bs, 2, h, w]. Parameters ---------- x : torch.Tensor Input tensor of shape (batch_size, num_features, height, width) where the first half of channels represents real parts and the second half represents imaginary parts of complex values. Returns ------- torch.Tensor Normalized tensor of the same shape as input (batch_size, num_features, height, width). """ # Split complex input into real and imaginary components real, imag = x.chunk(2, dim=1) # Instance normalization for each part separately # Calc mean and variance across dimensions (H, W) for each sample and channel # unbiased=False: uses N denominator instead of N-1 for variance real_mean = real.mean(dim=[2, 3], keepdim=True) imag_mean = imag.mean(dim=[2, 3], keepdim=True) real_var = real.var(dim=[2, 3], keepdim=True, unbiased=False) imag_var = imag.var(dim=[2, 3], keepdim=True, unbiased=False) # Normalize real_norm = (real - real_mean) / torch.sqrt(real_var + self.eps) imag_norm = (imag - imag_mean) / torch.sqrt(imag_var + self.eps) if self.affine: # Apply learnable affine transformation real_norm = real_norm * self.weight_real.view( 1, -1, 1, 1 ) + self.bias_real.view(1, -1, 1, 1) imag_norm = imag_norm * self.weight_imag.view( 1, -1, 1, 1 ) + self.bias_imag.view(1, -1, 1, 1) return torch.cat([real_norm, imag_norm], dim=1)
[docs] class ComplexPReLU(nn.Module): """ Parametric ReLU activation function for complex-valued tensors. This layer applies Parametric ReLU activation to complex-valued inputs by treating real and imaginary parts separately. PReLU allows the negative slope to be learned during training, providing more flexibility than standard ReLU activation. The activation is applied as: - For positive values: f(x) = x - For negative values: f(x) = a * x where 'a' is the learnable negative slope parameter. Parameters ---------- num_parameters : int, optional Number of learnable parameters. Can be: - 1: Single shared parameter for all channels (default) - num_channels: Per-channel parameters for fine-grained control Default is 1. init : float, optional Initial value for the negative slope parameter(s). Should be a small positive value. Default is 0.25. Attributes ---------- num_parameters : int Number of learnable parameters (1 for shared, num_channels for per-channel). weight_real : torch.nn.Parameter Learnable negative slope parameter(s) for real channel(s). Shape: (num_parameters // 2,) weight_imag : torch.nn.Parameter Learnable negative slope parameter(s) for imaginary channel(s). Shape: (num_parameters // 2,) """ def __init__(self, num_parameters=1, init=0.25): """ Initialize the ComplexPReLU activation layer. Parameters ---------- num_parameters : int, optional Number of learnable parameters. Options: - 1: Single parameter shared across all channels (default) - num_channels: Individual parameter per channel Must be positive integer. Default is 1. init : float, optional Initial value for the negative slope parameter(s). Typically a small positive value (e.g., 0.01 to 0.25). Must be finite and typically in range [0, 1]. Default is 0.25. """ super().__init__() # Store configuration self.num_parameters = num_parameters # Create separate learnable parameters for real and imaginary parts n_params = self.num_parameters // 2 if self.num_parameters >= 2 else 1 self.weight_real = nn.Parameter(torch.full((n_params,), init)) self.weight_imag = nn.Parameter(torch.full((n_params,), init))
[docs] def forward(self, x): """ Forward pass of the complex PReLU activation function. Parameters ---------- x : torch.Tensor Input tensor of shape (batch_size, num_channels, height, width) where the first half of channels represents real parts and the second half represents imaginary parts of complex values. Returns ------- torch.Tensor Activated tensor of the same shape and dtype as input (batch_size, num_channels, height, width). """ # Split channels into real and imaginary components real, imag = x.chunk(2, dim=1) if self.num_parameters == 1: # Shared parameter across all channels real_out = torch.where(real >= 0, real, self.weight_real * real) imag_out = torch.where(imag >= 0, imag, self.weight_imag * imag) else: # Per-channel parameters weight_real = self.weight_real.view(1, -1, 1, 1) weight_imag = self.weight_imag.view(1, -1, 1, 1) real_out = torch.where(real >= 0, real, weight_real * real) imag_out = torch.where(imag >= 0, imag, weight_imag * imag) return torch.cat([real_out, imag_out], dim=1)