from pathlib import Path
import h5py
import numpy as np
import torch
import torch.nn.functional as F
from numba import set_num_threads, vectorize
from torch.utils.data import DataLoader
from radionets import architecture
from radionets.core.data import load_data
from radionets.core.model import load_pre_model
[docs]
def source_list_collate(batch):
"""Collate function for the DataLoader with source list
Parameters
----------
batch : tuple
input and target images alongside with the corresponding source_list
Returns
-------
tuple
stacked images and list for source_list values
"""
x = [item[0] for item in batch]
y = [item[1] for item in batch]
z = [item[2][0] for item in batch]
return torch.stack(x), torch.stack(y), z
[docs]
def create_databunch(data_path, fourier, source_list, batch_size):
"""Create a dataloader object, which feeds the data batch-wise
Parameters
----------
data_path : str
path to the data
fourier : bool
true, if data in Fourier space is used
source_list : bool
true, if source_list data is used
batch_size : int
number of images for one batch
Returns
-------
DataLoader
dataloader object
"""
# Load data sets
test_ds = load_data(data_path, mode="test", fourier=fourier)
# Create databunch with defined batchsize and check for source_list
if source_list:
data = DataLoader(
test_ds, batch_size=batch_size, shuffle=True, collate_fn=source_list_collate
)
else:
data = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
return data
[docs]
def create_sampled_databunch(data_path, batch_size):
"""Create a dataloader object, which feeds the data batch-wise
Parameters
----------
data_path : str
path to the data
fourier : bool
true, if data in Fourier space is used
source_list : bool
true, if source_list data is used
batch_size : int
number of images for one batch
Returns
-------
DataLoader
dataloader object
"""
# Load data sets
test_ds = sampled_dataset(data_path)
data = DataLoader(test_ds, batch_size=batch_size, shuffle=True)
return data
[docs]
def read_config(config):
"""Parse the toml config file
Parameters
----------
config : dict
dict which contains the configurations loaded with toml.load
Returns
-------
dict
dict containing all configurations with unique keywords
"""
eval_conf = {}
eval_conf["data_path"] = config["paths"]["data_path"]
eval_conf["model_path"] = config["paths"]["model_path"]
eval_conf["model_path_2"] = config["paths"]["model_path_2"]
eval_conf["quiet"] = config["mode"]["quiet"]
eval_conf["format"] = config["general"]["output_format"]
eval_conf["fourier"] = config["general"]["fourier"]
eval_conf["amp_phase"] = config["general"]["amp_phase"]
eval_conf["arch_name"] = config["general"]["arch_name"]
eval_conf["source_list"] = config["general"]["source_list"]
eval_conf["arch_name_2"] = config["general"]["arch_name_2"]
eval_conf["diff"] = config["general"]["diff"]
eval_conf["vis_pred"] = config["inspection"]["visualize_prediction"]
eval_conf["vis_source"] = config["inspection"]["visualize_source_reconstruction"]
eval_conf["sample_unc"] = config["inspection"]["sample_uncertainty"]
eval_conf["unc"] = config["inspection"]["visualize_uncertainty"]
eval_conf["plot_contour"] = config["inspection"]["visualize_contour"]
eval_conf["vis_dr"] = config["inspection"]["visualize_dynamic_range"]
eval_conf["vis_ms_ssim"] = config["inspection"]["visualize_ms_ssim"]
eval_conf["num_images"] = config["inspection"]["num_images"]
eval_conf["random"] = config["inspection"]["random"]
eval_conf["viewing_angle"] = config["eval"]["evaluate_viewing_angle"]
eval_conf["dynamic_range"] = config["eval"]["evaluate_dynamic_range"]
eval_conf["ms_ssim"] = config["eval"]["evaluate_ms_ssim"]
eval_conf["intensity"] = config["eval"]["evaluate_intensity"]
eval_conf["mean_diff"] = config["eval"]["evaluate_mean_diff"]
eval_conf["area"] = config["eval"]["evaluate_area"]
eval_conf["batch_size"] = config["eval"]["batch_size"]
eval_conf["point"] = config["eval"]["evaluate_point"]
eval_conf["predict_grad"] = config["eval"]["predict_grad"]
eval_conf["gan"] = config["eval"]["evaluate_gan"]
eval_conf["save_vals"] = config["eval"]["save_vals"]
eval_conf["save_path"] = config["eval"]["save_path"]
return eval_conf
[docs]
def reshape_2d(array):
"""Reshape 1d arrays into 2d ones.
Parameters
----------
array: 1d array
input array
Returns
-------
array: 2d array
reshaped array
"""
shape = [int(np.sqrt(array.shape[-1]))] * 2
return array.reshape(-1, *shape)
[docs]
def make_axes_nice(fig, ax, im, title, phase=False, phase_diff=False, unc=False):
"""Create nice colorbars with bigger label size for every axis in a subplot.
Also use ticks for the phase.
Parameters
----------
fig : figure object
current figure
ax : axis object
current axis
im : ndarray
plotted image
title : str
title of subplot
"""
from mpl_toolkits.axes_grid1 import make_axes_locatable
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
ax.set_title(title)
if phase:
cbar = fig.colorbar(
im,
cax=cax,
orientation="vertical",
ticks=[-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi],
)
cbar.set_label("Phase / rad")
elif phase_diff:
cbar = fig.colorbar(
im,
cax=cax,
orientation="vertical",
ticks=[-2 * np.pi, -np.pi, 0, np.pi, 2 * np.pi],
)
cbar.set_label("Phase / rad")
elif unc:
cbar = fig.colorbar(im, cax=cax, orientation="vertical")
cbar.set_label(r"$\sigma$ / $\mathrm{Jy \cdot px^{-1}}$")
else:
cbar = fig.colorbar(im, cax=cax, orientation="vertical")
cbar.set_label(r"$\mathrm{Flux \ density / Jy \cdot px^{-1}}$")
if phase:
# set ticks for colorbar
cbar.ax.set_yticklabels([r"$-\pi$", r"$-\pi/2$", r"$0$", r"$\pi/2$", r"$\pi$"])
elif phase_diff:
# set ticks for colorbar
cbar.ax.set_yticklabels([r"$-2\pi$", r"$-\pi$", r"$0$", r"$\pi$", r"$2\pi$"])
[docs]
def check_vmin_vmax(inp):
"""Check wether the absolute of the maxmimum or the minimum is bigger.
If the minimum is bigger, return value with minus. Otherwise return
maximum.
Parameters
----------
inp : float
input image
Returns
-------
float
negative minimal or maximal value
"""
a = -inp.min() if np.abs(inp.min()) > np.abs(inp.max()) else inp.max()
return a
[docs]
def load_pretrained_model(arch_name, model_path, img_size=63):
"""Load model architecture and pretrained weigths.
Parameters
----------
arch_name : str
name of the architecture
model_path : str
path to pretrained model
Returns
-------
arch : architecture object
architecture with pretrained weigths
"""
if (
"filter_deep" in arch_name
or "resnet" in arch_name
or "Uncertainty" in arch_name
):
arch = getattr(architecture, arch_name)(img_size)
else:
arch = getattr(architecture, arch_name)()
norm_dict = load_pre_model(arch, model_path, visualize=True)
return arch, norm_dict
[docs]
def get_images(test_ds, num_images, rand=False, indices=None):
"""Get n random test and truth images or mean, standard deviation and
true images from an already sampled dataset.
Parameters
----------
test_ds : H5DataSet
data set with test images
num_images : int
number of test images
rand : bool
true if images should be drawn random
indices : list
list of indices to be used
Returns
-------
img_test : n 2d arrays
test images
img_true : n 2d arrays
truth images
"""
if hasattr(test_ds, "tar_fourier"):
indices = torch.arange(num_images)
if rand:
indices = torch.randint(0, len(test_ds), size=(num_images,))
# remove dublicate indices
while len(torch.unique(indices)) < len(indices):
new_indices = torch.randint(
0, len(test_ds), size=(num_images - len(torch.unique(indices)),)
)
indices = torch.cat((torch.unique(indices), new_indices))
# sort after getting indices
indices, _ = torch.sort(indices)
img_test = test_ds[indices][0]
img_true = test_ds[indices][1]
return img_test, img_true, indices
else:
mean = test_ds[indices][0]
std = test_ds[indices][1]
img_true = test_ds[indices][2]
return mean, std, img_true
[docs]
def eval_model(img, model):
"""Put model into eval mode and evaluate test images.
Parameters
----------
img : str
test image
model : architecture object
architecture with pretrained weigths
Returns
-------
pred : n 1d arrays
predicted images
"""
if len(img.shape) == (3):
img = img.unsqueeze(0)
model.eval()
if torch.cuda.is_available():
model.cuda()
with torch.no_grad():
if torch.cuda.is_available():
pred = model(img.float().cuda())["pred"]
else:
pred = model(img.float())["pred"]
return pred.cpu()
[docs]
def get_ifft(array, amp_phase=False, scale=False):
"""Compute the inverse Fourier transformation
Parameters
----------
array : ndarray
array with shape (2, img_size, img_size) with optional batch size
amp_phase : bool, optional
true, if splitting in amplitude and phase was used, by default True
Returns
-------
ndarray
image(s) in image space
"""
if len(array.shape) == 3:
array = array.unsqueeze(0) if hasattr(array, "numpy") else array[np.newaxis, :]
if amp_phase:
amp = 10 ** (10 * array[:, 0] - 10) - 1e-10 if scale else array[:, 0]
a = amp * np.cos(array[:, 1])
b = amp * np.sin(array[:, 1])
compl = a + b * 1j
else:
compl = array[:, 0] + array[:, 1] * 1j
if compl.shape[0] == 1:
compl = compl.squeeze(0)
return np.abs(np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(compl))))
[docs]
def save_pred(path, img):
"""Write test data and predictions to h5 file."""
with h5py.File(path, "w") as hf:
for key, value in img.items():
hf.create_dataset(key, data=value)
hf.close()
[docs]
def read_pred(path):
"""Read data saved with save_pred from h5 file."""
images = {}
with h5py.File(path, "r") as hf:
for key in hf:
images[key] = np.array(hf[key])
hf.close()
return images
[docs]
def check_outpath(model_path):
"""Checks if there is already a predictions file in the evaluation folder
Parameters
----------
model_path : str
path to the model
Returns
-------
bool
true, if the file exists
"""
name_model = Path(model_path).stem
model_path = Path(model_path).parent / "evaluation" / f"predictions_{name_model}.h5"
path = Path(model_path)
exists = path.exists()
return exists
[docs]
def symmetry(image, key):
"""Symmetry function to complete the images.
Parameters
----------
image : torch.Tensor
(stack of) half images
Returns
-------
torch.Tensor
quadratic images after utilizing symmetry
"""
if isinstance(image, np.ndarray):
image = torch.tensor(image)
if len(image.shape) == 3:
image = image.view(1, image.shape[0], image.shape[1], image.shape[2])
half_image = image.shape[-1] // 2
upper_half = image[:, :, :half_image, :].clone()
a = torch.rot90(upper_half, 2, dims=[-2, -1])
image[:, 0, half_image + 1 :, 1:] = a[:, 0, :-1, :-1]
image[:, 0, half_image + 1 :, 0] = a[:, 0, :-1, -1]
if key == "unc":
image[:, 1, half_image + 1 :, 1:] = a[:, 1, :-1, :-1]
image[:, 1, half_image + 1 :, 0] = a[:, 1, :-1, -1]
else:
image[:, 1, half_image + 1 :, 1:] = -a[:, 1, :-1, :-1]
image[:, 1, half_image + 1 :, 0] = -a[:, 1, :-1, -1]
return image
[docs]
def apply_symmetry(img_dict):
"""Pads and applies symmetry to half images.
Takes a dict as input.
Parameters
----------
img_dict : dict
input dict which contains the half images
Returns
-------
dict
input dict with quadratic images
"""
for key in img_dict:
if key != "indices":
if isinstance(img_dict[key], np.ndarray):
img_dict[key] = torch.tensor(img_dict[key])
half_image = img_dict[key].shape[-1] // 2
output = F.pad(
input=img_dict[key],
pad=(0, 0, 0, half_image - 5),
mode="constant",
value=0,
)
output = symmetry(output, key)
img_dict[key] = output
return img_dict
@vectorize(["float64(float64, float64, float64, float64)"], target="cpu")
def tn_numba_vec_cpu(mu, sig, a, b):
rv = np.random.normal(loc=mu, scale=sig)
cond = rv > a and rv < b
while not cond:
rv = np.random.normal(loc=mu, scale=sig)
cond = rv > a and rv < b
return rv
@vectorize(["float64(float64, float64, float64, float64)"], target="parallel")
def tn_numba_vec_parallel(mu, sig, a, b):
rv = np.random.normal(loc=mu, scale=sig)
cond = rv > a and rv < b
while not cond:
rv = np.random.normal(loc=mu, scale=sig)
cond = rv > a and rv < b
return rv
[docs]
def trunc_rvs(mu, sig, num_samples, mode, target="cpu", nthreads=1):
if mode == "amp":
a = 0
b = np.inf
elif mode == "phase":
a = -np.pi
b = np.pi
elif mode == "real" or mode == "imag":
a = -np.inf
b = np.inf
else:
raise ValueError("Unsupported mode, use either ``phase`` or ``amp``.")
mu = np.tile(mu, (num_samples, 1, 1, 1))
sig = np.tile(sig, (num_samples, 1, 1, 1))
if target == "cpu":
if nthreads > 1:
raise ValueError(
f"Target is ``cpu`` but nthreads is {nthreads}, "
"use target=``parallel`` instead."
)
res = tn_numba_vec_cpu(mu, sig, a, b)
elif target == "parallel":
if nthreads == 1:
raise ValueError(
"Target is ``parallel`` but nthreaads is 1, use target=``cpu`` instead."
)
set_num_threads(int(nthreads))
res = tn_numba_vec_parallel(mu, sig, a, b)
else:
raise ValueError("Unsupported target, use cpu or parallel.")
return res.swapaxes(0, 1)
[docs]
def sample_images(mean, std, num_samples, conf):
"""Samples for every pixel in Fourier space from a
truncated Gaussian distribution based on the output
of the network.
Parameters
----------
mean : torch.tensor
mean values of the pixels with shape (number of images, number of samples,
image size // 2 + 1, image_size)
std : torch.tensor
uncertainty values of the pixels with shape (number of images,
number of samples, image size // 2 + 1, image_size)
num_samples : int
number of samples in Fourier space
Returns
-------
dict
resulting mean and standard deviation
"""
mean_amp, mean_phase = mean[:, 0], mean[:, 1]
std_amp, std_phase = std[:, 0], std[:, 1]
num_img = mean_amp.shape[0]
mode = ["amp", "phase"] if conf["amp_phase"] else ["real", "imag"]
# amplitude
sampled_gauss_amp = trunc_rvs(
mu=mean_amp,
sig=std_amp,
mode=mode[0],
num_samples=num_samples,
).reshape(num_img * num_samples, mean_amp.shape[-2], mean_amp.shape[-1])
# phase
sampled_gauss_phase = trunc_rvs(
mu=mean_phase,
sig=std_phase,
mode=mode[1],
num_samples=num_samples,
).reshape(num_img * num_samples, mean_phase.shape[-2], mean_phase.shape[-1])
# masks
if conf["amp_phase"]:
mask_invalid_amp = sampled_gauss_amp <= (0 - 1e-4)
mask_invalid_phase = (sampled_gauss_phase <= (-np.pi - 1e-4)) | (
sampled_gauss_phase >= (np.pi + 1e-4)
)
assert mask_invalid_amp.sum() == 0
assert mask_invalid_phase.sum() == 0
sampled_gauss = np.stack([sampled_gauss_amp, sampled_gauss_phase], axis=1)
# pad resulting images and utilize symmetry
sampled_gauss = F.pad(
input=torch.tensor(sampled_gauss),
pad=(0, 0, 0, mean_amp.shape[-2] - 2),
mode="constant",
value=0,
)
sampled_gauss_symmetry = symmetry(sampled_gauss, None)
fft_sampled_symmetry = get_ifft(
sampled_gauss_symmetry, amp_phase=conf["amp_phase"], scale=False
).reshape(num_img, num_samples, mean_amp.shape[-1], mean_amp.shape[-1])
results = {
"mean": fft_sampled_symmetry.mean(axis=1),
"std": fft_sampled_symmetry.std(axis=1),
}
return results
[docs]
def mergeDictionary(dict_1, dict_2):
dict_3 = {**dict_1, **dict_2}
for key, value in dict_3.items():
if key in dict_1 and key in dict_2:
dict_3[key] = np.append(dict_1[key], value)
return dict_3
[docs]
class sampled_dataset:
def __init__(self, bundle_path):
"""
Save the bundle paths and the number of bundles in one file.
"""
if bundle_path == []:
raise ValueError("No bundles found! Please check the names of your files.")
self.bundle_path = bundle_path
def __len__(self):
"""Returns the total number of pictures in this dataset"""
bundle = h5py.File(self.bundle_path, "r")
data = bundle["mean"]
return data.shape[0]
def __getitem__(self, i):
mean = self.open_image("mean", i)
std = self.open_image("std", i)
true = self.open_image("true", i)
return mean, std, true
[docs]
def open_image(self, var, i):
bundle = h5py.File(self.bundle_path, "r")
data = bundle[var]
data = data[i]
return data
[docs]
def apply_normalization(img_test, norm_dict):
"""Applies one of currently two normalization
methods if the training was normalized
Parameters
----------
img_test : torch.Tensor
input image
norm_dict : dictionary
either empty (no normalization) or containing the factors
Returns
-------
img_test : torch.Tensor
normalized image
norm_dict : dictionary
updated dictionary
"""
# normalize using mean and std for whole dataset
if norm_dict and "mean_real" in norm_dict:
img_test[:, 0][img_test[:, 0] != 0] = (
img_test[:, 0][img_test[:, 0] != 0] - norm_dict["mean_real"]
) / norm_dict["std_real"]
img_test[:, 1][img_test[:, 1] != 0] = (
img_test[:, 1][img_test[:, 1] != 0] - norm_dict["mean_imag"]
) / norm_dict["std_imag"]
# scale with the maximum value of each image
elif norm_dict and "max_scaling" in norm_dict:
max_factors_real = torch.amax(img_test[:, 0], dim=(-2, -1), keepdim=True)
max_factors_imag = torch.amax(
torch.abs(img_test[:, 1]), dim=(-2, -1), keepdim=True
)
img_test[:, 0] *= 1 / torch.amax(img_test[:, 0], dim=(-2, -1), keepdim=True)
img_test[:, 1] *= 1 / torch.amax(
torch.abs(img_test[:, 1]), dim=(-2, -1), keepdim=True
)
norm_dict["max_factors_real"] = max_factors_real
norm_dict["max_factors_imag"] = max_factors_imag
# normalize each image to mean=0 and std=1
elif norm_dict and "all" in norm_dict:
means = (
img_test.mean(axis=-1)
.mean(axis=-1)
.reshape(img_test.shape[0], img_test.shape[1], 1, 1)
)
stds = (
img_test.std(axis=-1)
.std(axis=-1)
.reshape(img_test.shape[0], img_test.shape[1], 1, 1)
)
img_test = (img_test - means) / stds
norm_dict["means"] = means
norm_dict["stds"] = stds
return img_test, norm_dict
[docs]
def rescale_normalization(pred, norm_dict):
"""Rescale the prediction after normalized training
Parameters
----------
pred : torch.Tensor
predicted image
norm_dict : dictionary
either empty (no normalization) or containing the factors
Returns
-------
pred : torch.Tensor
recaled predicted image
"""
if norm_dict and "mean_real" in norm_dict:
pred[:, 0] = pred[:, 0] * norm_dict["std_real"] + norm_dict["mean_real"]
if pred.shape[1] == 4:
pred[:, 2] = pred[:, 2] * norm_dict["std_imag"] + norm_dict["mean_imag"]
else:
pred[:, 1] = pred[:, 1] * norm_dict["std_imag"] + norm_dict["mean_imag"]
elif norm_dict and "max_scaling" in norm_dict:
pred[:, 0] *= norm_dict["max_factors_real"]
pred[:, 1] *= norm_dict["max_factors_imag"]
elif norm_dict and "all" in norm_dict:
pred[:, 0] = pred[:, 0] * norm_dict["stds"][:, 0] + norm_dict["means"][:, 0]
if pred.shape[1] == 4:
pred[:, 2] = pred[:, 2] * norm_dict["stds"][:, 1] + norm_dict["means"][:, 1]
else:
pred[:, 1] = pred[:, 1] * norm_dict["stds"][:, 1] + norm_dict["means"][:, 1]
return pred
[docs]
def preprocessing(conf):
"""Makes the necessary preprocessing for the evaluation
methods analyzing the whole test dataset.
Parameters
----------
conf : dictionary
config file containing the settings
Returns
-------
model : architecture
model initialized with save file
model_2 : architecture
model initialized with save file
loader : torch.Dataloader
feeds the data batch-wise
norm_dict : dictionary
dict containing the normalization factors
out_path : Path object
path to the evaluation folder
"""
# create DataLoader
loader = create_databunch(
conf["data_path"], conf["fourier"], conf["source_list"], conf["batch_size"]
)
model_path = conf["model_path"]
out_path = Path(model_path).parent / "evaluation"
out_path.mkdir(parents=True, exist_ok=True)
img_size = loader.dataset[0][0][0].shape[-1]
model, norm_dict = load_pretrained_model(
conf["arch_name"], conf["model_path"], img_size
)
# Loads second model if the two channels were trainined separately
model_2 = None
if conf["model_path_2"] != "none":
model_2, norm_dict = load_pretrained_model(
conf["arch_name_2"], conf["model_path_2"], img_size
)
return model, model_2, loader, norm_dict, out_path
[docs]
def process_prediction(conf, img_test, img_true, norm_dict, model, model_2):
"""Applies the normalization, gets and rescales a
prediction and performs the inverse Fourier transformation.
Parameters
----------
conf : dictionary
config files containing the settings
img_test : torch.Tensor
input file for the network
img_true : torch.tensor
true image
norm_dict : dictionary
dict containing the normalization factors
model : architecture
model initialized with save file
model_2 :
model initialized with save file
Returns
-------
ifft_pred : ndarray
predicted source in image space
ifft_truth : ndarray
true source in image space
"""
img_test, norm_dict = apply_normalization(img_test, norm_dict)
pred = eval_model(img_test, model)
pred = rescale_normalization(pred, norm_dict)
if model_2 is not None:
pred_2 = eval_model(img_test, model_2)
pred_2 = rescale_normalization(pred_2, norm_dict)
pred = torch.cat((pred, pred_2), dim=1)
# apply symmetry
if pred.shape[-2] < pred.shape[-1]:
img_dict = {"truth": img_true, "pred": pred}
img_dict = apply_symmetry(img_dict)
img_true = img_dict["truth"]
pred = img_dict["pred"]
ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"])
ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"])
return ifft_pred, ifft_truth
[docs]
def check_samp_file(eval_conf):
"""Checks if a file with sampled images
is located in the evaluation folder
Parameters
----------
eval_conf : dict
contains the evaluation parameters
Returns
-------
bool
true if file exists, otherwise false
"""
model_path = eval_conf["model_path"]
out_path = Path(model_path).parent / "evaluation"
out_path.mkdir(parents=True, exist_ok=True)
name_model = Path(model_path).stem
data_path = out_path / f"sampled_imgs_{name_model}.h5"
return data_path.is_file()