Model Registration
Before training the model, modules that need to be trained must be correctly registered. Otherwise, the unregistered modules would NOT be trained without errors or exceptions being thrown. Moreover, when we call model.cuda()
, the unregistered modules will stay on CPU and will not be moved to GPU. In other words, this gotcha is usually hard to notice.
When does this gotcha usually occurs?
- Use python’s
list
ordict
but forget to wrap it withnn.ModuleList
ornn.ModuleDict
.- In this case, PyTorch can not correctly recognize its elements as trainable modules. Therefore, they can NOT be correctly registered and trained.
- An attribute of the model is python’s
list
ordict
, but forget to wrap it withnn.ModuleList
ornn.ModuleDict
.
Example
import torch
import torch.nn as nn
class DummyModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
print("dummy")
return x
class Net(nn.Module):
def __init__(self, num_dummy_modules=4):
super().__init__()
# Here self.dummy_modul_list is just a python list,
# as we do not wrap it with nn.ModuleList
self.dummy_module_list = [DummyModule().cuda() for _ in range(num_dummy_modules)]
print(f"#dummy modules: {len(self.dummy_module_list)}")
def forward(self, x):
for dummy_module in self.dummy_module_list:
x = dummy_module(x)
return x
Now we initialize the model and move it to GPU:
model = Net().to(device)
print(model)
#dummy modules: 4
Net()
We can see that Net
contains nothing. The 4 DummyModule
are not registered.
Now we use nn.ModuleList
to wrap self.dummy_modul_list
and covert its element to registered trainable modules.
class Net(nn.Module):
def __init__(self, num_dummy_modules=4):
super().__init__()
self.dummy_module_list = [DummyModule().cuda() for _ in range(num_dummy_modules)]
# Register elements in self.dummy_module_list as trainable modules
self.dummy_module_list = nn.ModuleList(self.dummy_module_list)
print(f"#dummy modules: {len(self.dummy_module_list)}")
def forward(self, x):
for dummy_module in self.dummy_module_list:
x = dummy_module(x)
return x
model = Net().to(device)
print(model)
#dummy modules: 4
Net(
(dummy_module_list): ModuleList(
(0): DummyModule()
(1): DummyModule()
(2): DummyModule()
(3): DummyModule()
)
)