A recurrent neural network (RNN) processes an input sequence arriving as a stream. It maintains state, i.e. memory. This captures whatever it has seen in the input to this point that it deems relevant for predicting the output (see below).
At each step, the RNN first derives a new state from the current state combined with the new input value. This becomes the new current state. It then outputs a value derived from its current state.
Thus, an RNN may be viewed as a transformer of an input sequence to an output sequence, with the state capturing whatever features it thinks will help it produce the desired output sequence. Learning happens when the network’s output does not match its target.
RNNs have many uses. One notable one is in machine translation. The input might be a sequence of words in English, the desired output might be its good translation into French, another sequence of words, albeit on a different lexicon (French).
In this post, we focus our attention on the simplest RNN that is interesting and useful. ‘Simplest’, ‘interesting’, and ‘useful’ are not falsifiable so you be the judge.
This RNN has one input neuron, one hidden neuron that is sigmoidal, and one output neuron that is also sigmoidal.
This RNN evolves as follows:
h(t) = f(hh*h(t-1) + ih*x(t))
y(t) = g(ho*h(t))
Here f and g are sigmoids. The quantities ih, hh, and ho are the input-to-hidden weight, the hidden-to-hidden weight, and the hidden-to-output weight respectively. These weights are what change when learning happens.
Can this thing do anything interesting? Let’s find out.
Consider a binary sequence, i.e. a sequence of 0s and 1s, that runs forever. We’d like to predict x(t+1) from x(1) through x(t). At time t, the target is y(t) equal to x(t+1).
This problem has a deceptively simple formulation. It has its uses though. To give one a sense of this, imagine trying to predict whether tomorrow will be sunny (1) or cloudy (0) from the binary sequence of daily outcomes (sunny or cloudy) to this point. Think of doing this at a per-city level. Now imagine running this over all the cities on Earth continually (each will have its own RNN). If you created a web service out of it, you might get some visitors.
This imagined use case was put in front of you to get you thinking. No doubt there are many use cases of predicting the next value of a binary sequence.
Let’s first write out the learning equations, as derived from first principles. These equations will provide the scaffolding upon which we will reveal qualitative insights as to what the various weights are learning.
We are given the network’s output at time t. Let’s call it y^(t). The target output is y(t). We will define the error of the network as (½)(y(t)-y^(t))². We could use some other error function. This one is familiar and it suffices for our purpose in this post.
The aim is to change the weights ih, hh, and ho in ways that reduce the error.
We will use the principle of gradient descent in error space, made famous in the multilayer neural network setting as the back-propagation algorithm.
First, let’s write out the negated gradient of the error with respect to the weight ho. (Negated because we want to reduce the error.)
where y == y(t), g == g(h(t)), and h == h(t).
Our update rule for this weight will be
delta ho = eta*(y–y^)*g*(1-g)*h
Here eta is the learning rate which is a small positive value.
First, we note that g*(1-g) is always positive. From this, we see that
sign(delta ho) = sign(y–y^)
Here sign(a) is 1 if a is positive, 0 if a is 0, and -1 if a is negative.
This means that the weight ho should increase (decrease) when the predicted output is less than (greater than) the target output.
In short, ho learns to chase y(t).
Next, consider the negated gradient of the error with respect to the weight ih. It is
where f == f(hh*h(t-1) + ih*x(t)) and x == x(t).
delta ih = eta*(y–y^)*g*(1-g)*ho*f*(1-f)*x
Just like g*(1-g), f*(1-f) is always positive. Plus, when x(t) is 0, delta ih is also 0.
So, when x(t) is 1
sign(delta ih) = sign((y–y^)*ho) = sign(y–y^)*sign(ho)
Similarly, the negated gradient of the error with respect to the weight hh is
from which noting that g*(1-g), f*(1-f), and h(t-1) are all positive we get
sign(delta hh) = sign((y–y^)*sign(ho)
How the weights evolve
Let’s start by tracking how ho changes over time on particular input sequences that are illuminating.
All the discussion below is grounded in empirical analysis. The python code for this is included at the end of this post. The experimental conditions are also described there in case someone wishes to repeat the experiment.
On a long streak of the same value
First, let’s see what happens on the input sequence 1¹⁰. (Ten straight 1s.) The weight ho increases monotonically, settling at 1.98. The monotonic increase makes sense. As ho sees more and more 1s, the vigor with which it chases a 1 increases.
The weight hh also increases monotonically, settling at 0.43. This increase also makes sense. As the observed streak gets longer, hh’s confidence that the streak will continue increases. Why 0.43 here vs 1.98 for ho? Because the error is propagated back to hh through two sigmoids, making it less than the error that ho sees.
In the previous paragraph, the expression “observed streak gets longer” holds for streaks of 0 as well. Let’s elaborate on this by considering the input sequence 0¹⁰.
The weight ho decreases monotonically, settling at -1.98. This makes sense. As ho sees more and more 0s, the vigor with which it chases a 0 increases.
The weight hh on the other hand increases monotonically, settling at 0.5 as it did before. To understand this better, consider the values of the weights after training on the first few 0s. After training on the first 0, ho is negative. After training on the second 0, hh has increased. This is because ho is negative. As is y–y^ since y is 0. So their product is positive.
Why does hh settle at 0.5 on 0¹⁰ whereas it settled at 0.43 on 1¹⁰? This is an artifact of our choices. The weight ih does not learn at all while processing 0¹⁰ since all the inputs are 0. So hh learns to compensate a bit.
hh is learning that long streaks predict that the streak will continue.
On a long streak of 1s followed by a long streak of 0s
Next, let’s see how the weights evolve while training on the input 1¹⁰ 0¹⁰. Below we will use + to denote that the weight increases and — to denote that it decreases. A zero-crossing will be shown by a comma.
ho’s sequence is +¹⁰ -⁵ , -⁵. This is easy to explain. As the first streak (comprising 1s) unfolds, ho increases monotonically. As the second streak (comprising 0s) unfolds, ho decreases monotonically. In the middle of the second streak, ho crosses 0 from above.
hh’s sequence is more interesting. It is +¹⁰ -⁵ , +⁵. The explanation goes like this. When the first few 0s are seen in the second streak, ho is still positive (although it has started decreasing). Since y–y^ is negative, hh must decrease. After the 5th 0 in the second streak, ho becomes positive. From then on hh continues to increase.
How does this play out in the predictions? ho is still positive after training on the 4th 0 in the second streak. So it still predicts that the next value will be 1, albeit with less confidence than before. Fortunately, hh is also still positive (albeit decreasing) after the 4th 0, so it predicts 0 (mildly) since that extends the current streak (of 0s). This divergence tempers ho’s mild enthusiasm for 1 further.
After the 5th 0 is seen, both hh and ho are on the same page. ho has switched to chasing 0s. hh helps it along by reinforcing this prediction since it extends the streak of 0s. This harmony drives the prediction towards 0 even faster.
On alternating 1s and 0s
Let’s see what happens on the sequence (10)¹⁰. After training on this sequence completes, the hh, ih, and ho weights are 0.25, 0.51, and -0.43 respectively. Hmm.
Let’s step back and check what the predicted y^ is after each training step. It’s in the range from 0.44 to 0.5. This suggests that the training resulted in the network becoming conservative. The training apparently was unable to capture the alternating pattern.
Intuition suggests we should be able to rig together an RNN of the same structure to generate an alternating binary sequence. Here is one. Let’s replace the sigmoids by tanh functions. Next, we will set ih to -1, hh to 0, and ho to 1. Next, we will set the gain of the tanh g to 10. We will then initialize x(0) to 1 and run the network forward. It produces an alternating sequence of values close to 1 and close to -1.
The point of this exercise was just to demonstrate that within this structure we can produce this behavior. We won’t cover whether or not we can learn this behavior.
The learning rate eta was set to 5.
The function run runs the network forward for n steps for a given initial value of x, and for specified weights ih, hh, and ho. This implementation is custom to generating the alternating sequence. It uses the tanh function, plus a large slope (10) for g.
The function rnn trains the network on an input sequence X.
return f(x)*(1-f(x))def run(X0, ih, hh, ho, n = 10):
f = lambda x: 2*(1.0/(1.0+math.exp(-x)))-1
g = f
h = 0.5
X = X0
for t in range(n):
o = hh*h + ih*X
h = f(o)
y = g(10*ho*h)
X = ydef rnn(X, eta = 1):
f = lambda x: 1.0/(1.0+math.exp(-x))
g = f
h = 0.0
a, b, c = 0.0,0.0,0.0
res = [[‘y(t)’,’a’,’b’,’c’,’y^(t+1)’,’y^(t+1) c only’]]
for t in range(len(X)-1):
htminus1 = h
o = a*h + b*X[t]
h = f(o)
y = g(c*h)
err = X[t+1] — y
c += eta*err*derivative(g,c*h)*h
b += eta*err*derivative(g,c*h)*c*derivative(f,o)*X[t]
a += eta*err*derivative(g,c*h)*c*derivative(f,o)*htminus1
yhattplus1 = g(c*f(a*h + b*X[t+1]))
yhattplus1conly = g(c*X[t+1])
In this post, we described one of the simplest recurrent neural networks that exhibit interesting behaviors. We met some minimum requirements. It has a hidden layer whose neuron computes a nonlinear activation function.
On the task of learning to predict the next value in a binary sequence, the output weight learns to chase the output. The hidden neuron behaves as it is tracking the length of the current streak. The hidden neuron’s weight predicts that the longer the streak, the more likely it is to continue. This has the effect of tempering the output weight’s enthusiasm for chasing the output when a streak is interrupted by sporadic noise.
For readers who like looking at code, and maybe even running it, we have included the Python code.