Experiment on PyTorch Lightning and Catalyst- the high level frameworks for PyTorch

Note: this article is like a condensed version of my own experiment and experience, it does not systematically give a detail for what are the frameworks and how to start using them, better for you to explore them and read the documentation or example from official sites or other articles.

Note: this is also not a comparison article between the frameworks, but much rather, provide a working approach that worked for me, for doing something with the frameworks.

Why I am writing this and where I am from

I used to write my own training code and participated in creating a deep learning framework (just for the team to use).

Writing training code is always tedious, the standard code normally would be like:

Problem is always about a bit experimentation of adding varieties (e.g. logging tensors distribution, doing some custom code to average loss per epoch…), ended up there are 1000 versions of “template” or “past code” to start a new project and endless combine of version to get a better starting point next time.

The purpose of using framework (for me)

Talking about framework, the first few me and my team explored is fast.ai and AlllenNLP (and maybe Keras when deciding what callbacks is needed). What they help is prepare code that’s what regular training/inference flow which also allow flexibility to adapt.

Usually framework would encapsulate the complexity by providing a easy start for common flow (like above), and with some configuration to add functionalities, that comes handy, it would normally comes with callback like following so that user can override the callback to do something else without modifying the training loop directly.

PyTorch Lightning

The first framework I personally started seriously using is PyTorch Lightning, I love it (until I build my vanilla GAN). There are a lot of advantage using it.

First of all, the documentation is very well written, as beginner, it’s super easy to know how to convert ordinary PyTorch training code into PyTorch Lightning.

The core item is Lightning Module and Trainer.

The Lightning Module

The Lightning Module defines how the training run with a bunch of predefined callbacks (e.g. training_step, training_epoch_end…), so you can override any of these if there are things you want it to behave different than the default.

The smallest override needed as mentioned from documentation is:

Consider it as a super torch.nn.Module, so you have your layers (or sub module) defined in __init__() and your forward function is the connection of layers like nn.Module forward().

The training_step() and configure_optimizers(), on the other hand are code and objects that are outside of nn.Module but within the usual training loop.

While my “template” is a bit more complicated than their example:

The Trainer

The trainer, to my understanding, is to provide all those common features while we train, for example, checkpointing, early stopping, defining how many epochs to run…

There are a lot of different features (flags) they provide (please refer to their documentation which provide a full explanation of each of them)

a sample trainer I used mostly:

The beginning hurdles for me

One of the early hurdle for me is the expected input/output per call back function, for example the step function for train seems expect “loss” as output key, instead of any random key like “train_loss”, while validation and test seems also expecting “val_loss”

The second hurdle is TPU usage, the sample on documentation work for training (and make sure you use 1 or 8 for tpu_cores in Trainer config), but the early stopping callback throw error everytime it try to collect the data from different core.

The third one is the reason why I want to try other framework (Catalyst), which is the way it implement handling of multiple optimizers (e.g. encoder decoder with different optimizer or GAN with discriminator and generator), in while the framework decided to process each optimizer per batch (meaning part of the forward() is executed x times with x = num of optimizer), the official documentation GAN sample does work, just I “believe” there could be an implementation that fit better on the structure of the Lightning Module, I would like to write another article on this part in future.

Finally for developer want to know the core flow by reading code

Read this file and you get the idea of how the callbacks run and their parameters and with code snipplets.


Honestly I am super new to Catalyst, it’s unlike PyTorch Lightning, a bit harder to understand for me, but I am excited about knowing more about it.

First of all, the documentation is not yet completed, so my understanding is through source code reading on and example code they provided.

For another reference, you can read: https://medium.com/pytorch/catalyst-101-accelerated-pytorch-bd766a556d92

The core items for Catalyst, I believe is Runner with the different types of Callbacks


The simpliest sample documentation provided as follow:

So this _handle_batch() look alike PyTorch Lightning’s combination of training_step(), validation_step() (and maybe test_step()).

And the runner itself process train (or run) and add config with following (which already see :

First I look and this, I am worried, as it looked to me that I have to manage the loss.backward(), optimizer.step()…as well, the optimizer, model and schedulers are not within the Runner module

Then I see another example of using the 2nd important items — Callbacks


The example from documentation (https://catalyst-team.github.io/catalyst/api/callbacks.html#catalyst.callbacks.batch_overfit.BatchOverfitCallback):

Note that the difference is there is no overriding of Runner (but instead using a subclass SupervisedRunner(), and most standard procedure are already there)

And the callback system is a list (sometimes you would see a dictionary object, which I want to explore more but I cannot make it work yet), and looks like they would all run, and I don’t know how they determine the sequence (until I read the source code), but they work magically.

A bit deeper in Callbacks (class and subclasses)

With my limited understanding, the callbacks are designed very differently compare to other frameworks I experienced.

Catalysts seems to put different component (E.g. Optimizer) or features (like Logging) of training as Callbacks object, and each Callbacks subclass object have callback functions like on_batch_start()… (read Callback class definition from this: https://github.com/catalyst-team/catalyst/blob/master/catalyst/core/callback.py)

I admire the vision of how each component and feature is implemented as Callback and this allow minimum overriding of callbacks and better code separation.

Finally for developer want to know the core flow by reading code

I believe this is where the core flow:

Final words

Thank you so much for surviving through the article as I know how messy it’s being written and it’s not beginner friendly (even I consider myself beginner in the frameworks)

I wish I can explore more and write more about these when I work on them more. They are both very good framework to start with, it’s a matter of choice instead of which one is better than the other.

Good Reference(s)