PyTorch Lightning Hooks and Callbacks — my limited understanding

Background — What and why of callbacks in framework(s)

One of the important part of a deep learning framework would be the balance between the ease of use and flexibility to change.

# the extremely simplified high level structure of training loopfor epoch in epochs:
for batch in dataloader:
model_output = model(x_in_batch)
loss = loss_function(target, model_output)

loss.backward()
optimizer.step()
optimizer.zero_grad()
for epoch in epochs:
for batch in dataloader:
# Code add here to transform batch
model_output = model(x_in_batch)
loss = loss_function(target, model_output)

loss.backward()
optimizer.step()
optimizer.zero_grad()
class BaseTrainer():
def train(self):
For epoch in epochs:
For batch in dataloader:
# calling callback that is at start of processing batch
batch = on_batch_begin(batch)
model_output = model(x_in_batch)
loss = loss_function(target, model_output)

loss.backward()
optimizer.step()
optimizer.zero_grad()
def on_batch_begin(self, batch):
# default to bypass and return the batch
return batch
class ChildTrainer(BaseTrainer):
def on_batch_begin(self, batch):
# Do some transform
new_batch = transform(batch)
return new_batch

How PyTorch Lightning approach the callback

For a full description and high level placement of all callbacks (hooks) available in PyTorch Lightning, the documentation gave a good detail.

Screen capture of particular section in documentation
class MyTrainer(pl.LightningModule):
def on_epoch_end(self):
# do some custom visualization for result of last epoch
...
# check if we should do early stop
...

# check if we need to save check point of model
...
The built-in callbacks, see documentation for more detail

How is it being implemented in PyTorch Lightning

The core item I believe is several abstract class, for instance TrainerCallbackHookMixin for main training flow.

This is in github project folder path: pytorch_lightning/trainer/callback_hook.py
This is in github project folder path: pytorch_lightning/loops/batch/training_batch_loop.py