Source code for gaggle.population.base_individuals.pytorch_individual

from gaggle.population import Individual
from gaggle.arguments import SysArgs, IndividualArgs
from gaggle.utils.individual_helper import from_gene_pool_no_metadata
import torch
import torch.nn as nn


[docs]class PytorchIndividual(Individual): """An Individual whose initial parameters are pytorch tensors. """ def __init__(self, tensors: dict[int: torch.Tensor] = None, individual_args: IndividualArgs = None, sys_args: SysArgs = None): super(PytorchIndividual, self).__init__(individual_args, sys_args) if tensors is None: self.initialize() else: self.gene_pool = self.generate_gene_pool(tensors)
[docs] def initialize(self, *args, **kwargs) -> nn.Module: if self.gene_pool is None: low = self.individual_args.param_lower_bound if self.individual_args.param_lower_bound is not None else 0. high = self.individual_args.param_upper_bound if self.individual_args.param_upper_bound is not None else 1. tensors = {0: torch.empty(self.individual_args.individual_size).uniform_(low, high).clone().detach()} self.gene_pool = self.generate_gene_pool(tensors) return self
[docs] def generate_gene_pool(self, tensors: dict[int: torch.Tensor], *args, **kwargs) -> dict[int:dict[str: nn.Parameter, str: int]]: gene_pool = {} self.genome_size = 0 idx = 0 for value in tensors.values(): param = nn.Parameter(value.clone().detach()) self.register_parameter(str(idx), param) gene_size = 1 for dim in list(param.size()): gene_size *= dim gene_pool[idx] = {"param": param, "gene_size": gene_size} self.genome_size += gene_size idx += 1 self.to(self.sys_args.device) return gene_pool
[docs] def forward(self, *args, **kwargs): return from_gene_pool_no_metadata(self.get_gene_pool())