Source code for gaggle.base_nns.dqn

import torch.nn as nn
import torch.nn.functional as F


[docs]class DQN(nn.Module): def __init__(self, num_inputs=4, num_outputs=2, hidden_size=16): super(DQN, self).__init__() # The inputs are two integers giving the dimensions of the inputs and outputs respectively. # The input dimension is the state dimention and the output dimension is the action dimension. # This constructor function initializes the network by creating the different layers. self.num_inputs = num_inputs self.num_outputs = num_outputs self.fc1 = nn.Linear(num_inputs, hidden_size) self.fc2 = nn.Linear(hidden_size, num_outputs)
[docs] def forward(self, x): # The variable x denotes the input to the network. # The function returns the q value for the given input. x = x.view(-1, self.num_inputs) x = F.sigmoid(self.fc1(x)) qvalue = self.fc2(x) # wouldn't usually do a second sigmoid but leap does it so we have to return qvalue