Source code for radionets.training.utils

import sys
from pathlib import Path

import click
import torch
from tqdm import tqdm

from radionets import architecture
from radionets.core.data import DataBunch, get_dls, load_data
from radionets.core.logging import setup_logger
from radionets.core.model import save_model
from radionets.evaluation.train_inspection import create_inspection_plots
from radionets.plotting.inspection import plot_loss

LOGGER = setup_logger()


[docs] def create_databunch(data_path, fourier, batch_size): # Load data sets train_ds = load_data(data_path, "train", fourier=fourier) valid_ds = load_data(data_path, "valid", fourier=fourier) # Create databunch with defined batchsize data = DataBunch(*get_dls(train_ds, valid_ds, batch_size)) return data
[docs] def read_config(config): train_conf = {} train_conf["data_path"] = config["paths"]["data_path"] train_conf["model_path"] = config["paths"]["model_path"] train_conf["pre_model"] = config["paths"]["pre_model"] train_conf["quiet"] = config["mode"]["quiet"] train_conf["gpu"] = config["mode"]["gpu"] train_conf["comet_ml"] = config["logging"]["comet_ml"] train_conf["plot_n_epochs"] = config["logging"]["plot_n_epochs"] train_conf["project_name"] = config["logging"]["project_name"] train_conf["scale"] = config["logging"]["scale"] train_conf["batch_size"] = config["hypers"]["batch_size"] train_conf["lr"] = config["hypers"]["lr"] train_conf["fourier"] = config["general"]["fourier"] train_conf["amp_phase"] = config["general"]["amp_phase"] train_conf["normalize"] = config["general"]["normalize"] train_conf["arch_name"] = config["general"]["arch_name"] train_conf["loss_func"] = config["general"]["loss_func"] train_conf["num_epochs"] = config["general"]["num_epochs"] train_conf["inspection"] = config["general"]["inspection"] train_conf["separate"] = False train_conf["format"] = config["general"]["output_format"] train_conf["switch_loss"] = config["general"]["switch_loss"] train_conf["when_switch"] = config["general"]["when_switch"] train_conf["param_scheduling"] = config["param_scheduling"]["use"] train_conf["lr_start"] = config["param_scheduling"]["lr_start"] train_conf["lr_max"] = config["param_scheduling"]["lr_max"] train_conf["lr_stop"] = config["param_scheduling"]["lr_stop"] train_conf["lr_ratio"] = config["param_scheduling"]["lr_ratio"] train_conf["source_list"] = config["general"]["source_list"] return train_conf
[docs] def check_outpath(model_path, train_conf): path = Path(model_path) exists = path.exists() if exists: if train_conf["quiet"]: LOGGER.info("Overwriting existing model file!") path.unlink() else: if click.confirm( "Do you really want to overwrite existing model file?", abort=True ): LOGGER.info("Overwriting existing model file!") path.unlink()
[docs] def define_arch(arch_name, img_size): if ( "filter_deep" in arch_name or "resnet" in arch_name or "Uncertainty" in arch_name ): arch = getattr(architecture, arch_name)(img_size) else: arch = getattr(architecture, arch_name)() return arch
[docs] def pop_interrupt(learn, train_conf): if click.confirm("KeyboardInterrupt, do you want to save the model?", abort=False): model_path = train_conf["model_path"] # save model LOGGER.info(f"Saving the model after epoch {learn.epoch}") save_model(learn, model_path) # plot loss plot_loss(learn, model_path) # Plot input, prediction and true image if asked if train_conf["inspection"]: create_inspection_plots(learn, train_conf) else: LOGGER.info(f"Stopping after epoch {learn.epoch}") sys.exit(1)
[docs] def end_training(learn, train_conf): # Save model save_model(learn, Path(train_conf["model_path"])) # Plot loss plot_loss(learn, Path(train_conf["model_path"]))
[docs] def get_normalisation_factors(data): mean_real = [] mean_imag = [] std_real = [] std_imag = [] for inp, _ in tqdm(data.train_ds): mean_batch_imag = inp[1].mean() mean_batch_real = inp[0].mean() std_batch_imag = inp[1].std() std_batch_real = inp[0].std() mean_real.append(mean_batch_real) mean_imag.append(mean_batch_imag) std_real.append(std_batch_real) std_imag.append(std_batch_imag) mean_real = torch.tensor(mean_real).mean() mean_imag = torch.tensor(mean_imag).mean() std_real = torch.tensor(std_real).std() std_imag = torch.tensor(std_imag).std() norm_factors = { "mean_real": mean_real, "mean_imag": mean_imag, "std_real": std_real, "std_imag": std_imag, } return norm_factors