Source code for gaggle.problem.dataset.base_datasets.mnist

import numpy as np
import torch
import torchvision
from torchvision.transforms import transforms

from gaggle.arguments.problem_args import ProblemArgs
from gaggle.arguments.sys_args import SysArgs
from gaggle.problem.dataset.dataset import Dataset
from gaggle.global_configs import global_configs


[docs]class MNIST(Dataset): """MNIST dataset. Deng, L. (2012). The mnist database of handwritten digit images for machine learning research. IEEE Signal Processing Magazine, 29(6), 141–142. """ def __init__(self, problem_args: ProblemArgs = None, train: bool = True, sys_args: SysArgs = None): super().__init__(problem_args, train, sys_args) self.dataset = torchvision.datasets.MNIST(root=global_configs.CACHE_DIR, download=True, train=train, transform=None) self.idx = list(range(len(self.dataset))) max_size = self.problem_args.max_size_train if train else self.problem_args.max_size_val if max_size is not None: self.idx = np.random.choice(self.idx, max_size) self.transform = self._build_transform() self.classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
[docs] def get_data_and_targets(self): if isinstance(self.dataset.data, torch.Tensor): data = (self.dataset.data.to(torch.float), self.dataset.targets.to(torch.float)) else: data = (torch.Tensor(self.dataset.data), torch.Tensor(self.dataset.targets)) return data
def _build_transform(self): if self.train: transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), ]) else: transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), ]) return transform