Blog: Implementing SPADE using fastai
I was fascinated by the results of Nvidia’s latest research paper when it came out. If you haven’t looked at the paper result then you are missing out. Also, have a look at SPADE based GAN in action in the GIF. I couldn’t wait for the official implementation to be released and decided to implement the paper myself using our favorite fastai library. fastai provides a very neat API which can be used to develop highly customizable models. Specifically, it’s datablock and callbacks API is mind-blowing.
The paper is a very simple idea which is reported to give huge performance boosts on the task of photo-realistic image synthesis using semantic maps as inputs to the GAN model. Today I am gonna implement it block by block.
If you follow this blog till the end, you will learn about fastai and PyTorch and implementing new architectures and using new datasets.
What is SPADE?
SPADE stands for SPatially ADaptivE Normalization which is just a normalization technique like Batch Norm, Instance Norm, etc. It is used in GANs to generate synthetic photo-realistic Images from the segmentation mask. The paper uses this normalization technique in all the layers of the generator. The idea of SPADE is that instead of learning the affine parameters in the Batch Norm layer they use the semantic map to compute those affine parameters. Confused what affine parameters are?
This is an image from the batch norm paper. Those little gamma and beta guys are the affine parameters. Those parameters are learnable and gives model the freedom to choose any distribution they want to become. So, the SPADE says why not use the semantic maps to compute those gammas and betas so called the scaling and shifting parameters respectively.
Instead of using the randomly initialized scaling and shifting parameters, SPADE will utilize the semantic maps to compute those and THAT’S IT! They did this because the conventional normalization layers washes away the information in input semantic masks. SPADE helps in effectively propagating the semantic information throughout the network. Below is an architecture of a SPADE block.
Now, all that I spoke must have started to make sense. The traditional normalization is still performed, just the affine parameters are different. Basically, they are just a couple of convolutions applied on the input semantic map. I am really thankful to NVIDIA AI researchers for making this paper so much visual. So now let’s get set and write the code.
All of the code that I will be showing here is from my Github repository here. I will show you the implementation of a version of the SPADE paper with fewer features which is in this notebook, but I have implemented other additions in other notebooks. Since it is a GAN, it will have a generator block and a discriminator block. The generator block is what contains SPADE layer. So we will start first by implementing the basic SPADE block.
This image is just the more detailed version of the earlier picture. It tells exactly which convolutions to be performed on the semantic map.
It accepts the segmentation mask and features. The segmentation mask is just the straight and simple long integer 2D mask, the paper advises to use embedding layer for the classes, but I decided to go simple, therefore the first convolutional layer number of input filters is 1. Then it resizes the mask in the size of the features. It does that because SPADE layer will be used at every layer, so it needs to know the size of the features so that the mask can be resized for the operation of the affine paramete. Take a look at when I initialize the BatchNorm2d layer I set affine to false to not use the default affine parameters. Spectral Normalization is used in all the convolutional blocks in the paper to stabilize the GAN training. In PyTorch a new layer is implemented by inheriting from “nn.Module” and by implementing “__init__” and “forward” functions. The variable names ni and nf are used for number of input filters and number of output filter in convolutional layers respectively.
SPADE Residual Block
After implementing the SPADE block we can now use it in our SPADE ResBlk and its pretty straight forward.
Now, we have got the basic block setup and now is the time to stack them up as shown in the architecture below.
I listed out the number of feature maps that individual layers will be throwing out and created the generator using a for loop.
To keep things simple in my code, I use these tricks like initializing the parameters of the module with a global variable. You can see that the nfs variable contains the output of all the SPADE residual blocks which used in initializing the layers in the generator. In the end, I make the output from the tanh layer to be in the range 0–1, so that it is simpler to visualize.
The discriminator is a Multi-Scale, Patch Gan based Discriminator. Multi-Scale means that it classifies the given image at different scale. Patch Gan based discriminator the final output layer is convolutional and the spatial mean is taken. The output from multiple scale is just added to give the final output.
The discriminator takes both mask and the generated/real images at one time and output the activation. As this can be seen in the code the forward method in the discriminator is taking both mask and image as input.
Now we have implemented the complete architecture of the model. The complete architecture looks like this.
It’s now time to implement the loss functions.
Since it’s a gan there are two loss functions one for the generator and the other one for the discriminator. The loss function is the the hinge loss from SAGAN paper which I mentioned in my earlier blog. The loss unction is very simple and is literally just one line of code. BUT, it is the part where I spent the most time and realized how important is a loss function in a deep learning problem.
The above two equations can be written in code in the code block below. Sometimes these scary looking equations are just a single line of code.
Here we have the paper implemented in code, but we need to 2 more things. Firstly we need the data to pass to the model and preprocess the data and a training loop to train our model. The data preparation step used to haunt me earlier, but not much after the fastai version 1 has been released. It makes a lot of these things very simple and quick.
I used land cover classification data provided by Chesapeake Conservancy Land Cover Data Project. I extracted all the images of from the classified raster using ArcGIS Pro’s “Export Training Data for Deep Learning” tool. Once I got images on the disk, I used fastai datablock API to load them to be passed to create a fastai learner and call its fit method. After creating appropriate classes I used the following code to create a fastai databunch object.
The classes I created were SpadeItemList, it’s just the fastai’s SegmentaionItemList reversed. So what fastai does is that, you can create a class inheriting from a fastai Item or Label class and override a couple of methods according to your needs and it will create the databunch object using that class. Databunch object contains attributes contaning PyTorch datasets and dataloaders. It contains a method that shows your data called show_batch. Here is the output of my show_batch.
Now, the data preparation step is complete, so now have to train the model.
Let’s Train It.
In fastai we have to create a learner which contains the method fit which actually trains the model. But this is not a simple model like image classification, in GAN we need to switch the model from generator to discriminator and viceversa. To use the fastai library we need to create callback mechanism to do that. I copied fastai library’s GANLearner and modified it a bit to do that.
The self.gen_mode tells the GANModule when to use generator and when to use discriminator. There is a callback implemented in fastai which switches the GAN at fixed intervals. I set the discriminator steps to five times for each generator step. This is done using FixedGANSwitcher.
There are other callbacks that are used. Please look at the code in my Github repository.
Now we can run fit method to train the model.
The results are not that photorealistic, but given enough time and compute and removing any bugs present will make the model to generate good images. Below are the initial images that the model generated.
And after 100 epochs of training. It started to produce some detailed images but with some artifacts.
Spade is just a normalization layer which helps in generating photorealistic images. I implemented it here: https://github.com/divyanshj16/SPADE