Source code for radionets.plotting.visualization

from math import pi
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.colors import LogNorm, PowerNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable
from pytorch_msssim import ms_ssim
from tqdm import tqdm

from radionets.evaluation.contour import compute_area_ratio
from radionets.evaluation.dynamic_range import calc_dr, get_boxsize
from radionets.evaluation.utils import check_vmin_vmax, make_axes_nice, reshape_2d


[docs] def plot_target(h5_dataset, log=False): index = np.random.randint(len(h5_dataset) - 1) plt.figure(figsize=(5.78, 3.57)) target = reshape_2d(h5_dataset[index][1]).squeeze(0) if log: plt.imshow(target, norm=LogNorm()) else: plt.imshow(target) plt.xlabel("Pixels") plt.ylabel("Pixels") plt.colorbar(label="Intensity / a.u.")
[docs] def plot_inp_tar(h5_dataset, fourier=False, amp_phase=False): index = np.random.randint(len(h5_dataset) - 1) if fourier is False: fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(14.45, 3.57)) inp1 = h5_dataset[index][0][0] lim1 = check_vmin_vmax(inp1) im1 = ax1.imshow(inp1, cmap="RdBu", vmin=-lim1, vmax=lim1) make_axes_nice(fig, ax1, im1, "Input: real part") inp2 = h5_dataset[index][0][1] lim2 = check_vmin_vmax(inp2) im2 = ax2.imshow(inp2, cmap="RdBu", vmin=-lim2, vmax=lim2) make_axes_nice(fig, ax2, im2, "Input: imaginary part") tar = reshape_2d(h5_dataset[index][1]).squeeze(0) im3 = ax3.imshow(tar, cmap="inferno") make_axes_nice(fig, ax3, im3, "Target: source image") if fourier is True: if amp_phase is False: fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14.45, 8.92)) inp1 = h5_dataset[index][0][0] lim1 = check_vmin_vmax(inp1) im1 = ax1.imshow(inp1, cmap="RdBu", vmin=-lim1, vmax=lim1) make_axes_nice(fig, ax1, im1, "Input: real part") inp2 = h5_dataset[index][0][1] lim2 = check_vmin_vmax(inp2) im2 = ax2.imshow(inp2, cmap="RdBu", vmin=-lim2, vmax=lim2) make_axes_nice(fig, ax2, im2, "Input: imaginary part") tar1 = h5_dataset[index][1][0] lim_t1 = check_vmin_vmax(tar1) im3 = ax3.imshow(tar1, cmap="RdBu", vmin=-lim_t1, vmax=lim_t1) make_axes_nice(fig, ax3, im3, "Target: real part") tar2 = h5_dataset[index][1][1] lim_t2 = check_vmin_vmax(tar2) im4 = ax4.imshow(tar2, cmap="RdBu", vmin=-lim_t2, vmax=lim_t2) make_axes_nice(fig, ax4, im4, "Target: imaginary part") if amp_phase is True: fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14.45, 8.92)) inp1 = h5_dataset[index][0][0] im1 = ax1.imshow(inp1, cmap="inferno") make_axes_nice(fig, ax1, im1, "Input: amplitude") inp2 = h5_dataset[index][0][1] lim2 = check_vmin_vmax(inp2) im2 = ax2.imshow(inp2, cmap="RdBu", vmin=-pi, vmax=pi) make_axes_nice(fig, ax2, im2, "Input: phase") tar1 = h5_dataset[index][1][0] im3 = ax3.imshow(tar1, cmap="inferno") make_axes_nice(fig, ax3, im3, "Target: amplitude") tar2 = h5_dataset[index][1][1] im4 = ax4.imshow(tar2, cmap="RdBu", vmin=-pi, vmax=pi) make_axes_nice(fig, ax4, im4, "Target: phase") fig.tight_layout()
[docs] def visualize_with_fourier( i: int, img_input: torch.tensor, img_pred: torch.tensor, img_truth: torch.tensor, amp_phase: bool, out_path: Path, plot_format: str = "png", return_fig: bool = False, kwargs: list[dict] | None = None, ): """Visualizes how the target variables are displayed in fourier space. Parameters ---------- i : int Current index given form the loop img_input : :func:`torch.tensor` Current input image as a :func:`~numpy.array` or :func:`~torch.tensor` with shape [M, N] img_pred : :func:`torch.tensor` Current prediction image as a :func:`~numpy.array` or :func:`~torch.tensor` with shape [M, N] img_truth : :func:`torch.tensor` Current true image as a :func:`~numpy.array` or :func:`~torch.tensor` with shape [M, N] amp_phase : bool Whether the image contains real/imaginary information or amplitude/phase information. out_path : str which contains the output path Output path of the figure. Skipped if ``return_fig`` is set to ``True``. plot_format : str, optional Output file format. Default: png return_fig : bool, optional Whether to return the :func:`~matplotlib.pyplot.figure` object instead of saving the figure to a file. Default: ``False`` **kwargs : list[dict] or None, optional Additional list of dictionaries with keyword arguments for each subplot. Default: ``None`` Returns ------- fig : :func:`~matplotlib.pyplot.figure` Figure object if ``return_fig`` is set to ``True``. """ # reshaping and splitting in real and imaginary part if necessary inp_real, inp_imag = img_input[0], img_input[1] real_pred, imag_pred = img_pred[0], img_pred[1] real_truth, imag_truth = img_truth[0], img_truth[1] if not kwargs: kwargs = [{}] * 8 a = check_vmin_vmax(inp_imag) if amp_phase: __defaults = dict( cmap=["inferno"] * 3 + ["radionets.PuOr"] * 5, vmin=[None, None, None, None, -a, -np.pi, -np.pi, None], vmax=[None, None, None, None, a, np.pi, np.pi, None], name=["Amplitude"] * 4 + ["Phase"] * 4, ) else: __defaults = dict( cmap=["radionets.PuOr"] * 8, vmin=[None] * 8, vmax=[None] * 8, name=["Real"] * 4 + ["Imaginary"] * 4, ) for i, kwarg in enumerate(kwargs): if "cmap" not in kwarg: kwarg["cmap"] = __defaults["cmap"][i] if "vmin" not in kwarg: kwarg["vmin"] = __defaults["vmin"][i] if "vmax" not in kwarg: kwarg["vmax"] = __defaults["vmax"][i] fig, ax = plt.subplots(2, 4, figsize=(16, 10), sharex=True, sharey=True) ax = ax.ravel() im1 = ax[0].imshow(inp_real, **kwargs[0]) make_axes_nice(fig, ax[0], im1, f"{__defaults['name'][0]} Input") im2 = ax[1].imshow(real_pred, **kwargs[1]) make_axes_nice(fig, ax[1], im2, f"{__defaults['name'][1]} Prediction") im3 = ax[2].imshow(real_truth, **kwargs[2]) make_axes_nice(fig, ax[2], im3, f"{__defaults['name'][2]} Truth") im4 = ax[3].imshow(real_truth - real_pred, **kwargs[3]) make_axes_nice(fig, ax[3], im4, f"{__defaults['name'][3]} Difference") im5 = ax[4].imshow(inp_imag, **kwargs[4]) make_axes_nice( fig, ax[4], im5, f"{__defaults['name'][4]} Input", phase=bool(amp_phase), ) im6 = ax[5].imshow(imag_pred, **kwargs[5]) make_axes_nice( fig, ax[5], im6, f"{__defaults['name'][5]} Prediction", phase=bool(amp_phase), ) im7 = ax[6].imshow(imag_truth, **kwargs[6]) make_axes_nice( fig, ax[6], im7, f"{__defaults['name'][6]} Truth", phase=bool(amp_phase), ) im8 = ax[7].imshow(imag_truth - imag_pred, **kwargs[7]) make_axes_nice(fig, ax[7], im8, f"{__defaults['name'][7]} Difference") ax[0].set_ylabel("Pixels") ax[4].set_ylabel("Pixels") for axs in ax[4:]: axs.set_xlabel("Pixels") if return_fig: return fig, ax plt.tight_layout(pad=1.5) outpath = str(out_path) + f"/prediction_{i}.{plot_format}" fig.savefig(outpath, bbox_inches="tight", pad_inches=0.01)
[docs] def visualize_with_fourier_diff( i, img_pred, img_truth, amp_phase, out_path, plot_format="png", ): """ Visualizing, if the target variables are displayed in fourier space. Parameters ---------- i : int Current index given form the loop img_input : array_like Current input image as a numpy array in shape (2*img_size^2) img_pred : array_like Current prediction image as a numpy array with shape (2*img_size^2) img_truth: array_like Current true image as a numpy array with shape (2*img_size^2) out_path: str Which contains the output path """ # reshaping and splitting in real and imaginary part if necessary real_pred, imag_pred = img_pred[0], img_pred[1] real_truth, imag_truth = img_truth[0], img_truth[1] # plotting # plt.style.use('./paper_large_3_2.rc') fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots( 2, 3, figsize=(16, 10), sharex=True, sharey=True ) if amp_phase: im1 = ax1.imshow(real_pred, cmap="inferno") make_axes_nice(fig, ax1, im1, r"Amplitude Prediction") im2 = ax2.imshow(real_truth, cmap="inferno") make_axes_nice(fig, ax2, im2, r"Amplitude Truth") a = check_vmin_vmax(real_pred - real_truth) im3 = ax3.imshow(real_pred - real_truth, cmap="radionets.PuOr", vmin=-a, vmax=a) make_axes_nice(fig, ax3, im3, r"Amplitude Difference") a = check_vmin_vmax(imag_truth) im4 = ax4.imshow(imag_pred, cmap="radionets.PuOr", vmin=-np.pi, vmax=np.pi) make_axes_nice(fig, ax4, im4, r"Phase Prediction", phase=True) a = check_vmin_vmax(imag_truth) im5 = ax5.imshow(imag_truth, cmap="radionets.PuOr", vmin=-np.pi, vmax=np.pi) make_axes_nice(fig, ax5, im5, r"Phase Truth", phase=True) a = check_vmin_vmax(imag_pred - imag_truth) im6 = ax6.imshow( imag_pred - imag_truth, cmap="radionets.PuOr", vmin=-2 * np.pi, vmax=2 * np.pi, ) make_axes_nice(fig, ax6, im6, r"Phase Difference", phase_diff=True) else: im1 = ax1.imshow(real_pred, cmap="inferno") make_axes_nice(fig, ax1, im1, r"Real Prediction") im2 = ax2.imshow(real_truth, cmap="inferno") make_axes_nice(fig, ax2, im2, "Real Truth") a = check_vmin_vmax(real_pred - real_truth) im3 = ax3.imshow(real_pred - real_truth, cmap="radionets.PuOr", vmin=-a, vmax=a) make_axes_nice(fig, ax3, im3, r"Real Difference") im4 = ax4.imshow(imag_pred, cmap="radionets.PuOr") make_axes_nice(fig, ax4, im4, r"Imaginary Prediction") im5 = ax5.imshow(imag_truth, cmap="radionets.PuOr") make_axes_nice(fig, ax5, im5, r"Imaginary Truth") im6 = ax6.imshow(imag_pred - imag_truth, cmap="radionets.PuOr") make_axes_nice(fig, ax6, im6, r"Imaginary Difference") ax1.set_ylabel(r"Pixels") ax4.set_ylabel(r"Pixels") ax4.set_xlabel(r"Pixels") ax5.set_xlabel(r"Pixels") ax6.set_xlabel(r"Pixels") plt.tight_layout(pad=1) outpath = str(out_path) + f"/prediction_{i}.{plot_format}" fig.savefig(outpath, bbox_inches="tight", pad_inches=0.05) plt.close("all")
[docs] def visualize_source_reconstruction( ifft_pred, ifft_truth, out_path, i, dr=False, msssim=False, plot_format="png", ): # plt.style.use("./paper_large_3.rc") fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16, 10), sharey=True) # Plot prediction im1 = ax1.imshow(ifft_pred, vmax=ifft_truth.max(), cmap="inferno") # Plot truth im2 = ax2.imshow(ifft_truth, cmap="inferno") a = check_vmin_vmax(ifft_pred - ifft_truth) im3 = ax3.imshow(ifft_pred - ifft_truth, cmap="radionets.PuOr", vmin=-a, vmax=a) make_axes_nice(fig, ax1, im1, r"FFT Prediction") make_axes_nice(fig, ax2, im2, r"FFT Truth") make_axes_nice(fig, ax3, im3, r"FFT Diff") ax1.set_ylabel(r"Pixels") ax1.set_xlabel(r"Pixels") ax2.set_xlabel(r"Pixels") ax3.set_xlabel(r"Pixels") if dr: dr_truth, dr_pred, num_boxes, corners = calc_dr( ifft_truth[None, ...], ifft_pred[None, ...] ) ax1.plot([], [], " ", label=f"DR: {int(dr_pred[0])}") ax2.plot([], [], " ", label=f"DR: {int(dr_truth[0])}") plot_box(ax1, num_boxes, corners[0]) plot_box(ax2, num_boxes, corners[0]) if msssim: val = ms_ssim( torch.tensor(ifft_pred).unsqueeze(0).unsqueeze(0), torch.tensor(ifft_truth).unsqueeze(0).unsqueeze(0), data_range=1, win_size=7, size_average=False, ) val = val.numpy()[0] ax1.plot([], [], " ", label=f"MS-SSIM: {val:.2f}") ax1.legend(loc="best") outpath = str(out_path) + f"/fft_pred_{i}.{plot_format}" fig.tight_layout(pad=1) plt.savefig(outpath, bbox_inches="tight", pad_inches=0.05) plt.close("all") return np.abs(ifft_pred), np.abs(ifft_truth)
[docs] def visualize_uncertainty( i, img_pred, img_truth, img_unc, amp_phase, out_path, plot_format="png" ): pred_amp, pred_phase = img_pred[0], img_pred[1] true_amp, true_phase = img_truth[0], img_truth[1] unc_amp, unc_phase = img_unc[0], img_unc[1] # amplitude fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots( 2, 2, sharey=True, sharex=True, figsize=(12, 10) ) im1 = ax1.imshow(true_amp) im2 = ax2.imshow(pred_amp) im3 = ax3.imshow(unc_amp) a = check_vmin_vmax(true_amp - pred_amp) im4 = ax4.imshow(true_amp - pred_amp, cmap="radionets.PuOr", vmin=-a, vmax=a) make_axes_nice(fig, ax1, im1, r"Simulation") make_axes_nice(fig, ax2, im2, r"Predicted $\mu$") make_axes_nice(fig, ax3, im3, r"Predicted $\sigma^2$", unc=True) make_axes_nice(fig, ax4, im4, r"Difference") ax1.set_ylabel(r"pixels") ax3.set_ylabel(r"pixels") ax3.set_xlabel(r"pixels") ax4.set_xlabel(r"pixels") fig.tight_layout(pad=1) outpath = str(out_path) + f"/unc_amp{i}.{plot_format}" fig.savefig(outpath, bbox_inches="tight", pad_inches=0.05) # phase fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots( 2, 2, sharey=True, sharex=True, figsize=(12, 10) ) im1 = ax1.imshow(true_phase, cmap="radionets.PuOr") im2 = ax2.imshow(pred_phase, cmap="radionets.PuOr") im3 = ax3.imshow(unc_phase) a = check_vmin_vmax(true_phase - pred_phase) im4 = ax4.imshow(true_phase - pred_phase, cmap="radionets.PuOr", vmin=-a, vmax=a) make_axes_nice(fig, ax1, im1, r"Simulation") make_axes_nice(fig, ax2, im2, r"Predicted $\mu$") make_axes_nice(fig, ax3, im3, r"Predicted $\sigma^2$", unc=True) make_axes_nice(fig, ax4, im4, r"Difference") ax1.set_ylabel(r"pixels") ax3.set_ylabel(r"pixels") ax3.set_xlabel(r"pixels") ax4.set_xlabel(r"pixels") fig.tight_layout(pad=1) outpath = str(out_path) + f"/unc_phase{i}.{plot_format}" fig.savefig(outpath, bbox_inches="tight", pad_inches=0.05) plt.close("all")
[docs] def visualize_sampled_unc(i, mean, std, ifft_truth, out_path, plot_format): # plt.style.use('../paper_large_3.rc') fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots( 2, 2, figsize=(12, 10), sharey=True, sharex=True ) im1 = ax1.imshow(ifft_truth) im2 = ax2.imshow(mean) im3 = ax3.imshow(std) a = check_vmin_vmax(mean - ifft_truth) im4 = ax4.imshow(mean - ifft_truth, cmap="radionets.PuOr", vmin=-a, vmax=a) ax1.text( 90, 110, "Simulation", ha="center", size=9, bbox=dict( boxstyle="round", fc="w", ec="gray", alpha=0.75, ), ) ax2.text( 90, 110, "Prediction", ha="center", size=9, bbox=dict( boxstyle="round", fc="w", ec="gray", alpha=0.75, ), ) ax3.text( 90, 110, "Uncertainty", ha="center", size=9, bbox=dict( boxstyle="round", fc="w", ec="gray", alpha=0.75, ), ) ax4.text( 90, 110, "Difference", ha="center", size=9, bbox=dict( boxstyle="round", fc="w", ec="gray", alpha=0.75, ), ) make_axes_nice(fig, ax1, im1, r"Simulation") make_axes_nice(fig, ax2, im2, r"Prediction") make_axes_nice(fig, ax3, im3, r"Uncertainty", unc=True) make_axes_nice(fig, ax4, im4, r"Difference") ax1.set_ylabel(r"pixels") ax3.set_xlabel(r"pixels") ax3.set_ylabel(r"pixels") ax4.set_xlabel(r"pixels") fig.tight_layout(pad=1.5) outpath = str(out_path) + f"/unc_samp{i}.{plot_format}" fig.savefig(outpath, bbox_inches="tight", pad_inches=0.05) plt.close("all")
[docs] def plot_contour( ifft_pred, ifft_truth, out_path, i, plot_format="png", norm_scale: float = 0.4, labels: list | None = None, colors: list | None = None, levels: list | None = None, ): if not labels: labels = ["5%", "10%", "30%", "50%", "80%"] if not colors: colors = ["#454CC7", "#1984DE", "#50B3D7", "#ABD9DC", "#FFFFFF"] if not levels: levels = [ ifft_truth.max() * 0.05, ifft_truth.max() * 0.1, ifft_truth.max() * 0.3, ifft_truth.max() * 0.5, ifft_truth.max() * 0.8, ] fig, ax = plt.subplots(1, 2, figsize=(10, 8), sharey=True) im1 = ax[0].imshow( ifft_pred, cmap="inferno", norm=PowerNorm(norm_scale, vmin=ifft_truth.min(), vmax=ifft_truth.max()), ) CS1 = ax[0].contour(ifft_pred, levels=levels, colors=colors) make_axes_nice(fig, ax[0], im1, "Prediction") im2 = ax[1].imshow( ifft_truth, cmap="inferno", norm=PowerNorm(norm_scale, vmin=ifft_truth.min(), vmax=ifft_truth.max()), ) CS2 = ax[1].contour(ifft_truth, levels=levels, colors=colors) diff = np.round(compute_area_ratio(CS1, CS2), 2) make_axes_nice(fig, ax[1], im2, f"Truth, ratio: {diff}") outpath = str(out_path) + f"/contour_{diff}_{i}.{plot_format}" cl1, _ = CS1.legend_elements() cl2, _ = CS2.legend_elements() # plotting legend ax[0].legend(cl1, labels, loc="best") ax[1].legend(cl2, labels, loc="best") ax[0].set_ylabel(r"Pixels") ax[0].set_xlabel(r"Pixels") ax[1].set_xlabel(r"Pixels") plt.tight_layout(pad=0.75) plt.savefig(outpath, bbox_inches="tight", pad_inches=0.05) plt.close("all")
[docs] def plot_box(ax, num_boxes, corners): size = get_boxsize(num_boxes) img_size = 64 if corners[2]: ax.axvspan( xmin=0, xmax=size, ymin=(img_size - size) / img_size, ymax=0.99, color="red", fill=False, ) if corners[3]: ax.axvspan( xmin=img_size - size, xmax=img_size - 1, ymin=(img_size - size) / img_size, ymax=0.99, color="red", fill=False, ) if corners[0]: ax.axvspan( xmin=0, xmax=size, ymin=0.01, ymax=(size) / img_size, color="red", fill=False, ) if corners[1]: ax.axvspan( xmin=img_size - size, xmax=img_size - 1, ymin=0.01, ymax=(size) / img_size, color="red", fill=False, )
[docs] def plot_length_point(length, vals, mask, out_path, plot_format="png"): fig, (ax1) = plt.subplots(1, figsize=(6, 4)) ax1.plot( length[mask], vals[mask], ".", markersize=1, color="darkorange", label="Point sources", ) ax1.plot( length[~mask], vals[~mask], ".", markersize=1, color="#1f77b4", label="Extended sources", ) ax1.set_ylabel("Mean specific intensity deviation") ax1.set_xlabel("Linear extent / pixels") plt.grid() plt.legend(loc="best", markerscale=10) outpath = str(out_path) + "/extend_point.png" plt.savefig(outpath, bbox_inches="tight", pad_inches=0.01, dpi=150)
[docs] def plot_jet_results(inp, pred, truth, path, save=False, plot_format="pdf"): """ Plot input images, prediction, true and diff image of the overall prediction. (Not component wise) Parameters ---------- inp : n 4d arrays with 1 channel input images pred : n 4d arrays with multiple channels predicted images truth : n 4d arrays with multiple channels true images """ if truth.shape[1] > 2: truth = torch.sum(truth[:, 0:-1], axis=1) pred = torch.sum(pred[:, 0:-1], axis=1) elif truth.shape[1] == 2: truth = truth[:, 0:-1].squeeze() pred = pred[:, 0:-1].squeeze() else: truth = truth.squeeze() pred = pred.squeeze() for i in tqdm(range(len(inp))): fig, ax = plt.subplots(2, 1, sharex=True, sharey=True, figsize=(4, 7)) im1 = ax[0].imshow(inp[i, 0], cmap=plt.cm.inferno) ax[0].set_xlabel(r"Pixels") ax[0].set_ylabel(r"Pixels") divider = make_axes_locatable(ax[0]) cax = divider.append_axes("right", size="5%", pad=0.05) cbar = fig.colorbar(im1, cax=cax, orientation="vertical") cbar.set_label(r"Specific Intensity / a.u.") diff = pred[i] - truth[i] im2 = ax[1].imshow(diff, cmap=plt.cm.inferno) ax[1].set_xlabel(r"Pixels") ax[1].set_ylabel(r"Pixels") divider = make_axes_locatable(ax[1]) cax = divider.append_axes("right", size="5%", pad=0.05) cbar = fig.colorbar(im2, cax=cax, orientation="vertical") cbar.set_label(r"Specific Intensity / a.u.") plt.tight_layout() if save: Path(path).mkdir(parents=True, exist_ok=True) outpath = str(path) + f"/prediction_{i}.{plot_format}" fig.savefig(outpath, bbox_inches="tight", pad_inches=0.01) plt.close()
[docs] def plot_jet_components_results(inp, pred, truth, path, save=False, plot_format="pdf"): """ Plot input images, prediction and true image. Parameters ---------- inp : n 4d arrays with 1 channel input images pred : n 4d arrays with multiple channels predicted images truth : n 4d arrays with multiple channels true images """ X, Y = np.meshgrid(np.arange(inp.shape[-1]), np.arange(inp.shape[-1])) for i in tqdm(range(len(inp))): c = truth.shape[1] - 1 # -1 because last one is the background for j in range(c): truth_max = torch.max(truth[i, j]) fig, axs = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(8, 7)) if truth_max != 0: pred_max = torch.max(pred[i, j]) axs[0, 0].contour( X, Y, truth[i, j], levels=[truth_max * 0.32], colors="white" ) axs[0, 1].contour( X, Y, truth[i, j], levels=[truth_max * 0.32], colors="white" ) axs[1, 0].contour( X, Y, truth[i, j], levels=[truth_max * 0.32], colors="white" ) axs[1, 0].contour( X, Y, pred[i, j], levels=[pred_max * 0.32], colors="cyan", linestyles="dashed", ) im1 = axs[0, 0].imshow(inp[i, 0], cmap=plt.cm.inferno) axs[0, 0].set_xlabel(r"Pixels") axs[0, 0].set_ylabel(r"Pixels") divider = make_axes_locatable(axs[0, 0]) cax = divider.append_axes("right", size="5%", pad=0.05) cbar = fig.colorbar(im1, cax=cax, orientation="vertical") cbar.set_label(r"Specific Intensity / a.u.") im2 = axs[0, 1].imshow(truth[i, j], cmap=plt.cm.inferno) axs[0, 1].set_xlabel(r"Pixels") axs[0, 1].set_ylabel(r"Pixels") divider = make_axes_locatable(axs[0, 1]) cax = divider.append_axes("right", size="5%", pad=0.05) cbar = fig.colorbar(im2, cax=cax, orientation="vertical") cbar.set_label(r"Specific Intensity / a.u.") im1 = axs[1, 0].imshow(pred[i, j], cmap=plt.cm.inferno) axs[1, 0].set_xlabel(r"Pixels") axs[1, 0].set_ylabel(r"Pixels") divider = make_axes_locatable(axs[1, 0]) cax = divider.append_axes("right", size="5%", pad=0.05) cbar = fig.colorbar(im1, cax=cax, orientation="vertical") cbar.set_label(r"Specific Intensity / a.u.") im4 = axs[1, 1].imshow(pred[i, j] - truth[i, j], cmap=plt.cm.inferno) divider = make_axes_locatable(axs[1, 1]) axs[1, 1].set_xlabel(r"Pixels") axs[1, 1].set_ylabel(r"Pixels") cax = divider.append_axes("right", size="5%", pad=0.05) cbar = fig.colorbar(im4, cax=cax, orientation="vertical") cbar.set_label(r"Specific Intensity / a.u.") plt.tight_layout(w_pad=2) if save: Path(path).mkdir(parents=True, exist_ok=True) outpath = str(path) + f"/prediction_{i}_comp_{j}.{plot_format}" fig.savefig(outpath, bbox_inches="tight", pad_inches=0.01) plt.close()
[docs] def plot_fitgaussian( data, fit_list, params_list, iteration, path, save=False, plot_format="pdf" ): """ Plotting the sky image with the fitted gaussian distributian and the related parameters. Parameters ---------- data : 2d array skymap, usually the prediction of the NN fit : 2d array gaussian fit around the maxima params : list parameters related to the gaussian: height, x, y, width_x, width_y, theta """ fig, axs = plt.subplots( 1, len(params_list), sharex=True, sharey=True, figsize=(4 * len(params_list), 3.5), ) for i, (fit, params) in enumerate(zip(fit_list, params_list)): im = axs[i].imshow(data, cmap=plt.cm.inferno) axs[i].set_xlabel(r"Pixels") axs[i].set_ylabel(r"Pixels") divider = make_axes_locatable(axs[i]) cax = divider.append_axes("right", size="5%", pad=0.05) cbar = fig.colorbar(im, cax=cax, orientation="vertical") cbar.set_label(r"Specific Intensity / a.u.") axs[i].contour(fit, cmap=plt.cm.gray_r) data -= fit (height, x, y, width_x, width_y, theta) = params.parameters plt.text( 0.95, 0.02, f""" height : {height:.2f} x : {x:.1f} y : {y:.1f} width_x : {width_x:.1f} width_y : {width_y:.1f} theta : {theta:.2f}""", fontsize=8, horizontalalignment="right", c="w", verticalalignment="bottom", transform=axs[i].transAxes, ) plt.tight_layout() if save: Path(path).mkdir(parents=True, exist_ok=True) outpath = str(path) + f"/eval_iterativ_gaussian_{iteration}.{plot_format}" fig.savefig(outpath, bbox_inches="tight", pad_inches=0.01) plt.close()
[docs] def plot_data(x, path, rows=1, cols=1, save=False, plot_format="pdf"): """ Plotting image of the dataset Parameters ---------- x : array array of shape (n, 1, size, size), n must be at least rows * cols rows : int number of rows in the plot cols : int number of cols in the plot """ fig, ax = plt.subplots( rows, cols, sharex=True, sharey=True, figsize=(4 * cols, 3.5 * rows) ) for i in range(rows): for j in range(cols): img = ax[i, j].imshow(x[i * cols + j, 0], cmap=plt.cm.inferno) ax[i, j].set_xlabel(r"Pixels") ax[i, j].set_ylabel(r"Pixels") divider = make_axes_locatable(ax[i, j]) cax = divider.append_axes("right", size="5%", pad=0.05) cbar = fig.colorbar(img, cax=cax, orientation="vertical") cbar.set_label(r"Specific Intensity / a.u.") plt.tight_layout() if save: Path(path).mkdir(parents=True, exist_ok=True) outpath = str(path) + f"/simulation_examples.{plot_format}" fig.savefig(outpath, bbox_inches="tight", pad_inches=0.01) plt.close()