Source code for radionets.core.model
from pathlib import Path
import torch
from torch import nn
from radionets.core.logging import setup_logger
__all__ = [
"init_cnn",
"load_pre_model",
"save_model",
]
LOGGER = setup_logger(namespace=__name__)
def _init_cnn(m, f):
if isinstance(m, nn.Conv2d):
f(m.weight, a=0.1)
if getattr(m, "bias", None) is not None:
m.bias.data.zero_()
for c in m.children():
_init_cnn(c, f)
[docs]
def init_cnn(m, uniform=False):
f = nn.init.kaiming_uniform_ if uniform else nn.init.kaiming_normal_
_init_cnn(m, f)
[docs]
def load_pre_model(learn, pre_path, visualize=False, plot_loss=False):
"""Loads a previously saved model as pre-model.
Parameters
----------
learn : learner
Object of type learner.
pre_path : str
Path to the pre-model.
visualize : bool
Default: False
plot_loss : bool
Default: False
"""
name_pretrained = Path(pre_path).stem
LOGGER.info(f"Load pretrained model: {name_pretrained}")
if torch.cuda.is_available() and not plot_loss:
checkpoint = torch.load(pre_path)
else:
checkpoint = torch.load(pre_path, map_location=torch.device("cpu"))
if visualize:
learn.load_state_dict(checkpoint["model"])
return checkpoint["norm_dict"]
elif plot_loss:
learn.avg_loss.loss_train = checkpoint["train_loss"]
learn.avg_loss.loss_valid = checkpoint["valid_loss"]
learn.avg_loss.lrs = checkpoint["lrs"]
else:
learn.model.load_state_dict(checkpoint["model"])
learn.opt.load_state_dict(checkpoint["opt"])
learn.epoch = checkpoint["epoch"]
learn.avg_loss.loss_train = checkpoint["train_loss"]
learn.avg_loss.loss_valid = checkpoint["valid_loss"]
learn.avg_loss.lrs = checkpoint["lrs"]
learn.recorder.iters = checkpoint["iters"]
learn.recorder.values = checkpoint["vals"]
[docs]
def save_model(learn, model_path):
if hasattr(learn, "normalize"):
if learn.normalize.mode == "mean":
norm_dict = {
"mean_real": learn.normalize.mean_real,
"mean_imag": learn.normalize.mean_imag,
"std_real": learn.normalize.std_real,
"std_imag": learn.normalize.std_imag,
}
elif learn.normalize.mode == "max":
norm_dict = {"max_scaling": 0}
elif learn.normalize.mode == "all":
norm_dict = {"all": 0}
elif not learn.normalize.mode:
norm_dict = {}
else:
raise ValueError(f"Undefined mode {learn.normalize.mode}, check for typos")
else:
norm_dict = {}
torch.save(
{
"model": learn.model.state_dict(),
"opt": learn.opt.state_dict(),
"epoch": learn.epoch,
"iters": learn.recorder.iters,
"vals": learn.recorder.values,
"train_loss": learn.avg_loss.loss_train,
"valid_loss": learn.avg_loss.loss_valid,
"lrs": learn.avg_loss.lrs,
"norm_dict": norm_dict,
},
model_path,
)