Source code for gaggle.population.base_individuals.rl_individual

import torch
import torch.nn as nn

from gaggle.arguments import IndividualArgs, SysArgs
from gaggle.population.base_individuals.nn_individual import NNIndividual


[docs]class RLIndividual(NNIndividual): """NNIndividual wrapper that adds an argmax to the prediction as RL problem usually require the output of the model to be an action rather than logits. """ def __init__(self, individual_args: IndividualArgs, sys_args: SysArgs = None, model: nn.Module = None, *args, **kwargs): super(RLIndividual, self).__init__(individual_args, sys_args, model, *args, **kwargs)
[docs] def forward(self, x): return torch.argmax(super(RLIndividual, self).forward(x))