import hashlib
import torch
[docs]class UnNormalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Returns:
Tensor: Normalized image.
"""
for t, m, s in zip(tensor, self.mean, self.std):
t.mul_(s).add_(m)
# The normalize code -> t.sub_(m).div_(s)
return tensor
[docs]def tensor_hash(tensor):
tensor = tensor.detach().cpu().numpy().tobytes()
return hashlib.sha256(tensor).hexdigest()
[docs]def yield_batches(images, batch_size):
for i in range(0, len(images), batch_size):
yield images[i:i+batch_size]
[docs]def set_multiple_indices_to_zero(tensor, indices):
# Get the batch size
batch_size = tensor.shape[0]
# Get the indices in 1D
indices = indices[:,0]*tensor.shape[1]+indices[:,1] # calculate 1D indices
# Create a mask of the same shape as the tensor
mask = torch.ones_like(tensor)
# reshape the mask to 1D
mask = mask.reshape(-1)
# set the corresponding element in the mask to 0
mask[indices] = 0
# reshape the mask back to 2D
mask = mask.reshape(tensor.shape)
# Apply the mask to the tensor
tensor = tensor * mask
return tensor