Source code for gaggle.arguments.individual_args

from typing import Callable
from dataclasses import dataclass, field

from gaggle.base_nns.resnet_x import resnet20
from gaggle.base_nns.lenet import LeNet5
from gaggle.base_nns.snet import SNetCIFAR, SNetMNIST
from gaggle.base_nns.drqn import DRQN
from gaggle.base_nns.dqn import DQN


[docs]@dataclass class IndividualArgs: """ Argument class that contains the arguments relating to individuals """ CONFIG_KEY = "individual_args" individual_name: str = field(default="nn", metadata={ "help": "Individual class to use", }) param_lower_bound: float = field(default=None, metadata={ "help": "lower bound of restricted parameter value, if None, no clipping is performed" }) param_upper_bound: float = field(default=None, metadata={ "help": "upper bound of restricted parameter value, if None, no clipping is performed" }) individual_size: int = field(default=100, metadata={ "help": "length of the parameter tensor for the basic NumpyIndividual and PytorchIndividual. This argument" "is irrelevant for other individuals (unless custom coded it to be so)" }) model_name: str = field(default="lenet", metadata={ "help": "name of the model architecture. Only relevant for neural network individuals. " "Please see the 'get_base_model' method of this class.", }) resolution: int = field(default=32, metadata={ "help": "input resolution of the model", "choices": [32] }) base_model_weights: str = field(default=None, metadata={ "help": "Path to pre-trained base model weights. Can be a path or a weights file" "from the cloud (see below for defaults). The base_model_weights only specify" "the weight initialization strategy for the base model, but these weights would" "be overwritten, if there is different data in the model file itself." "" "Example Mappings: " " - resnet18: ResNet18_Weights.DEFAULT " " - resnet34: ResNet34_Weights.DEFAULT " " - resnet50: ResNet50_Weights.DEFAULT ", }) model_ckpt: str = field(default=None, metadata={ "help": "path to the model's checkpoint" }) random_init: bool = field(default=True, metadata={ "help": "initialization process for the models" }) models = { "resnet20": resnet20, "lenet": LeNet5, "snet_cifar": SNetCIFAR, "snet_mnist": SNetMNIST, "drqn": DRQN, "dqn": DQN }
[docs] @classmethod def get_keys(cls): r"""Gets the list of available NN pre-built models that can be created Returns: list of strings of model names """ return list(cls.models.keys())
[docs] @classmethod def update(cls, key, model): r"""Add a new model to the list of models that can be created Args: key: model name that will be used as the dictionary lookup key model: model class object, it needs to not be already initialized """ assert isinstance(model, Callable) cls.models[key] = model
[docs] def get_base_model(self, *args, **kwargs): r"""Gets a base model as specified by the self.model_name field Args: *args: args that will get passed down to the model initialization **kwargs: kwargs that will get passed down to the model initialization Returns: NN model (nn.Module unless custom model was added using cls.update) """ model = self.models.get(self.model_name, None) if model is None: raise ValueError(self.model_name) return model(*args, **kwargs)