Model Registration

Before training the model, modules that need to be trained must be correctly registered. Otherwise, the unregistered modules would NOT be trained without errors or exceptions being thrown. Moreover, when we call model.cuda(), the unregistered modules will stay on CPU and will not be moved to GPU. In other words, this gotcha is usually hard to notice.

When does this gotcha usually occurs?

  1. Use python’s list or dictbut forget to wrap it with nn.ModuleList or nn.ModuleDict.
    • In this case, PyTorch can not correctly recognize its elements as trainable modules. Therefore, they can NOT be correctly registered and trained.
  2. An attribute of the model is python’s list or dict, but forget to wrap it with nn.ModuleList or nn.ModuleDict.

Example

import torch
import torch.nn as nn

class DummyModule(nn.Module):

    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        print("dummy")
        return x
class Net(nn.Module):

    def __init__(self, num_dummy_modules=4):
        super().__init__()
        # Here self.dummy_modul_list is just a python list, 
        # as we do not wrap it with nn.ModuleList
        self.dummy_module_list = [DummyModule().cuda() for _ in range(num_dummy_modules)]
        print(f"#dummy modules: {len(self.dummy_module_list)}")

    def forward(self, x):
        for dummy_module in self.dummy_module_list:
            x = dummy_module(x)
        return x

Now we initialize the model and move it to GPU:

model = Net().to(device)
print(model)
#dummy modules: 4
Net()

We can see that Net contains nothing. The 4 DummyModule are not registered.

Now we use nn.ModuleList to wrap self.dummy_modul_list and covert its element to registered trainable modules.

class Net(nn.Module):

    def __init__(self, num_dummy_modules=4):
        super().__init__()
        self.dummy_module_list = [DummyModule().cuda() for _ in range(num_dummy_modules)]
        # Register elements in self.dummy_module_list as trainable modules
        self.dummy_module_list = nn.ModuleList(self.dummy_module_list)
        print(f"#dummy modules: {len(self.dummy_module_list)}")

    def forward(self, x):
        for dummy_module in self.dummy_module_list:
            x = dummy_module(x)
        return x
model = Net().to(device)
print(model)
#dummy modules: 4
Net(
  (dummy_module_list): ModuleList(
    (0): DummyModule()
    (1): DummyModule()
    (2): DummyModule()
    (3): DummyModule()
  )
)

References