PyTorch module__call__() vs forward()
2 min readSep 6, 2021
In Python, there is this built-in function __call__() for a class you can override, this make your object instance callable.
class MyClass:
def __init__(self):
//... some init
def __call__(self, input1):
self.my_function(input1)
def my_function(self, input1):
print(f"MyClass - print {input1}")my_obj = MyClass()// same as calling my_obj.my_function("haha")
my_obj("haha") // expect to print "MyClass - print haha"
In PyTorch, the nn.module is implemented so that one can treat the module as callable like above, e.g.
class myLayer(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(10,1)
def forward(self, input_tensor):
return self.layer1(input_tensor)model = myLayer()
input_tensor = torch.rand((2,10))
//treat as callable, which is same as model.forward(tensor)
model(input_tensor)
So what’s the different? It’s the hook mechanism that PyTorch built in the nn.Module class.
Seeing the code, there is this _call_impl(…) function, which
So we can see there are some action before and after the main call (line 1071 above, “result = forward_call(*input, **kwargs)” which is the calling forward of the module.
For more understanding of the hook system, one can refer to: