from math import pi
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.colors import LogNorm, PowerNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable
from pytorch_msssim import ms_ssim
from tqdm import tqdm
from radionets.evaluation.contour import compute_area_ratio
from radionets.evaluation.dynamic_range import calc_dr, get_boxsize
from radionets.evaluation.utils import check_vmin_vmax, make_axes_nice, reshape_2d
[docs]
def plot_target(h5_dataset, log=False):
index = np.random.randint(len(h5_dataset) - 1)
plt.figure(figsize=(5.78, 3.57))
target = reshape_2d(h5_dataset[index][1]).squeeze(0)
if log:
plt.imshow(target, norm=LogNorm())
else:
plt.imshow(target)
plt.xlabel("Pixels")
plt.ylabel("Pixels")
plt.colorbar(label="Intensity / a.u.")
[docs]
def plot_inp_tar(h5_dataset, fourier=False, amp_phase=False):
index = np.random.randint(len(h5_dataset) - 1)
if fourier is False:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(14.45, 3.57))
inp1 = h5_dataset[index][0][0]
lim1 = check_vmin_vmax(inp1)
im1 = ax1.imshow(inp1, cmap="RdBu", vmin=-lim1, vmax=lim1)
make_axes_nice(fig, ax1, im1, "Input: real part")
inp2 = h5_dataset[index][0][1]
lim2 = check_vmin_vmax(inp2)
im2 = ax2.imshow(inp2, cmap="RdBu", vmin=-lim2, vmax=lim2)
make_axes_nice(fig, ax2, im2, "Input: imaginary part")
tar = reshape_2d(h5_dataset[index][1]).squeeze(0)
im3 = ax3.imshow(tar, cmap="inferno")
make_axes_nice(fig, ax3, im3, "Target: source image")
if fourier is True:
if amp_phase is False:
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14.45, 8.92))
inp1 = h5_dataset[index][0][0]
lim1 = check_vmin_vmax(inp1)
im1 = ax1.imshow(inp1, cmap="RdBu", vmin=-lim1, vmax=lim1)
make_axes_nice(fig, ax1, im1, "Input: real part")
inp2 = h5_dataset[index][0][1]
lim2 = check_vmin_vmax(inp2)
im2 = ax2.imshow(inp2, cmap="RdBu", vmin=-lim2, vmax=lim2)
make_axes_nice(fig, ax2, im2, "Input: imaginary part")
tar1 = h5_dataset[index][1][0]
lim_t1 = check_vmin_vmax(tar1)
im3 = ax3.imshow(tar1, cmap="RdBu", vmin=-lim_t1, vmax=lim_t1)
make_axes_nice(fig, ax3, im3, "Target: real part")
tar2 = h5_dataset[index][1][1]
lim_t2 = check_vmin_vmax(tar2)
im4 = ax4.imshow(tar2, cmap="RdBu", vmin=-lim_t2, vmax=lim_t2)
make_axes_nice(fig, ax4, im4, "Target: imaginary part")
if amp_phase is True:
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14.45, 8.92))
inp1 = h5_dataset[index][0][0]
im1 = ax1.imshow(inp1, cmap="inferno")
make_axes_nice(fig, ax1, im1, "Input: amplitude")
inp2 = h5_dataset[index][0][1]
lim2 = check_vmin_vmax(inp2)
im2 = ax2.imshow(inp2, cmap="RdBu", vmin=-pi, vmax=pi)
make_axes_nice(fig, ax2, im2, "Input: phase")
tar1 = h5_dataset[index][1][0]
im3 = ax3.imshow(tar1, cmap="inferno")
make_axes_nice(fig, ax3, im3, "Target: amplitude")
tar2 = h5_dataset[index][1][1]
im4 = ax4.imshow(tar2, cmap="RdBu", vmin=-pi, vmax=pi)
make_axes_nice(fig, ax4, im4, "Target: phase")
fig.tight_layout()
[docs]
def visualize_with_fourier(
i: int,
img_input: torch.tensor,
img_pred: torch.tensor,
img_truth: torch.tensor,
amp_phase: bool,
out_path: Path,
plot_format: str = "png",
return_fig: bool = False,
kwargs: list[dict] | None = None,
):
"""Visualizes how the target variables are displayed in fourier space.
Parameters
----------
i : int
Current index given form the loop
img_input : :func:`torch.tensor`
Current input image as a :func:`~numpy.array` or
:func:`~torch.tensor` with shape [M, N]
img_pred : :func:`torch.tensor`
Current prediction image as a :func:`~numpy.array` or
:func:`~torch.tensor` with shape [M, N]
img_truth : :func:`torch.tensor`
Current true image as a :func:`~numpy.array` or
:func:`~torch.tensor` with shape [M, N]
amp_phase : bool
Whether the image contains real/imaginary information
or amplitude/phase information.
out_path : str which contains the output path
Output path of the figure. Skipped if ``return_fig`` is
set to ``True``.
plot_format : str, optional
Output file format. Default: png
return_fig : bool, optional
Whether to return the :func:`~matplotlib.pyplot.figure` object
instead of saving the figure to a file. Default: ``False``
**kwargs : list[dict] or None, optional
Additional list of dictionaries with keyword arguments
for each subplot. Default: ``None``
Returns
-------
fig : :func:`~matplotlib.pyplot.figure`
Figure object if ``return_fig`` is set to ``True``.
"""
# reshaping and splitting in real and imaginary part if necessary
inp_real, inp_imag = img_input[0], img_input[1]
real_pred, imag_pred = img_pred[0], img_pred[1]
real_truth, imag_truth = img_truth[0], img_truth[1]
if not kwargs:
kwargs = [{}] * 8
a = check_vmin_vmax(inp_imag)
if amp_phase:
__defaults = dict(
cmap=["inferno"] * 3 + ["radionets.PuOr"] * 5,
vmin=[None, None, None, None, -a, -np.pi, -np.pi, None],
vmax=[None, None, None, None, a, np.pi, np.pi, None],
name=["Amplitude"] * 4 + ["Phase"] * 4,
)
else:
__defaults = dict(
cmap=["radionets.PuOr"] * 8,
vmin=[None] * 8,
vmax=[None] * 8,
name=["Real"] * 4 + ["Imaginary"] * 4,
)
for i, kwarg in enumerate(kwargs):
if "cmap" not in kwarg:
kwarg["cmap"] = __defaults["cmap"][i]
if "vmin" not in kwarg:
kwarg["vmin"] = __defaults["vmin"][i]
if "vmax" not in kwarg:
kwarg["vmax"] = __defaults["vmax"][i]
fig, ax = plt.subplots(2, 4, figsize=(16, 10), sharex=True, sharey=True)
ax = ax.ravel()
im1 = ax[0].imshow(inp_real, **kwargs[0])
make_axes_nice(fig, ax[0], im1, f"{__defaults['name'][0]} Input")
im2 = ax[1].imshow(real_pred, **kwargs[1])
make_axes_nice(fig, ax[1], im2, f"{__defaults['name'][1]} Prediction")
im3 = ax[2].imshow(real_truth, **kwargs[2])
make_axes_nice(fig, ax[2], im3, f"{__defaults['name'][2]} Truth")
im4 = ax[3].imshow(real_truth - real_pred, **kwargs[3])
make_axes_nice(fig, ax[3], im4, f"{__defaults['name'][3]} Difference")
im5 = ax[4].imshow(inp_imag, **kwargs[4])
make_axes_nice(
fig,
ax[4],
im5,
f"{__defaults['name'][4]} Input",
phase=bool(amp_phase),
)
im6 = ax[5].imshow(imag_pred, **kwargs[5])
make_axes_nice(
fig,
ax[5],
im6,
f"{__defaults['name'][5]} Prediction",
phase=bool(amp_phase),
)
im7 = ax[6].imshow(imag_truth, **kwargs[6])
make_axes_nice(
fig,
ax[6],
im7,
f"{__defaults['name'][6]} Truth",
phase=bool(amp_phase),
)
im8 = ax[7].imshow(imag_truth - imag_pred, **kwargs[7])
make_axes_nice(fig, ax[7], im8, f"{__defaults['name'][7]} Difference")
ax[0].set_ylabel("Pixels")
ax[4].set_ylabel("Pixels")
for axs in ax[4:]:
axs.set_xlabel("Pixels")
if return_fig:
return fig, ax
plt.tight_layout(pad=1.5)
outpath = str(out_path) + f"/prediction_{i}.{plot_format}"
fig.savefig(outpath, bbox_inches="tight", pad_inches=0.01)
[docs]
def visualize_with_fourier_diff(
i,
img_pred,
img_truth,
amp_phase,
out_path,
plot_format="png",
):
"""
Visualizing, if the target variables are displayed in fourier space.
Parameters
----------
i : int
Current index given form the loop
img_input : array_like
Current input image as a numpy array in shape (2*img_size^2)
img_pred : array_like
Current prediction image as a numpy array with shape (2*img_size^2)
img_truth: array_like
Current true image as a numpy array with shape (2*img_size^2)
out_path: str
Which contains the output path
"""
# reshaping and splitting in real and imaginary part if necessary
real_pred, imag_pred = img_pred[0], img_pred[1]
real_truth, imag_truth = img_truth[0], img_truth[1]
# plotting
# plt.style.use('./paper_large_3_2.rc')
fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(
2, 3, figsize=(16, 10), sharex=True, sharey=True
)
if amp_phase:
im1 = ax1.imshow(real_pred, cmap="inferno")
make_axes_nice(fig, ax1, im1, r"Amplitude Prediction")
im2 = ax2.imshow(real_truth, cmap="inferno")
make_axes_nice(fig, ax2, im2, r"Amplitude Truth")
a = check_vmin_vmax(real_pred - real_truth)
im3 = ax3.imshow(real_pred - real_truth, cmap="radionets.PuOr", vmin=-a, vmax=a)
make_axes_nice(fig, ax3, im3, r"Amplitude Difference")
a = check_vmin_vmax(imag_truth)
im4 = ax4.imshow(imag_pred, cmap="radionets.PuOr", vmin=-np.pi, vmax=np.pi)
make_axes_nice(fig, ax4, im4, r"Phase Prediction", phase=True)
a = check_vmin_vmax(imag_truth)
im5 = ax5.imshow(imag_truth, cmap="radionets.PuOr", vmin=-np.pi, vmax=np.pi)
make_axes_nice(fig, ax5, im5, r"Phase Truth", phase=True)
a = check_vmin_vmax(imag_pred - imag_truth)
im6 = ax6.imshow(
imag_pred - imag_truth,
cmap="radionets.PuOr",
vmin=-2 * np.pi,
vmax=2 * np.pi,
)
make_axes_nice(fig, ax6, im6, r"Phase Difference", phase_diff=True)
else:
im1 = ax1.imshow(real_pred, cmap="inferno")
make_axes_nice(fig, ax1, im1, r"Real Prediction")
im2 = ax2.imshow(real_truth, cmap="inferno")
make_axes_nice(fig, ax2, im2, "Real Truth")
a = check_vmin_vmax(real_pred - real_truth)
im3 = ax3.imshow(real_pred - real_truth, cmap="radionets.PuOr", vmin=-a, vmax=a)
make_axes_nice(fig, ax3, im3, r"Real Difference")
im4 = ax4.imshow(imag_pred, cmap="radionets.PuOr")
make_axes_nice(fig, ax4, im4, r"Imaginary Prediction")
im5 = ax5.imshow(imag_truth, cmap="radionets.PuOr")
make_axes_nice(fig, ax5, im5, r"Imaginary Truth")
im6 = ax6.imshow(imag_pred - imag_truth, cmap="radionets.PuOr")
make_axes_nice(fig, ax6, im6, r"Imaginary Difference")
ax1.set_ylabel(r"Pixels")
ax4.set_ylabel(r"Pixels")
ax4.set_xlabel(r"Pixels")
ax5.set_xlabel(r"Pixels")
ax6.set_xlabel(r"Pixels")
plt.tight_layout(pad=1)
outpath = str(out_path) + f"/prediction_{i}.{plot_format}"
fig.savefig(outpath, bbox_inches="tight", pad_inches=0.05)
plt.close("all")
[docs]
def visualize_source_reconstruction(
ifft_pred,
ifft_truth,
out_path,
i,
dr=False,
msssim=False,
plot_format="png",
):
# plt.style.use("./paper_large_3.rc")
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16, 10), sharey=True)
# Plot prediction
im1 = ax1.imshow(ifft_pred, vmax=ifft_truth.max(), cmap="inferno")
# Plot truth
im2 = ax2.imshow(ifft_truth, cmap="inferno")
a = check_vmin_vmax(ifft_pred - ifft_truth)
im3 = ax3.imshow(ifft_pred - ifft_truth, cmap="radionets.PuOr", vmin=-a, vmax=a)
make_axes_nice(fig, ax1, im1, r"FFT Prediction")
make_axes_nice(fig, ax2, im2, r"FFT Truth")
make_axes_nice(fig, ax3, im3, r"FFT Diff")
ax1.set_ylabel(r"Pixels")
ax1.set_xlabel(r"Pixels")
ax2.set_xlabel(r"Pixels")
ax3.set_xlabel(r"Pixels")
if dr:
dr_truth, dr_pred, num_boxes, corners = calc_dr(
ifft_truth[None, ...], ifft_pred[None, ...]
)
ax1.plot([], [], " ", label=f"DR: {int(dr_pred[0])}")
ax2.plot([], [], " ", label=f"DR: {int(dr_truth[0])}")
plot_box(ax1, num_boxes, corners[0])
plot_box(ax2, num_boxes, corners[0])
if msssim:
val = ms_ssim(
torch.tensor(ifft_pred).unsqueeze(0).unsqueeze(0),
torch.tensor(ifft_truth).unsqueeze(0).unsqueeze(0),
data_range=1,
win_size=7,
size_average=False,
)
val = val.numpy()[0]
ax1.plot([], [], " ", label=f"MS-SSIM: {val:.2f}")
ax1.legend(loc="best")
outpath = str(out_path) + f"/fft_pred_{i}.{plot_format}"
fig.tight_layout(pad=1)
plt.savefig(outpath, bbox_inches="tight", pad_inches=0.05)
plt.close("all")
return np.abs(ifft_pred), np.abs(ifft_truth)
[docs]
def visualize_uncertainty(
i, img_pred, img_truth, img_unc, amp_phase, out_path, plot_format="png"
):
pred_amp, pred_phase = img_pred[0], img_pred[1]
true_amp, true_phase = img_truth[0], img_truth[1]
unc_amp, unc_phase = img_unc[0], img_unc[1]
# amplitude
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(
2, 2, sharey=True, sharex=True, figsize=(12, 10)
)
im1 = ax1.imshow(true_amp)
im2 = ax2.imshow(pred_amp)
im3 = ax3.imshow(unc_amp)
a = check_vmin_vmax(true_amp - pred_amp)
im4 = ax4.imshow(true_amp - pred_amp, cmap="radionets.PuOr", vmin=-a, vmax=a)
make_axes_nice(fig, ax1, im1, r"Simulation")
make_axes_nice(fig, ax2, im2, r"Predicted $\mu$")
make_axes_nice(fig, ax3, im3, r"Predicted $\sigma^2$", unc=True)
make_axes_nice(fig, ax4, im4, r"Difference")
ax1.set_ylabel(r"pixels")
ax3.set_ylabel(r"pixels")
ax3.set_xlabel(r"pixels")
ax4.set_xlabel(r"pixels")
fig.tight_layout(pad=1)
outpath = str(out_path) + f"/unc_amp{i}.{plot_format}"
fig.savefig(outpath, bbox_inches="tight", pad_inches=0.05)
# phase
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(
2, 2, sharey=True, sharex=True, figsize=(12, 10)
)
im1 = ax1.imshow(true_phase, cmap="radionets.PuOr")
im2 = ax2.imshow(pred_phase, cmap="radionets.PuOr")
im3 = ax3.imshow(unc_phase)
a = check_vmin_vmax(true_phase - pred_phase)
im4 = ax4.imshow(true_phase - pred_phase, cmap="radionets.PuOr", vmin=-a, vmax=a)
make_axes_nice(fig, ax1, im1, r"Simulation")
make_axes_nice(fig, ax2, im2, r"Predicted $\mu$")
make_axes_nice(fig, ax3, im3, r"Predicted $\sigma^2$", unc=True)
make_axes_nice(fig, ax4, im4, r"Difference")
ax1.set_ylabel(r"pixels")
ax3.set_ylabel(r"pixels")
ax3.set_xlabel(r"pixels")
ax4.set_xlabel(r"pixels")
fig.tight_layout(pad=1)
outpath = str(out_path) + f"/unc_phase{i}.{plot_format}"
fig.savefig(outpath, bbox_inches="tight", pad_inches=0.05)
plt.close("all")
[docs]
def visualize_sampled_unc(i, mean, std, ifft_truth, out_path, plot_format):
# plt.style.use('../paper_large_3.rc')
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(
2, 2, figsize=(12, 10), sharey=True, sharex=True
)
im1 = ax1.imshow(ifft_truth)
im2 = ax2.imshow(mean)
im3 = ax3.imshow(std)
a = check_vmin_vmax(mean - ifft_truth)
im4 = ax4.imshow(mean - ifft_truth, cmap="radionets.PuOr", vmin=-a, vmax=a)
ax1.text(
90,
110,
"Simulation",
ha="center",
size=9,
bbox=dict(
boxstyle="round",
fc="w",
ec="gray",
alpha=0.75,
),
)
ax2.text(
90,
110,
"Prediction",
ha="center",
size=9,
bbox=dict(
boxstyle="round",
fc="w",
ec="gray",
alpha=0.75,
),
)
ax3.text(
90,
110,
"Uncertainty",
ha="center",
size=9,
bbox=dict(
boxstyle="round",
fc="w",
ec="gray",
alpha=0.75,
),
)
ax4.text(
90,
110,
"Difference",
ha="center",
size=9,
bbox=dict(
boxstyle="round",
fc="w",
ec="gray",
alpha=0.75,
),
)
make_axes_nice(fig, ax1, im1, r"Simulation")
make_axes_nice(fig, ax2, im2, r"Prediction")
make_axes_nice(fig, ax3, im3, r"Uncertainty", unc=True)
make_axes_nice(fig, ax4, im4, r"Difference")
ax1.set_ylabel(r"pixels")
ax3.set_xlabel(r"pixels")
ax3.set_ylabel(r"pixels")
ax4.set_xlabel(r"pixels")
fig.tight_layout(pad=1.5)
outpath = str(out_path) + f"/unc_samp{i}.{plot_format}"
fig.savefig(outpath, bbox_inches="tight", pad_inches=0.05)
plt.close("all")
[docs]
def plot_contour(
ifft_pred,
ifft_truth,
out_path,
i,
plot_format="png",
norm_scale: float = 0.4,
labels: list | None = None,
colors: list | None = None,
levels: list | None = None,
):
if not labels:
labels = ["5%", "10%", "30%", "50%", "80%"]
if not colors:
colors = ["#454CC7", "#1984DE", "#50B3D7", "#ABD9DC", "#FFFFFF"]
if not levels:
levels = [
ifft_truth.max() * 0.05,
ifft_truth.max() * 0.1,
ifft_truth.max() * 0.3,
ifft_truth.max() * 0.5,
ifft_truth.max() * 0.8,
]
fig, ax = plt.subplots(1, 2, figsize=(10, 8), sharey=True)
im1 = ax[0].imshow(
ifft_pred,
cmap="inferno",
norm=PowerNorm(norm_scale, vmin=ifft_truth.min(), vmax=ifft_truth.max()),
)
CS1 = ax[0].contour(ifft_pred, levels=levels, colors=colors)
make_axes_nice(fig, ax[0], im1, "Prediction")
im2 = ax[1].imshow(
ifft_truth,
cmap="inferno",
norm=PowerNorm(norm_scale, vmin=ifft_truth.min(), vmax=ifft_truth.max()),
)
CS2 = ax[1].contour(ifft_truth, levels=levels, colors=colors)
diff = np.round(compute_area_ratio(CS1, CS2), 2)
make_axes_nice(fig, ax[1], im2, f"Truth, ratio: {diff}")
outpath = str(out_path) + f"/contour_{diff}_{i}.{plot_format}"
cl1, _ = CS1.legend_elements()
cl2, _ = CS2.legend_elements()
# plotting legend
ax[0].legend(cl1, labels, loc="best")
ax[1].legend(cl2, labels, loc="best")
ax[0].set_ylabel(r"Pixels")
ax[0].set_xlabel(r"Pixels")
ax[1].set_xlabel(r"Pixels")
plt.tight_layout(pad=0.75)
plt.savefig(outpath, bbox_inches="tight", pad_inches=0.05)
plt.close("all")
[docs]
def plot_box(ax, num_boxes, corners):
size = get_boxsize(num_boxes)
img_size = 64
if corners[2]:
ax.axvspan(
xmin=0,
xmax=size,
ymin=(img_size - size) / img_size,
ymax=0.99,
color="red",
fill=False,
)
if corners[3]:
ax.axvspan(
xmin=img_size - size,
xmax=img_size - 1,
ymin=(img_size - size) / img_size,
ymax=0.99,
color="red",
fill=False,
)
if corners[0]:
ax.axvspan(
xmin=0,
xmax=size,
ymin=0.01,
ymax=(size) / img_size,
color="red",
fill=False,
)
if corners[1]:
ax.axvspan(
xmin=img_size - size,
xmax=img_size - 1,
ymin=0.01,
ymax=(size) / img_size,
color="red",
fill=False,
)
[docs]
def plot_length_point(length, vals, mask, out_path, plot_format="png"):
fig, (ax1) = plt.subplots(1, figsize=(6, 4))
ax1.plot(
length[mask],
vals[mask],
".",
markersize=1,
color="darkorange",
label="Point sources",
)
ax1.plot(
length[~mask],
vals[~mask],
".",
markersize=1,
color="#1f77b4",
label="Extended sources",
)
ax1.set_ylabel("Mean specific intensity deviation")
ax1.set_xlabel("Linear extent / pixels")
plt.grid()
plt.legend(loc="best", markerscale=10)
outpath = str(out_path) + "/extend_point.png"
plt.savefig(outpath, bbox_inches="tight", pad_inches=0.01, dpi=150)
[docs]
def plot_jet_results(inp, pred, truth, path, save=False, plot_format="pdf"):
"""
Plot input images, prediction, true and diff image of the overall prediction.
(Not component wise)
Parameters
----------
inp : n 4d arrays with 1 channel
input images
pred : n 4d arrays with multiple channels
predicted images
truth : n 4d arrays with multiple channels
true images
"""
if truth.shape[1] > 2:
truth = torch.sum(truth[:, 0:-1], axis=1)
pred = torch.sum(pred[:, 0:-1], axis=1)
elif truth.shape[1] == 2:
truth = truth[:, 0:-1].squeeze()
pred = pred[:, 0:-1].squeeze()
else:
truth = truth.squeeze()
pred = pred.squeeze()
for i in tqdm(range(len(inp))):
fig, ax = plt.subplots(2, 1, sharex=True, sharey=True, figsize=(4, 7))
im1 = ax[0].imshow(inp[i, 0], cmap=plt.cm.inferno)
ax[0].set_xlabel(r"Pixels")
ax[0].set_ylabel(r"Pixels")
divider = make_axes_locatable(ax[0])
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = fig.colorbar(im1, cax=cax, orientation="vertical")
cbar.set_label(r"Specific Intensity / a.u.")
diff = pred[i] - truth[i]
im2 = ax[1].imshow(diff, cmap=plt.cm.inferno)
ax[1].set_xlabel(r"Pixels")
ax[1].set_ylabel(r"Pixels")
divider = make_axes_locatable(ax[1])
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = fig.colorbar(im2, cax=cax, orientation="vertical")
cbar.set_label(r"Specific Intensity / a.u.")
plt.tight_layout()
if save:
Path(path).mkdir(parents=True, exist_ok=True)
outpath = str(path) + f"/prediction_{i}.{plot_format}"
fig.savefig(outpath, bbox_inches="tight", pad_inches=0.01)
plt.close()
[docs]
def plot_jet_components_results(inp, pred, truth, path, save=False, plot_format="pdf"):
"""
Plot input images, prediction and true image.
Parameters
----------
inp : n 4d arrays with 1 channel
input images
pred : n 4d arrays with multiple channels
predicted images
truth : n 4d arrays with multiple channels
true images
"""
X, Y = np.meshgrid(np.arange(inp.shape[-1]), np.arange(inp.shape[-1]))
for i in tqdm(range(len(inp))):
c = truth.shape[1] - 1 # -1 because last one is the background
for j in range(c):
truth_max = torch.max(truth[i, j])
fig, axs = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(8, 7))
if truth_max != 0:
pred_max = torch.max(pred[i, j])
axs[0, 0].contour(
X, Y, truth[i, j], levels=[truth_max * 0.32], colors="white"
)
axs[0, 1].contour(
X, Y, truth[i, j], levels=[truth_max * 0.32], colors="white"
)
axs[1, 0].contour(
X, Y, truth[i, j], levels=[truth_max * 0.32], colors="white"
)
axs[1, 0].contour(
X,
Y,
pred[i, j],
levels=[pred_max * 0.32],
colors="cyan",
linestyles="dashed",
)
im1 = axs[0, 0].imshow(inp[i, 0], cmap=plt.cm.inferno)
axs[0, 0].set_xlabel(r"Pixels")
axs[0, 0].set_ylabel(r"Pixels")
divider = make_axes_locatable(axs[0, 0])
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = fig.colorbar(im1, cax=cax, orientation="vertical")
cbar.set_label(r"Specific Intensity / a.u.")
im2 = axs[0, 1].imshow(truth[i, j], cmap=plt.cm.inferno)
axs[0, 1].set_xlabel(r"Pixels")
axs[0, 1].set_ylabel(r"Pixels")
divider = make_axes_locatable(axs[0, 1])
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = fig.colorbar(im2, cax=cax, orientation="vertical")
cbar.set_label(r"Specific Intensity / a.u.")
im1 = axs[1, 0].imshow(pred[i, j], cmap=plt.cm.inferno)
axs[1, 0].set_xlabel(r"Pixels")
axs[1, 0].set_ylabel(r"Pixels")
divider = make_axes_locatable(axs[1, 0])
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = fig.colorbar(im1, cax=cax, orientation="vertical")
cbar.set_label(r"Specific Intensity / a.u.")
im4 = axs[1, 1].imshow(pred[i, j] - truth[i, j], cmap=plt.cm.inferno)
divider = make_axes_locatable(axs[1, 1])
axs[1, 1].set_xlabel(r"Pixels")
axs[1, 1].set_ylabel(r"Pixels")
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = fig.colorbar(im4, cax=cax, orientation="vertical")
cbar.set_label(r"Specific Intensity / a.u.")
plt.tight_layout(w_pad=2)
if save:
Path(path).mkdir(parents=True, exist_ok=True)
outpath = str(path) + f"/prediction_{i}_comp_{j}.{plot_format}"
fig.savefig(outpath, bbox_inches="tight", pad_inches=0.01)
plt.close()
[docs]
def plot_fitgaussian(
data, fit_list, params_list, iteration, path, save=False, plot_format="pdf"
):
"""
Plotting the sky image with the fitted gaussian distributian and the related
parameters.
Parameters
----------
data : 2d array
skymap, usually the prediction of the NN
fit : 2d array
gaussian fit around the maxima
params : list
parameters related to the gaussian: height, x, y, width_x, width_y, theta
"""
fig, axs = plt.subplots(
1,
len(params_list),
sharex=True,
sharey=True,
figsize=(4 * len(params_list), 3.5),
)
for i, (fit, params) in enumerate(zip(fit_list, params_list)):
im = axs[i].imshow(data, cmap=plt.cm.inferno)
axs[i].set_xlabel(r"Pixels")
axs[i].set_ylabel(r"Pixels")
divider = make_axes_locatable(axs[i])
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = fig.colorbar(im, cax=cax, orientation="vertical")
cbar.set_label(r"Specific Intensity / a.u.")
axs[i].contour(fit, cmap=plt.cm.gray_r)
data -= fit
(height, x, y, width_x, width_y, theta) = params.parameters
plt.text(
0.95,
0.02,
f"""
height : {height:.2f}
x : {x:.1f}
y : {y:.1f}
width_x : {width_x:.1f}
width_y : {width_y:.1f}
theta : {theta:.2f}""",
fontsize=8,
horizontalalignment="right",
c="w",
verticalalignment="bottom",
transform=axs[i].transAxes,
)
plt.tight_layout()
if save:
Path(path).mkdir(parents=True, exist_ok=True)
outpath = str(path) + f"/eval_iterativ_gaussian_{iteration}.{plot_format}"
fig.savefig(outpath, bbox_inches="tight", pad_inches=0.01)
plt.close()
[docs]
def plot_data(x, path, rows=1, cols=1, save=False, plot_format="pdf"):
"""
Plotting image of the dataset
Parameters
----------
x : array
array of shape (n, 1, size, size), n must be at least rows * cols
rows : int
number of rows in the plot
cols : int
number of cols in the plot
"""
fig, ax = plt.subplots(
rows, cols, sharex=True, sharey=True, figsize=(4 * cols, 3.5 * rows)
)
for i in range(rows):
for j in range(cols):
img = ax[i, j].imshow(x[i * cols + j, 0], cmap=plt.cm.inferno)
ax[i, j].set_xlabel(r"Pixels")
ax[i, j].set_ylabel(r"Pixels")
divider = make_axes_locatable(ax[i, j])
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = fig.colorbar(img, cax=cax, orientation="vertical")
cbar.set_label(r"Specific Intensity / a.u.")
plt.tight_layout()
if save:
Path(path).mkdir(parents=True, exist_ok=True)
outpath = str(path) + f"/simulation_examples.{plot_format}"
fig.savefig(outpath, bbox_inches="tight", pad_inches=0.01)
plt.close()