Source code for radionets.core.learner

import torch.nn as nn
from fastai.callback.schedule import ParamScheduler, combined_cos
from fastai.data.core import DataLoaders
from fastai.learner import Learner
from fastai.optimizer import Adam

import radionets.core.loss_functions as loss_functions
from radionets.core.callbacks import (
    AvgLossCallback,
    CometCallback,
    CudaCallback,
    DataAug,
    Normalize,
    SaveTempCallback,
    SwitchLoss,
)
from radionets.core.model import init_cnn

__all__ = ["get_learner", "define_learner"]


[docs] def get_learner(data, arch, lr, loss_func=None, cb_funcs=None, opt_func=Adam, **kwargs): if not loss_func: loss_func = nn.MSELoss() init_cnn(arch) dls = DataLoaders.from_dsets( data.train_ds, data.valid_ds, bs=data.train_dl.batch_size ) return Learner(dls, arch, loss_func, lr=lr, cbs=cb_funcs, opt_func=opt_func)
[docs] def define_learner(data, arch, train_conf, lr_find=False, plot_loss=False): cbfs = [] model_path = train_conf["model_path"] lr = train_conf["lr"] opt_func = Adam if train_conf["param_scheduling"]: sched = { "lr": combined_cos( train_conf["lr_ratio"], train_conf["lr_start"], train_conf["lr_max"], train_conf["lr_stop"], ) } cbfs.extend([ParamScheduler(sched)]) if train_conf["gpu"]: cbfs.extend([CudaCallback]) cbfs.extend( [ SaveTempCallback(model_path=model_path), AvgLossCallback, DataAug, ] ) # use switch loss if train_conf["switch_loss"]: cbfs.extend( [ SwitchLoss( second_loss=loss_functions.comb_likelihood, when_switch=train_conf["when_switch"], ), ] ) if train_conf["comet_ml"] and not lr_find and not plot_loss: cbfs.extend( [ CometCallback( name=train_conf["project_name"], test_data=train_conf["data_path"], plot_n_epochs=train_conf["plot_n_epochs"], amp_phase=train_conf["amp_phase"], scale=train_conf["scale"], ), ] ) if not plot_loss and train_conf["normalize"] != "none": cbfs.extend([Normalize(train_conf)]) # get loss func if train_conf["loss_func"] == "feature_loss": loss_func = loss_functions.init_feature_loss() else: loss_func = getattr(loss_functions, train_conf["loss_func"]) # Combine model and data in learner learn = get_learner( data, arch, lr=lr, opt_func=opt_func, cb_funcs=cbfs, loss_func=loss_func ) return learn