PyTorch Lightning Hooks and Callbacks — my limited understanding

Background — What and why of callbacks in framework(s)

One of the important part of a deep learning framework would be the balance between the ease of use and flexibility to change.

With ease of use, I like PyTorch Lightning for their rich features which already encapsulated in the core structure (flow) while one can control the run through config/settings (flags to Trainer object).

The simplified idea of a frameworks and position of callback is as follow:

Imagine what we need to do if we want to manipulate the batch, we can add code before passing it into model:

How the framework handle such flexibility for user to add their code without amending the function? The answer is callback (planned function calls at specific location)

Assume a base class like below to have a on_batch_begin() to take in a batch and return a batch (default to be returning the same batch)

Assume user want to amend the behavior, he/she can inherit the class and override the on_batch_begin():

How PyTorch Lightning approach the callback

For a full description and high level placement of all callbacks (hooks) available in PyTorch Lightning, the documentation gave a good detail.

A valid implementation (which I used to do) is something like:

One can imagine if we override all the callback hooks, the Lightning Module itself can be huge and difficult to keep track.

So what PyTorch Lightning does is to include some Callback class, as for example above, they are already in the built-in call backs:

The official document have the following example:

Note that the callbacks being pass to the Trainer is an array of Callback class instances, so one can separate different categories of actions into different callbacks.

How is it being implemented in PyTorch Lightning

The core item I believe is several abstract class, for instance TrainerCallbackHookMixin for main training flow.

According to the code, whenever the main training flow call a particular planned hook, it would then loop through all the passed in callbacks to the Trainer instance.

And inside the main training flow, this is how the hook being called — by calling “call_hook()” function:

And the call_hook function is implemented as below, and note the highlighted region, and it “imply” it would call the callbacks before calling the overridden hook inside the PyTorch Lightning Module

This align with result observed — the message print from callback’s hook function before the hook overridden in PyTorch Lightning Module.