- Whenever you define a neural network, either as a vanilla Python function or as a
nn.Module, it should only take a mini-batch of training data as input. For e.g., for a net that works on images, the input will only be
Xwhose shape could be
[256, 1, 28, 28]. Here, 256 is the number of items in each mini-batch, 1 is the number of channels (for gray-scale images) and 28 * 28 is the size of the images.
- First element of the input is always the batch-size.
- Loss function should only return a scalar (or a
Tensorof size 1). That is because PyTorch’s
backwardfunction only works on scalars.