Source code for radionets.architecture.activation

import torch.nn.functional as F
from torch import nn

__all__ = [
    "GeneralELU",
    "GeneralReLU",
    "Lambda",
]


[docs] class Lambda(nn.Module): def __init__(self, func): super().__init__() self.func = func
[docs] def forward(self, x): return self.func(x)
[docs] class GeneralReLU(nn.Module): def __init__(self, leak=None, sub=None, maxv=None): super().__init__() self.leak = leak self.sub = sub self.maxv = maxv
[docs] def forward(self, x): x = F.leaky_relu(x, self.leak) if self.leak is not None else F.relu(x) if self.sub is not None: x = x - self.sub if self.maxv is not None: x.clamp_max_(self.maxv) return x
[docs] class GeneralELU(nn.Module): def __init__(self, add=None, maxv=None): super().__init__() self.add = add self.maxv = maxv
[docs] def forward(self, x): x = F.elu(x) if self.add is not None: x = x + self.add if self.maxv is not None: x.clamp_max_(self.maxv) return x