PyTorch module__call__() vs forward()

Stephen Cow Chau
2 min readSep 6, 2021

In Python, there is this built-in function __call__() for a class you can override, this make your object instance callable.

class MyClass:
def __init__(self):
//... some init
def __call__(self, input1):
self.my_function(input1)
def my_function(self, input1):
print(f"MyClass - print {input1}")
my_obj = MyClass()// same as calling my_obj.my_function("haha")
my_obj("haha") // expect to print "MyClass - print haha"

In PyTorch, the nn.module is implemented so that one can treat the module as callable like above, e.g.

class myLayer(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(10,1)
def forward(self, input_tensor):
return self.layer1(input_tensor)
model = myLayer()
input_tensor = torch.rand((2,10))
//treat as callable, which is same as model.forward(tensor)
model(input_tensor)

So what’s the different? It’s the hook mechanism that PyTorch built in the nn.Module class.

Seeing the code, there is this _call_impl(…) function, which

override the __call__ with the _call_impl (https://github.com/pytorch/pytorch/blob/v1.9.0/torch/nn/modules/module.py#L1101)
The _call_impl(…) function (https://github.com/pytorch/pytorch/blob/v1.9.0/torch/nn/modules/module.py#L1045)

So we can see there are some action before and after the main call (line 1071 above, “result = forward_call(*input, **kwargs)” which is the calling forward of the module.

For more understanding of the hook system, one can refer to:

--

--