Why ‘Learn To Forget’ In Recurrent Neural Networks

Arun Jagota Arun Jagota
March 12, 2021 AI & Machine Learning

Illustrated with a simple example

Consider the following binary classification problem. The input is a binary sequence of arbitrary length. We want the output to be 1 if and only if a 1 occurred in the input but not too recently. Specifically, the last n bits must be 0.

We can also write this problem as one on language recognition. For n = 4, the language, described as a regular expression, is (0 or 1)*10000*.

Below are some labeled instances for the case n = 3.

000 → 0, 101 → 0, 0100100000 → 1, 1000 → 1

Why this seemingly strange problem? It requires remembering that (i) a 1 occurred in the input and (ii) not too recently. As we will see soon, this example helps explain why simple recurrent neural networks are inadequate and how injecting a mechanism that learns to forget helps.

Simple Recurrent Neural Network

Let’s start with a basic recurrent neural network (RNN). x(t) is the bit, 0 or 1, that arrives at time t in the input. This RNN maintains a state h(t) that tries to remember whether it saw a 1 sometime in the past. The output is just read out from this state after a suitable transformation.

More formally, we have

h(t) = tanh(a*h(t-1) + b*x(t) + c)
y(t) = sigmoid(d*h(t))

Next, let’s consider the following (input sequence, output sequence) pair and assume n = 3.

x 10000000
y 00011111

To discuss the behavior and learning of the RNN on this pair, it will help to unroll the network in time as is commonly done.

Why ‘Learn To Forget’ In Recurrent Neural Networks
Simple RNN unrolled in time

Think of this as a pipeline with stages. The state travels from left to right and gets modified during the process by the input at a stage.

Let’s walk through what happens inside a stage in a bit more detail. Consider the third stage. It inputs the state h2 and the next input symbol x3. h2 may be thought of as a feature derived from x1 and x2 towards predicting y3. The box first computes the next state h3 from these two inputs. h3 is then carried forward to the next stage. h3 also determines the stage’s output y3.

Consider what happens when the input 1000 is seen. y4 is 1 and since y^4 is less than 1 (which is always the case) there is some error. Following the backpropagation-through-time learning strategy, we will ripple the error back through time to the extent needed to update the various parameters.

Consider the parameter b. There are 4 instances of it, attached to x1 through x4 respectively. The instances attached to x2 through x4 don’t change since x2 through x4 are all 0. So none of these b instances have any impact on y^4. The instance of b attached to x1 increases as making this change gets y^4 closer to 1.

As we continue seeing x5, x6, x7, x8, and their corresponding targets y5, y6, y7, and y8, the same learning behavior will happen. b will keep increasing. (Albeit less so as we need to backpropagate the errors further back in time to get to x1.)

Now imagine x9 is 1. y9 must be 0. y^9 is however large. This is because the parameter b has learned that xi = 1 predicts yj = 1 for j >= i. b has no way of enforcing that xi = 1 must be followed only by 0s, numbering at least 3.

In short, this RNN is unable to capture the joint interaction of xi = 1 and all the bits that follow it are 0s, numbering at least 3, towards predicting yj. Also note that this is not a long-range influence. n is only 3. So the weakness of the RNN on this example cannot be explained in terms of vanishing error gradients when doing backpropagation-through-time [2]. There is something else going on here.

An RNN that learns to forget

Now considerthis version

z(t)    = sigmoid(a*x(t) + b)
hnew(t) = tanh(c*x(t) +d)
h(t) = (1-z(t))*h(t-1) + z(t)*hnew(t)
y(t) = sigmoid(e*h(t))

We didn’t just pull it out of a hat. It is a key one in a popular gated recurrent neural network called GRU. We took this equation from it’s description in [1].

This RNN has an explicit mechanism to forget! It is z(t), a value between 0 and 1, denoting the degree of forgetfulness. When z(t) approaches 1, the state h(t-1) is completely forgotten.

When h(t-1) is completely forgotten, what should h(t) be? We encapsulate this is in an explicit function hnew(t) denoting “new state”. hnew(t) is derived solely from the present input. This makes sense because if h(t-1) is to be forgotten, all we have in front of us is the new input x(t).

More generally, the next state h(t) is a mixture of the previous state h(t-1) and a new state hnew(t), modulated by z(t).

Does this RNN have the capability to do better on this problem? We will answer this question in the affirmative by prescribing a solution that works. The accompanying explanation will reveal what roles the various neurons play in making this solution work.

Consider x(t) is 1. y(t) must be 0. So we want to drive y^(t) towards 0. We can make this happen by setting e to a sufficiently negative number (say -1) and forcing h(t) to be close to 1. One way to get the desired h(t) is to force z(t) to be close to 1 and set c to a sufficiently positive number and d such that c+d is sufficiently positive. We can force z(t) to be close to 1 by setting a to be a sufficiently positive number and b such that a+b is sufficiently positive.

This prescription operates as if

If x(t) is 1
Set hnew(t) to close to 1.
Reset h(t) to hnew(t)
Drive y^(t) towards 0 by setting e sufficiently negative

