Source code for gaggle.utils.special_images

import matplotlib
import matplotlib.pyplot as plt
from os.path import dirname, isdir

import numpy as np
import torch
import torchvision
from PIL import Image


[docs]def plot_images(x: torch.Tensor, n_row=2, dpi=400, savefig=None, title=None): x = x.detach().cpu() matplotlib.rcParams["figure.dpi"] = dpi grid_img = torchvision.utils.make_grid(x,nrow=n_row, range=(-1, 1), scale_each=True, normalize=True) plt.axis('off') if title is not None: plt.title(title) plt.imshow(grid_img.permute(1, 2, 0)) if savefig is not None and isdir(dirname(savefig)): plt.savefig(savefig) plt.show()
[docs]def image_to_tensor(image_file): """ Loads a PIL image and returns it with values scaled to [0,1]""" with Image.open(image_file) as img: img = img.convert('RGB') img = torch.from_numpy(np.array(img)) img = img.permute(2, 0, 1).float() return img / 255.