Source code for radionets.plotting.inspection

from pathlib import Path

import matplotlib as mpl
import matplotlib.pyplot as plt

from radionets.core.logging import setup_logger

LOGGER = setup_logger(namespace=__name__)


[docs] def plot_loss(learn, model_path: str | Path, output_format: str = "png") -> None: """ Plot train and valid loss of model. Parameters ---------- learn : learner-object learner containing data and model model_path : str path to trained model """ if isinstance(model_path, str): model_path = Path(model_path) save_path = model_path.with_suffix("") LOGGER.info(f"Plotting Loss for: {model_path.stem}") logscale = learn.avg_loss.plot_loss() title = str(model_path.stem).replace("_", " ") plt.title(rf"{title}") if logscale: plt.yscale("log") plt.savefig( f"{save_path}_loss.{output_format}", bbox_inches="tight", pad_inches=0.01 ) plt.clf() mpl.rcParams.update(mpl.rcParamsDefault)
[docs] def plot_lr(learn, model_path: str | Path, output_format: str = "png") -> None: """ Plot learning rate of model. Parameters ---------- learn : learner-object learner containing data and model model_path : str or Path path to trained model output_format : """ if isinstance(model_path, str): model_path = Path(model_path) save_path = model_path.with_suffix("") LOGGER.info(f"Plotting Learning rate for: {model_path.stem}") learn.avg_loss.plot_lrs() plt.savefig(f"{save_path}_lr.{output_format}", bbox_inches="tight", pad_inches=0.01) plt.clf() mpl.rcParams.update(mpl.rcParamsDefault)
[docs] def plot_lr_loss( learn, arch_name: str, out_path: str | Path, skip_last, output_format="png" ): """ Plot loss of learning rate finder. Parameters ---------- learn : learner-object learner containing data and model arch_path : str name of the architecture out_path : str path to save loss plot skip_last : int skip n last points """ if isinstance(out_path, str): out_path = Path(out_path) LOGGER.info(f"Plotting Lr vs Loss for architecture: {arch_name}") learn.recorder.plot_lr_find() out_path.mkdir(parents=True, exist_ok=True) plt.savefig( out_path / f"lr_loss.{output_format}", bbox_inches="tight", pad_inches=0.01 ) mpl.rcParams.update(mpl.rcParamsDefault)