Source code for radionets.architecture.activation
import torch.nn.functional as F
from torch import nn
__all__ = [
"GeneralELU",
"GeneralReLU",
"Lambda",
]
[docs]
class Lambda(nn.Module):
def __init__(self, func):
super().__init__()
self.func = func
[docs]
def forward(self, x):
return self.func(x)
[docs]
class GeneralReLU(nn.Module):
def __init__(self, leak=None, sub=None, maxv=None):
super().__init__()
self.leak = leak
self.sub = sub
self.maxv = maxv
[docs]
def forward(self, x):
x = F.leaky_relu(x, self.leak) if self.leak is not None else F.relu(x)
if self.sub is not None:
x = x - self.sub
if self.maxv is not None:
x.clamp_max_(self.maxv)
return x
[docs]
class GeneralELU(nn.Module):
def __init__(self, add=None, maxv=None):
super().__init__()
self.add = add
self.maxv = maxv
[docs]
def forward(self, x):
x = F.elu(x)
if self.add is not None:
x = x + self.add
if self.maxv is not None:
x.clamp_max_(self.maxv)
return x