Source code for gaggle.ga.simple_ga

import os
from typing import Callable, List
import time
import json

import torch
import matplotlib.pyplot as plt

from gaggle.arguments import GAArgs, SysArgs, IndividualArgs, ProblemArgs
from gaggle.arguments.outdir_args import OutdirArgs
from gaggle.population.individual import Individual
from gaggle.population.population_manager import PopulationManager
from gaggle.utils.special_print import print_dict_highlighted, print_highlighted
from gaggle.operators import Crossover, Mutation, Selection
from gaggle.ga import GA
from gaggle.problem import Problem


[docs]class SimpleGA(GA): r"""Implements a Simple Genetic Algorithm following Mitchell. """ def __init__(self, population_manager: PopulationManager = None, ga_args: GAArgs = None, selection: Selection = None, crossover: Crossover = None, mutation: Mutation = None, problem_args: ProblemArgs = None, sys_args: SysArgs = None, outdir_args: OutdirArgs = None, individual_args: IndividualArgs = None, problem: Problem = None): super(SimpleGA, self).__init__(population_manager, ga_args, selection, crossover, mutation, problem_args, sys_args, outdir_args, individual_args, problem) self.best = 0 self.saved_metrics = {"train_metrics": {"best": [], "avg": [], "worst": [], "generation": [], "time_taken": [] }, "test_metrics": {"best": [], "avg": [], "worst": [], "generation": [], "time_taken": [] } } self.window_size = 10 self.metrics_to_plot = ["best", "avg", "worst"]
[docs] def save_individual(self, individual: Individual, fitness): """Saves the individual provided as well as its fitness to the folder path specified in the outdir_args as 'best.pt'. Args: individual: fitness: Returns: """ if (self.ga_args.save_every_epoch or (fitness > self.best)): data = { "fitness": fitness, GAArgs.CONFIG_KEY: vars(self.ga_args), IndividualArgs.CONFIG_KEY: individual.save(), } fn = os.path.join(self.outdir_args.create_folder_name(), f"best.pt") torch.save(data, fn) print_highlighted(f"New best: {fitness:.2f}%>{self.best:.2f}%. Saved at '{os.path.abspath(fn)}'") self.best = fitness
[docs] @torch.no_grad() def get_fitness_metric(self, fitness: dict[int:float], save: bool = False, mode: str = "train"): """ Outputs basic fitness metrics for a population, like avg fitness, best & worst fitness :param fitness: dictionary of fitness :param save: whether to save the metrics :param mode: what to save :return: metrics: dictionary of metrics """ metrics = {} # best best_idx = max(fitness, key=fitness.get) best_fitness = fitness[best_idx] metrics["best"] = [best_idx, best_fitness] # worst worst_idx = min(fitness, key=fitness.get) worst_fitness = fitness[worst_idx] metrics["worst"] = [worst_idx, worst_fitness] # average fitness_list = list(fitness.values()) avg_fitness = sum(fitness_list)/len(fitness_list) metrics["avg"] = [avg_fitness] if save: if mode == "train": self.saved_metrics["train_metrics"]["best"].append(best_fitness) self.saved_metrics["train_metrics"]["avg"].append(avg_fitness) self.saved_metrics["train_metrics"]["worst"].append(worst_fitness) elif mode == "eval": self.saved_metrics["test_metrics"]["best"].append(best_fitness) self.saved_metrics["test_metrics"]["avg"].append(avg_fitness) self.saved_metrics["test_metrics"]["worst"].append(worst_fitness) else: raise NotImplementedError return metrics
[docs] def display_metrics(self, display_train: bool = True, display_test: bool = True): """Displays the metrics computed and stored in self.saved_metrics. Only displays and saves the metrics whose keys are in self.metrics_to_plot. The graphs are also saved to a file in the output folder specified in outdir_args. Args: display_train: display_test: Returns: """ if display_train: plt.figure() plt.clf() plt.title('Training Metrics') plt.xlabel('Generations') plt.ylabel('Fitness') moving_averages = {} for key in self.saved_metrics["train_metrics"].keys(): if key in self.metrics_to_plot: moving_averages[key] = [] for i in range(len(self.saved_metrics["train_metrics"][key]) - self.window_size): moving_averages[key].append(sum(self.saved_metrics["train_metrics"][key][i: i+self.window_size])/ self.window_size) plt.plot(list(range(len(moving_averages[key]))), moving_averages[key], label=key) fn = os.path.join(self.outdir_args.create_folder_name(), f"training_metrics.png") plt.legend() plt.savefig(fn) plt.show() if display_test: plt.figure() plt.clf() plt.title('Eval Metrics') plt.xlabel('Generations') plt.ylabel('Fitness') moving_averages = {} for key in self.saved_metrics["test_metrics"].keys(): if key in self.metrics_to_plot: moving_averages[key] = [] for i in range(len(self.saved_metrics["test_metrics"][key]) - self.window_size): moving_averages[key].append(sum(self.saved_metrics["test_metrics"][key][i: i+self.window_size])/ self.window_size) plt.plot(list(range(len(moving_averages[key]))), moving_averages[key], label=key) fn = os.path.join(self.outdir_args.create_folder_name(), f"testing_metrics.png") plt.legend() plt.savefig(fn) plt.show()
[docs] def save_metrics(self, save_train: bool = True, save_test: bool = True): """Saves the metrics computed and stored in self.saved_metrics. The metrics are saved to files in the output folder specified in outdir_args. Args: save_train: save_test: Returns: """ if save_train: train_fn = os.path.join(self.outdir_args.create_folder_name(), f"training_metrics.json") with open(train_fn, "w") as f: json.dump(self.saved_metrics["train_metrics"], f) if save_test: eval_fn = os.path.join(self.outdir_args.create_folder_name(), f"test_metrics.json") with open(eval_fn, "w") as f: json.dump(self.saved_metrics["test_metrics"], f)
[docs] def test(self, save: bool = True): """Tests population by computing the fitness and its associated metrics without updating the population manager. Args: save: Returns: """ start_time = time.time() self.population_manager.eval() test_fitness = self.problem.evaluate_population(self.population_manager, use_freshness=False, update_manager=False, train=False) test_metrics = self.get_fitness_metric(test_fitness, save=save, mode="eval") if self.outdir_args is not None: self.save_individual(self.population_manager.get_individual(test_metrics["best"][0]), test_metrics["best"][1]) test_metrics["time_taken"] = time.time() - start_time self.saved_metrics["test_metrics"]["time_taken"].append(test_metrics["time_taken"]) print("Test Metrics") print_dict_highlighted(test_metrics)
[docs] def train_one_generation(self): """ Standard one generation GA pipeline """ self.population_manager.train() train_fitness = self.problem.evaluate_population(self.population_manager, use_freshness=self.ga_args.use_freshness, update_manager=True, train=True) self.population_manager = self.selection_fn.select_all(self.population_manager, self.crossover_fn.mates_per_crossover, self.crossover_fn.children_per_crossover) self.population_manager = self.crossover_fn.crossover_pop(self.population_manager) self.population_manager = self.mutation_fn.mutate_pop(self.population_manager) return train_fitness
[docs] def train(self, test: bool = True, callbacks: List[Callable] = None, display_train_metrics: bool = True, display_test_metrics: bool = True): """Call to begin the training process of the population using the arguments stored in this SimpleGA object. Args: test: callbacks: display_train_metrics: display_test_metrics: Returns: """ print(f"Genome size: {self.population_manager.get_gene_count():.3e} params") print_dict_highlighted(vars(self.ga_args)) if callbacks is None: callbacks = [] for generation in range(self.ga_args.generations): start_time = time.time() train_fitness = self.train_one_generation() time_taken = time.time() metrics = self.get_fitness_metric(train_fitness, save=True, mode="train") metrics["generation"] = f"{generation+1}/{self.ga_args.generations}" metrics["time_taken"] = time_taken - start_time print_dict_highlighted(metrics) # we add the generation and time taken to the saved metrics self.saved_metrics["train_metrics"]["generation"].append(generation+1) self.saved_metrics["train_metrics"]["time_taken"].append(metrics["time_taken"]) if test and self.ga_args.eval_every_generation \ and generation % self.ga_args.eval_every_generation == 0: self.test() self.saved_metrics["test_metrics"]["generation"].append(generation+1) for callback in callbacks: callback(generation) if test: self.test() if display_test_metrics or display_train_metrics: self.display_metrics(display_train=display_train_metrics, display_test=display_test_metrics)