import torch.nn as nn
import torch
__all__ = ["LeNet5"]
#pytorch implementation of the classic LeNet5 (copied from https://towardsdatascience.com/implementing-yann-lecuns-lenet-5-in-pytorch-5e05a0911320)
[docs]class LeNet5(nn.Module):
def __init__(self, num_classes=10):
super(LeNet5, self).__init__()
self.feature_extractor = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1),
nn.Tanh(),
nn.AvgPool2d(kernel_size=2),
nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
nn.Tanh(),
nn.AvgPool2d(kernel_size=2),
nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1),
nn.Tanh()
)
self.classifier = nn.Sequential(
nn.Linear(in_features=120, out_features=84),
nn.Tanh(),
nn.Linear(in_features=84, out_features=num_classes),
)
[docs] def forward(self, x):
if len(x.size()) == 3:
x = x.unsqueeze(1)
x = self.feature_extractor(x)
x = torch.flatten(x, 1)
logits = self.classifier(x)
#probs = F.softmax(logits, dim=1) #we use crossentropyloss wich includes a softmax inside the loss function
#return logits, probs
return logits