import copy
from abc import ABC
from copy import deepcopy
from pprint import pprint
from typing import List, Union
import numpy as np
import torch.utils.data
from torch.utils.data.dataset import T_co
from torchvision import transforms
import torchvision
from gaggle.arguments.problem_args import ProblemArgs
from gaggle.arguments.sys_args import SysArgs
from gaggle.utils.special_images import plot_images
[docs]class DataWrapper(torch.utils.data.Dataset):
"""Wrapper that set the .data and .targets attributes that can then be accessed by the Dataset class in the
get_data_and_targets method.
See Also:
This class creates attributes that are used by Dataset.get_data_and_targets.
"""
def __init__(self, data: torch.Tensor = None, targets: torch.Tensor = None):
self.data = data
self.targets = targets
def __len__(self):
return self.targets.size(0)
def __getitem__(self, item):
x = self.data[item]
y = self.targets[item]
return x, y
[docs]class Dataset(torch.utils.data.Dataset, ABC):
r"""Dataset class that allows for more flexible custom indexing and other behavior"""
def __init__(self, problem_args: ProblemArgs = None, train: bool = True, sys_args: SysArgs = None):
self.train = train
super(Dataset, self).__init__()
self.problem_args = problem_args if problem_args is not None else ProblemArgs()
self.sys_args = sys_args if sys_args is not None else SysArgs()
self.apply_transform: bool = True
# pre-processors receive an index, the image and the label of each item
self.idx: List[int] = [] # all indices that this dataset returns
self.idx_to_backdoor = {} # store idx and backdoor mapping
self.dataset: torch.utils.data.Dataset | None = None
self.classes: List[str] = []
self.real_normalize_transform = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
self.normalize_transform = self.real_normalize_transform
self.transform = self._build_transform()
self.class_to_idx = None
self.disable_fetching = False
[docs] def get_data_and_targets(self):
r"""Gets the data and the targets for the current dataset stored in the self.data object.
The self.data object should have .data and .targets attributes to be returned.
Returns:
A tuple containing (data, targets) or (None, None) if the dataset is not initialized.
"""
if self.dataset is not None:
return (self.dataset.data, self.dataset.targets)
else:
return (None, None)
[docs] def num_classes(self) -> int:
""" Return the number of classes"""
return len(self.classes)
def _build_transform(self) -> None:
""" Internal function to build a default transformation. Override this
if necessary. """
transform = transforms.Compose([
transforms.ToTensor(),
self.normalize_transform
])
return transform
[docs] def random_subset(self, n: int):
""" Creates a random subset of this dataset """
idx = deepcopy(self.idx)
np.random.shuffle(idx)
copy = self.copy()
copy.idx = idx[:n]
copy.class_to_idx = None
return copy
[docs] def subset(self, idx: Union[List[int], int]):
""" Creates a subset of this dataset. """
if isinstance(idx, int):
idx = np.arange(idx)
copy = self.copy()
copy.idx = [self.idx[i] for i in idx]
copy.class_to_idx = None
return copy
[docs] def remove_classes(self, target_classes: List[int]):
""" Creates a subset without samples from one target class. """
copy = self.copy()
for target_class in target_classes:
for index_of_target_class in copy.get_class_to_idx()[target_class]:
copy.idx.remove(index_of_target_class)
copy.class_to_idx = None
return copy
[docs] def visualize(self, sq_size: int = 3) -> None:
""" Plot samples from this dataset as a square.
"""
n = sq_size ** 2
x = [self[i][0] for i in range(n)]
imgs = torch.stack(x, 0)
if len(self.idx_to_backdoor) == 0:
title = self.problem_args.dataset_name
else:
title = f"{self.problem_args.dataset_name} (Poisoned)"
plot_images(imgs, n_row=sq_size, title=title)
[docs] def print_class_distribution(self):
class_to_idx = self.get_class_to_idx(verbose=False)
cd = {c: 100 * len(v) / len(self) for c, v in class_to_idx.items()}
pprint(cd)
[docs] def without_normalization(self) -> 'Dataset':
""" Return a copy of this data without normalization.
"""
copy = self.copy()
copy.enable_normalization(False)
return copy
[docs] def enable_normalization(self, enable: bool) -> None:
""" Method to enable or disable normalization.
"""
if enable:
self.normalize_transform = self.real_normalize_transform
else:
self.normalize_transform = transforms.Lambda(lambda x: x)
self.transform = self._build_transform()
[docs] def size(self):
""" Alternative function to get the size. """
return len(self.idx)
def __len__(self):
""" Return the number of elements in this dataset """
return self.size()
[docs] def copy(self):
""" Return a copy of this dataset instance. """
return deepcopy(self)
def __getitem__(self, index) -> T_co:
index = self.idx[index]
x, y = self.dataset[index]
x = self.transform(x)
return self.normalize(x), y