PyTorch Modules and Classes

PyTorch Modules and Classes

TL;DR

  • torch.nn
    • Module: creates a callable which behaves like a function, but can also contain state(such as neural net layer weights). It knows what Parameter (s) it contains and can zero all their gradients, loop through them for weight updates, etc.
    • Parameter: a wrapper for a tensor that tells a Module that it has weights that need updating during backprop. Only tensors with the requires_grad attribute set are updated
    • functional: a module (usually imported into the F namespace by convention) which contains activation functions, loss functions, etc, as well as non-stateful versions of layers such as convolutional and linear layers.
  • torch.optim: Contains optimizers such as SGD, which update the weights of Parameter during the backward step
  • Dataset: An abstract interface of objects with a __len__ and a __getitem__, including classes provided with Pytorch such as TensorDataset
  • DataLoader: Takes any Dataset and creates an iterator which returns batches of data.

Notebook

View in nbviewer

Reference