Base_nns

Dqn

class gaggle.base_nns.dqn.DQN(num_inputs=4, num_outputs=2, hidden_size=16)[source]

Bases: Module

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Drqn

class gaggle.base_nns.drqn.DRQN(num_inputs=4, num_outputs=2, hidden_size=16)[source]

Bases: Module

forward(x, hidden=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Lenet

class gaggle.base_nns.lenet.LeNet5(num_classes=10)[source]

Bases: Module

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Resnet

ResNet in PyTorch. For Pre-activation ResNet, see ‘preact_resnet.py’. Reference: [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun

Deep Residual Learning for Image Recognition. arXiv:1512.03385

gaggle.base_nns.resnet.ResNet101(num_classes: int = 10)[source]
gaggle.base_nns.resnet.ResNet152(num_classes: int = 10)[source]
gaggle.base_nns.resnet.ResNet18(num_classes: int = 10)[source]
gaggle.base_nns.resnet.ResNet34(num_classes: int = 10)[source]
gaggle.base_nns.resnet.ResNet50(num_classes: int = 10)[source]

Resnet_x

Properly implemented ResNet-s for CIFAR10 as described in paper [1]. The implementation and structure of this file is hugely influenced by [2] which is implemented for ImageNet and doesn’t have option A for identity. Moreover, most of the implementations on the web is copy-paste from torchvision’s resnet and has wrong number of params. Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following number of layers and parameters: name | layers | params ResNet20 | 20 | 0.27M ResNet32 | 32 | 0.46M ResNet44 | 44 | 0.66M ResNet56 | 56 | 0.85M ResNet110 | 110 | 1.7M ResNet1202| 1202 | 19.4m which this implementation indeed has. Reference: [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun

Deep Residual Learning for Image Recognition. arXiv:1512.03385

[2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py If you use this implementation in you work, please don’t forget to mention the author, Yerlan Idelbayev.

class gaggle.base_nns.resnet_x.ResNet(block, num_blocks, num_classes=10)[source]

Bases: Module

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

gaggle.base_nns.resnet_x.resnet110(num_classes: int = 10)[source]
gaggle.base_nns.resnet_x.resnet1202(num_classes: int = 10)[source]
gaggle.base_nns.resnet_x.resnet20(num_classes: int = 10)[source]
gaggle.base_nns.resnet_x.resnet32(num_classes: int = 10)[source]
gaggle.base_nns.resnet_x.resnet44(num_classes: int = 10)[source]
gaggle.base_nns.resnet_x.resnet56(num_classes: int = 10)[source]

Snet

class gaggle.base_nns.snet.SNetCIFAR(num_classes: int = 10)[source]

Bases: Module

Small custom convolutional cifar model. Has ~149K parameters.

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class gaggle.base_nns.snet.SNetMNIST(num_classes: int = 10)[source]

Bases: Module

Small custom convolutional mnist model

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.