Blog: A Brief Summary of Maths Behind RNN (Recurrent Neural Networks)
There is a lot of buzz now a days about machine learning, deep learning and artificial neural networks. Programmers just want to work with these fancy terms, but majority of these programmers really do not want to know what actually happens under the hood of neural network. So today we are going to discuss the Recurrent Neural Networks and all the basic mathematics behind them that makes them to do those things that other neural networks can only dream.
The purpose of this post is to provide an intuition about the functioning of recurrent neural networks and purpose and structure of RNN.
A neural network usually takes an independent variable X (or a set of independent variables ) and a dependent variable y then it learns the mapping between X and y (we call this Training), Once training is done , we give a new independent variable to predict the dependent variable.
But what if the order of data matters? just imagine what if the order of all independent variables matter?
Let me explain visually.
Just assume every ant is an independent variable if one ant goes in a different direction , it does not matter for other ants right? But what if the order of the ants matters ?
if one ant misses or turns away from the group, it affects the following ants.
So which data where the order matter in our ML space ????
- Natural Language Data where the order of words matter
- Speech data
- Time series data
- Video/Music Sequences data
- Stock markets data
So how RNN’s solve “the whole order matters thing” data? We take natural text data as an example to explain RNN’s.
Let’s say i am doing sentiment analysis on user reviews on a movie
“This movie is good” → Positive “This movie is bad” → negative
We can classify these by using simple model “Bag of words” and we can predict (Positive or Negative) but wait…
what if the review is “This movie is not good”
The BOW model may say it’s a positive sign but actually it’s not.
The RNN understands it and predicts that it’s negative.
First let’s admit that here the order of the text matters. cool? okay
RNN has the following models
- One to Many
RNN takes one input lets say an image and generates a sequence of words.
2.Many to One
RNN takes sequence of words as input and generates one output.
3.Many to Many
Currently we are focusing on 2nd model “Many to One”. In RNN’s Input is considered as time steps.
ex : input(X) = [“this”, “movie”, “is”, “good”]
Time stamp for “this” is x(0), “movie” is x(1), “is” is x(2) , and “good” is x(3).
Let’s dive into the mathematical world of RNNs.
First let’s understand what RNN cell contains! I hope and assume you know Feed Forward Neural Networks, Summary of FFNN is:
In Feed forward neural network we have X(input) and H(Hidden) and y(output). We can have as many hidden layers as we want but weights (W)for every hidden layers are and the weights for every neuron corresponding to the input are different.
Above we have weights Wh0 and Wh1, which corresponds to two different layers, while Wh00, Wh01 and so on, represents different weights corresponding to different neuron and with respect to the input.
The RNN cell contains a set of feed forward neural networks cause we have time steps. The RNN has: sequential input, sequential output, multiple time-steps, and multiple hidden layers.
Unlike FFNN , here we calculate hidden layer values not only from input values but also previous time step values and Weights ( W ) at hidden layers are same for time steps. Here is the complete picture for RNN and it’s Math.
In the picture we are calculating the Hidden layer time step (t) values so
Ht = Activation function(input * weights of Hidden Layer (H)+ W * Ht-1)
Yt = SoftMax(weights of Hidden Layer (H)* Ht)
Ht-1 is the previous time step and as i said W’s are same for all time-steps. The activation function can be Tanh, Relu, Sigmoid, etc.
Above we calculated only for Ht similarly we can calculate for all other time-steps.
- Calculate Ht-1 from U and X
- Calculate yt-1 from Ht-1 and V
- Calculate Ht from U,X,W and Ht-1
- Calculate yt from V and Ht and so on…
- U and V are weight vectors, different for every time step.
2. We can even calculate hidden layer( all time steps ) first, then calculate Y values.
3. Weight vectors are random initially.
Once Feed forwarding is done then we need to calculate the error and back propagate the error using back propagation. We use Cross entropy as cost function ( assume you know so not going into details, otherwise if you don’t just click on the respective hyper-links to learn.)
BPTT ( Back propagation through time )
If you know how Normal neural network works , the rest is pretty easy , if you don’t know, here is my article that talks about Artificial Neural Networks.
We need to calculate the below terms
- how much does the total error change with respect to the output (hidden and output units) ? (or how much is a change in output)
- how much does the output change with respect to weights (U,V,W)? (or how much is a change in weights)
Since W’s are same for all time steps we need to go all the way back to make an update.
Remember the Back Propagation for RNN is as same as for artificial neural networks Back Propagation, but here Current time step is calculated based on the previous time step so we have to traverse all the way back.
If we apply chain rule which looks like this:
W’s are same for all the time steps so the chain rule expands more and more
A similar but a different way of working out the equations can be seen in Richard Sochers’s Recurrent Neural Network lecture slide.
So here Et is same as our J( θ)
U, V and W should get updated using any optimization algorithms like gradient descent ( Take a look at my story here GD).
Now if we go back and talk about our sentiment problem here is the RNN for that:
We give word vectors or one hot encoding vectors for every word as input and we do feed forward and BPTT ,Once the training is done, we can give new text for prediction. It learns something like wherever “not” + positive word = negative. I hope you get that.
Problems with RNN → Vanishing/exploding gradient problem
Since W’s are same for all time-steps, during back propagation as we go back adjusting the weights, The signal gets either too weak or too strong which cause either vanishing or exploding problem.
To avoid this we use either GRU or LSTM which I will cover in the posts.