Source code for gaggle.problem.dataset.classification_problem

from gaggle.problem.problem import Problem
from gaggle.population import Individual, PopulationManager
from gaggle.arguments import ProblemArgs, SysArgs
from gaggle.problem.dataset import DatasetFactory
from gaggle.utils.smooth_value import SmoothedValue
from gaggle.utils.metrics import accuracy
import torch


[docs]class ClassificationProblem(Problem): """A Problem that represents a standard Machine Learning classification problem. It stores the associated training and validation dataset. Population evaluation optimized for GPU by default to speed up training. To create a classification problem with a custom dataset, register said dataset in the DatasetFactory. """ def __init__(self, problem_args: ProblemArgs = None, sys_args: SysArgs = None): super(ClassificationProblem, self).__init__(problem_args, sys_args) self.train_dataset = DatasetFactory.from_problem_args(problem_args, train=True, sys_args=sys_args) self.train_data, self.train_transforms = self.train_dataset.get_data_and_transform() if self.problem_args.batch_size == -1: # this means use the entire dataset without batching self.problem_args.batch_size = self.train_data[0].size(0) if self.problem_args.batch_size == self.train_data[0].size(0): # we move everything to the gpu and let it live on the gpu print(f"Batching is not necessary, will store the entire data on device: {sys_args.device}") self.train_data = (self.train_data[0].to(self.sys_args.device), self.train_data[1].to( self.sys_args.device)) self.eval_dataset = DatasetFactory.from_problem_args(problem_args, train=False, sys_args=sys_args) self.eval_data, self.eval_transforms = self.eval_dataset.get_data_and_transform() if self.problem_args.eval_batch_size == -1: self.problem_args.eval_batch_size = self.eval_data[0].size(0) self.current_batch = None self.fitness_function = accuracy
[docs] @torch.no_grad() def evaluate_population(self, population_manager: PopulationManager, use_freshness: bool = True, update_manager: bool = True, train: bool = True, *args, **kwargs) -> dict[int: float]: """Population evaluation optimized for GPU by default to speed up training. Should only be modified if specific custom behavior is desired. It is usually not recommend to modify this function. Args: population_manager: use_freshness: update_manager: train: *args: **kwargs: Returns: The dictionary of individual fitnesses """ all_data = self.train_data if train else self.eval_data transforms = self.train_transforms if train else self.eval_transforms batch_size = self.problem_args.batch_size if train else self.problem_args.eval_batch_size num_inputs = all_data[0].size(0) fitness = {} for i in range(population_manager.population_size): if population_manager.is_fresh(i) and use_freshness: fitness[i] = SmoothedValue() elif not use_freshness: fitness[i] = SmoothedValue() num_batches = num_inputs // batch_size rest = num_inputs % batch_size for j in range(num_batches): data = all_data[0][j * batch_size:(j + 1) * batch_size].to(self.sys_args.device) data = transforms(data) targets = all_data[1][j * batch_size:(j + 1) * batch_size].to(self.sys_args.device) self.current_batch = (data, targets) for k in list(fitness.keys()): fitness[k].update(self.evaluate(population_manager.get_individual(k), *args, **kwargs), n=batch_size) if rest > 0: data = transforms(all_data[0][-rest:].to(self.sys_args.device)) targets = all_data[1][-rest:].to(self.sys_args.device) self.current_batch = (data, targets) for l in list(fitness.keys()): fitness[l].update(self.evaluate(population_manager.get_individual(l), *args, **kwargs), n=batch_size) for m in list(fitness.keys()): fitness[m] = fitness[m].global_avg if update_manager: population_manager.set_individual_fitness(m, fitness[m]) if use_freshness: population_manager.set_freshness(m, False) if train and use_freshness: return population_manager.get_fitness() return fitness
[docs] @torch.no_grad() def evaluate(self, individual: Individual, train: bool = True, *args, **kwargs) -> float: """Evaluates an individual on the current batch of data. Args: individual: train: whether we are currently training or performing an inference. *args: **kwargs: Returns: """ if train: individual.train() else: individual.eval() x, y = self.current_batch x, y = x.to(self.sys_args.device), y.to(self.sys_args.device) y_pred = individual(x) return self.fitness_function(y_pred, y).cpu().item()