Source code for gaggle.base_nns.lenet

import torch.nn as nn
import torch


__all__ = ["LeNet5"]


#pytorch implementation of the classic LeNet5 (copied from https://towardsdatascience.com/implementing-yann-lecuns-lenet-5-in-pytorch-5e05a0911320)
[docs]class LeNet5(nn.Module): def __init__(self, num_classes=10): super(LeNet5, self).__init__() self.feature_extractor = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1), nn.Tanh(), nn.AvgPool2d(kernel_size=2), nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1), nn.Tanh(), nn.AvgPool2d(kernel_size=2), nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1), nn.Tanh() ) self.classifier = nn.Sequential( nn.Linear(in_features=120, out_features=84), nn.Tanh(), nn.Linear(in_features=84, out_features=num_classes), )
[docs] def forward(self, x): if len(x.size()) == 3: x = x.unsqueeze(1) x = self.feature_extractor(x) x = torch.flatten(x, 1) logits = self.classifier(x) #probs = F.softmax(logits, dim=1) #we use crossentropyloss wich includes a softmax inside the loss function #return logits, probs return logits