PyTorch module__call__() vs forward()

class MyClass:
def __init__(self):
//... some init
def __call__(self, input1):
self.my_function(input1)
def my_function(self, input1):
print(f"MyClass - print {input1}")
my_obj = MyClass()// same as calling my_obj.my_function("haha")
my_obj("haha") // expect to print "MyClass - print haha"
class myLayer(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(10,1)
def forward(self, input_tensor):
return self.layer1(input_tensor)
model = myLayer()
input_tensor = torch.rand((2,10))
//treat as callable, which is same as model.forward(tensor)
model(input_tensor)
override the __call__ with the _call_impl (https://github.com/pytorch/pytorch/blob/v1.9.0/torch/nn/modules/module.py#L1101)
The _call_impl(…) function (https://github.com/pytorch/pytorch/blob/v1.9.0/torch/nn/modules/module.py#L1045)

--

--

--

Love podcasts or audiobooks? Learn on the go with our new app.

Recommended from Medium

Topic DL01: Activation functions and its Types in Artifical Neural network

MEDICAL FRAUD DETECTION- An End to End ML Case Study

Using Face Recognition Launch AWS Instance with EBS, Send Mail and WhatsApp Message.

The Science behind the Machines with the power of Vision — Computer Vision

Explain like I’m five: Artificial neurons

Toward fine-tuning a state of the art Natural Language Inference (NLI) model for Persian

Autonomous Journey through Term 2 of Self-Driving Car Nano-degree with Udacity

Collision Detection System using CoreML

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Stephen Cow Chau

Stephen Cow Chau

More from Medium

How To Detect Partially-Occluded Objects Using Temporal Context

From AutoML powered development to cloud-native deployment, MONAI marches forward with four new…

SimCLR, Part 2: The Encoder, Projection Head, and Loss Function

PADL is the next ML-ops tool you should learn