The case x(t) is 0 is more involved as y(t) depends on the recent past values of x. Let’s explain it in the following setting:

Time … t    t+1     t+2      t+3
x … 1 0 0 0
y … 0 0 0 1
hnew … 1 D=tanh(d) D D
z … 1 ½ ½ ½
h … 1 ½(1+D) ½(h(t+1)+D) ½(h(t+2)+D)h^ … >>0 >>0 >>0 <<0y^ … → 0 → 0 → 0 → 1

There is a lot in here! So let’s walk through it row by row.

We are looking at the situation when processing the last 4 bits of the input x = …1000 in sequence. The corresponding target is y = …0001. We assume that the parameters of the RNN have been somehow chosen just right (or learned) as surfaced below. (These have to be consistent with the settings we used when x(t) was 1, of course.) In short, we are describing the behavior of a fixed network in this situation.

Now look at hnew. When x(t) is 1, we have already discussed that hnew(t) should approach 1. When x(t) is 0, hnew(t) equals tanh(cx(t)+d)=tanh(d). We are calling this D.

Next look at z. When x(t) is 1, we already discussed that z(t) should approach 1. When x(t) is 0, since we want to remember the past, let’s set z(t) to approximately ½. For this, we just need to set b to 0. This can be achieved without unlearning the z(t) that works when x(t) is 1.

For the remaining rows, let’s start from the last row and work our way in. In the y^ row, we describe what we want, given the y targets. Given that we have fixed e to a sufficiently negative number, this gives us what we want from our states. We call them h^.

So now all that remains is to show that h can be made to match up with h^. First let’s zoom into these two rows and while at it also transform h to a more convenient form

h  … 1        ½ + ½D        ¼ + ¼D + ½D          ⅛ + ⅛ D + ¼ D + ½ D
h^ … >> 0 >> 0 >> 0 << 0

It can be seen that choosing D such that -⅓ < D < -1/7 will meet the desiderata. It’s easy to find d such that tanh(d) is in this range.

The prescription for the case x(t) = 0 may be summarized as

If x(t) is 0
Set hnew(t) to be slightly negative.
Set h(t) as average of h(t-1) and hnew(t)

So as 0s that follow a 1 are seen, h(t) keeps dropping. If enough 0s are seen, h(t) becomes negative.

Summary

In this post, we discussed recurrent neural networks with and without an explicit ‘forget’ mechanism. We discussed it in the context of a simply-described prediction problem which the simpler RNN is incapable of solving. The RNN with the ‘forget’ mechanism is able to solve this problem.

This post will be useful to readers who’d like to understand how simple RNNs work, how an enhanced version with a forgetting mechanism works (GRU in particular), and how the latter improves upon the former.

Further Reading

  1. https://colah.github.io/posts/2015-08-Understanding-LSTMs/
  2. https://www.superdatascience.com/blogs/recurrent-neural-networks-rnn-the-vanishing-gradient-problem
  • Experfy Insights

    Top articles, research, podcasts, webinars and more delivered to you monthly.

  • Arun Jagota

    Tags
    GRURNNSequence Learning
    Leave a Comment
    Next Post
    AI Journalism: Possibilities, Limitations, And Outcomes

    AI Journalism: Possibilities, Limitations, And Outcomes

    Leave a Reply Cancel reply

    Your email address will not be published. Required fields are marked *

    More in AI & Machine Learning
    AI & Machine Learning,Future of Work
    AI’s Role in the Future of Work

    Artificial intelligence is shaping the future of work around the world in virtually every field. The role AI will play in employment in the years ahead is dynamic and collaborative. Rather than eliminating jobs altogether, AI will augment the capabilities and resources of employees and businesses, allowing them to do more with less. In more

    5 MINUTES READ Continue Reading »
    AI & Machine Learning
    How Can AI Help Improve Legal Services Delivery?

    Everybody is discussing Artificial Intelligence (AI) and machine learning, and some legal professionals are already leveraging these technological capabilities.  AI is not the future expectation; it is the present reality.  Aside from law, AI is widely used in various fields such as transportation and manufacturing, education, employment, defense, health care, business intelligence, robotics, and so

    5 MINUTES READ Continue Reading »
    AI & Machine Learning
    5 AI Applications Changing the Energy Industry

    The energy industry faces some significant challenges, but AI applications could help. Increasing demand, population expansion, and climate change necessitate creative solutions that could fundamentally alter how businesses generate and utilize electricity. Industry researchers looking for ways to solve these problems have turned to data and new data-processing technology. Artificial intelligence, in particular — and

    3 MINUTES READ Continue Reading »

    About Us

    Incubated in Harvard Innovation Lab, Experfy specializes in pipelining and deploying the world's best AI and engineering talent at breakneck speed, with exceptional focus on quality and compliance. Enterprises and governments also leverage our award-winning SaaS platform to build their own customized future of work solutions such as talent clouds.

    Join Us At

    Contact Us

    1700 West Park Drive, Suite 190
    Westborough, MA 01581

    Email: support@experfy.com

    Toll Free: (844) EXPERFY or
    (844) 397-3739

    © 2023, Experfy Inc. All rights reserved.