Introduction to Recurrent Networks in TensorFlow

Recurrent networks like LSTM and GRU are powerful sequence models. I will explain how to create recurrent networks in TensorFlow and use them for sequence classification and labelling tasks.

If you are not familiar with recurrent networks, I suggest you take a look at Christopher Olah’s great article first. On the TensorFlow part, I also expect some basic knowledge. The official tutorials are a good place to start.

Defining the Network

To use recurrent networks in TensorFlow we first need to define the network architecture consiting of one or more layers, the cell type and possibly dropout between the layers. In TensorFlow, we build recurrent networks out of so called cells that wrap each other.

import tensorflow as tf

num_units = 200
num_layers = 3
dropout = tf.placeholder(tf.float32)

cells = []
for _ in range(num_layers):
  cell = tf.contrib.rnn.GRUCell(num_units)  # Or LSTMCell(num_units)
  cell = tf.contrib.rnn.DropoutWrapper(
      cell, output_keep_prob=1.0 - dropout)
  cells.append(cell)
cell = tf.contrib.rnn.MultiRNNCell(cells)

Simulating Time Steps

We can now add the operations to the graph that simulate the recurrent network over the time steps of the input. We do this using TensorFlow’s dynamic_rnn() operation. It takes the a tensor block holding the input sequences and returns the output activations and last hidden state as tensors.

# Batch size x time steps x features.
data = tf.placeholder(tf.float32, [None, None, 28])
output, state = tf.nn.dynamic_rnn(cell, data, dtype=tf.float32)

Sequence Classification

For classification, you might only care about the output activation at the last time step. We transpose so that the time axis is first and use tf.gather() for selecting the last frame. We can’t just use output[-1] because unlike Python lists, TensorFlow doesn’t support negative indexing yet.

output, _ = tf.nn.dynamic_rnn(cell, data, dtype=tf.float32)
output = tf.transpose(output, [1, 0, 2])
last = tf.gather(output, int(output.get_shape()[0]) - 1)

The code below adds a softmax classifier ontop of the last activation and defines the cross entropy loss function. Here is the complete gist for sequence classification.

out_size = target.get_shape()[2].value
logit = tf.contrib.layers.fully_connected(
    last, out_size, activation_fn=None)
prediction = tf.nn.softmax(logit)
loss = tf.losses.softmax_cross_entropy(target, logit)

For now we assume sequences to be equal in length. Please refer to my other post on handling sequences of different length.

Sequence Labelling

For sequence labelling, we want a prediction for each timestamp. However, we share the weights for the softmax layer across all timesteps. How do we do that? By flattening the first two dimensions of the output tensor. This way time steps look the same as examples in the batch to the weight matrix. Afterwards, we reshape back to the desired shape.

out_size = target.get_shape()[2].value
logit = tf.contrib.layers.fully_connected(
    output, out_size, activation_fn=None)
prediction = tf.nn.softmax(logit)

Let’s say we predict a class for each frame, so we keep using cross entropy as our loss function. Here we have a prediction and target for every time step. We thus compute the cross entropy for every time step and sequence in the batch, and then average along these two dimensions. Here is the complete gist for sequence labelling.

flat_target = tf.reshape(target, [-1] + target.shape.as_list()[2:])
flat_logit = tf.reshape(logit, [-1] + logit.shape.as_list()[2:])
loss = tf.losses.softmax_cross_entropy(flat_target, flat_logit)
loss = tf.reduce_mean(loss)

Conclusion

That’s all. We have learned how to construct recurrent networks in TensorFlow and use them for sequence learning tasks. Please ask any questions below if you couldn’t follow.

Updated 2016-08-17: TensorFlow 0.10 moved the recurrent network operations from tf.models.rnn into the tf.nn package where they live along the other neural network operations now. Cells can now be found in tf.nn.rnn_cell.

Updated 2016-05-20: TensorFlow 0.8 introduced dynamic_rnn() that uses a symbolic loop instead of creating a sub graph for each time step. This results in a more compact graph. The function also expects and returns tensors directly, so we do not need to convert to and from Python-lists anymore.

Updated 2017-06-07: TensorFlow 1.0 moved recurrent cells into tf.contrib.rnn. From TensorFlow 1.2 on, recurrent cells reuse their weights, so that we need to create multiple separate GRUCells in the first code block. Moreover, I switched to using the existing implementation of the cross entropy loss which is numerically stable and has a more efficient gradient computation.

You can use this post under the open CC BY-SA 3.0 license and cite it as:

@misc{hafner2016tfrnnintro,
  author = {Hafner, Danijar},
  title = {Introduction to Recurrent Networks in TensorFlow},
  year = {2016},
  howpublished = {Blog post},
  url = {https://danijar.com/introduction-to-recurrent-networks-in-tensorflow/}
}