Source code for gaggle.population.base_individuals.nn_individual

import os
import copy
from typing import Iterator

import torch
import torch.nn as nn
from torch.nn import Parameter


from gaggle.arguments import IndividualArgs, SysArgs
from gaggle.population.individual import Individual
from gaggle.arguments.outdir_args import OutdirArgs
from gaggle.utils.special_print import print_highlighted
from gaggle.utils.web import is_valid_url


[docs]class NNIndividual(Individual): """An Individual whose initial parameters are a torch.nn.Module. """ CONFIG_BASE_MODEL_STATE_DICT = "base_model_state_dict" def __init__(self, individual_args: IndividualArgs, sys_args: SysArgs = None, model: nn.Module = None, *args, **kwargs): # we first initialize all the variables we need super().__init__(individual_args, sys_args) if model is None: self.model = individual_args.get_base_model(*args, **kwargs).to(self.sys_args.device) else: self.model = copy.deepcopy(model).to(self.sys_args.device) self.hooks = [] self.return_val = {} self.print_flags = {} self._debug_mode = False self._tick = 0 if individual_args.model_ckpt is not None: self.load()
[docs] def initialize(self): if self.individual_args.random_init: for m in self.model.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") # nn.init.constant_(m.weight, 0.00001) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) return self
[docs] def generate_gene_pool(self): gene_pool = {} self.genome_size = 0 for i, m in enumerate(self.model.parameters()): gene_size = 1 for dim in list(m.size()): gene_size *= dim gene_pool[i] = {"param": m, "gene_size": gene_size} self.genome_size += gene_size return gene_pool
[docs] def parameters(self, recurse: bool = True) -> Iterator[Parameter]: return self.model.parameters()
[docs] def train(self, mode: bool = True): return self.model.train()
[docs] def eval(self): return self.model.eval()
[docs] def forward(self, x): return self.model(x)
[docs] def save(self, outdir_args: OutdirArgs = None) -> dict: data = { self.CONFIG_INDIVIDUAL_ARGS: self.individual_args, self.CONFIG_BASE_MODEL_STATE_DICT: self.model.state_dict() } if outdir_args is not None: folder = outdir_args.get_folder_path() fn = os.path.join(folder, f"{self.individual_args.model_name}.pt") torch.save(data, fn) print_highlighted(f"Saved model at {os.path.abspath(fn)}") return data
[docs] def debug(self, mode: bool = True) -> None: self._debug_mode = mode
[docs] def debug_tick(self) -> None: """ Clears plotting for debugs. """ self._tick += 1
[docs] def first_time(self, name) -> bool: """ Checks if something has been invoked for the first time """ state = name not in self.print_flags self.print_flags[name] = True return state
[docs] def load(self, content=None, ckpt=None) -> nn.Module: if content is None: ckpt = ckpt if ckpt is not None else self.model_args.model_ckpt if is_valid_url(ckpt): content = torch.hub.load_state_dict_from_url(ckpt, progress=False) else: content = torch.load(ckpt) if IndividualArgs.CONFIG_KEY in content.keys(): content = content[IndividualArgs.CONFIG_KEY] # Hacky part. See if this is a checkpoint to load the base model or to load this model. if self.CONFIG_MODEL_ARGS in content.keys(): self.model_args = content[self.CONFIG_MODEL_ARGS] self.model.load_state_dict(content[self.CONFIG_BASE_MODEL_STATE_DICT]) self.model.eval() else: # we assume this is just a state dict for the base model self.model.load_state_dict(content) self.model.eval() return self.model