Blog

ProjectBlog: Day 12: Generating Long Sequences with Sparse Transformers

Blog: Day 12: Generating Long Sequences with Sparse Transformers


[Apr 23, 2018] Model very long sequences in a compute/memory efficient way

I … can … do … THIS

TL-DR

This paper presents a variation on architecture an initialization, recomputation of attention matrices (to save memory) and fast attention kernels to be able to model tens of thousands of timesteps while at the same time, reducing the computation time and memory from polynomial time to O(n*sq(n)).

The Problem

The generic problem for any sequence model (autoregressive sequence generation) is:

Conditional probability distributions and parameterized by a network θ

Images, text and audio can be treated in the same way.

The training objective is to maximize the log-probability of the data with respect to θ.

The Transformer and Attention

A Transformer in decoder-only mode can be used for sequence generation.

The self-attention portion of the network must compute n weightings for each of n elements, however, which can quickly become intractable as the sequence length grows.

How to solve this problem so large sequences can be used?

Factorized Self-Attention

Let’s first see what a Transformer with the traditional attention looks like. Layers are able to learn specialized sparse structures (image below, see white pixels):

In a) We can see some of the early layers in the network which learn locally connected patterns that resemble a convolution around the predicted pixel

In b) (layers 19 and 20) the network learned to split the attention across a row attention and column attention, effectively factorizing the global attention calculation.

In c) we can see that several attention layers showed global, data-dependent access patterns.

In d), typical layers in layers 64–128 exhibited high sparsity, with positions activating rarely and only for specific input patterns.

A key insight here is that since most layers had sparse attention patterns across most data points it is plausible that some form of sparsity could be introduced in the attention computation without significantly affecting performance.

However, notice that several layers (Figure 2c) exhibited global patterns, and others exhibited data-dependent sparsity (Figure 2d), both of which would be difficult to model with a predetermined sparsity pattern and this might affect performance.

Learned attention patterns from a 128-layer network on CIFAR-10 trained with full attention. White highlights denote attention weights for a head while generating a given pixel, and black denotes the auto-regressive mask.

The new attention scheme is somewhat different than the traditional one as we can see in the image below.

The top row indicates, for an example 6×6 image, which positions two attention heads receive as input when computing a given output. The bottom row shows the connectivity matrix (not to scale) between all such outputs (rows) and inputs (columns).

Sparsity in the connectivity matrix can lead to significantly faster computation. In (b) and (c), full connectivity between elements is preserved when the two heads are computed sequentially.

2D factorized attention schemes evaluated by the authors in comparison to the full attention of a standard Transformer (a)

The authors restricted their investigation to a class of sparse attention patterns that have connectivity between all positions over several steps of attention. To that we turn now.

Deep into the algorithm

A self-attention layer maps an input embedding into an output matrix according to a connectivity pattern:

where S_i denotes the set of indices of input vectors to which the ith output vector must attend. So attention is computed in the following way, consistent with the scaled-dot product attention in the Transformer architecture:

Instead of attending to all the values before the predicted one, the algorithm the authors propose tries to do it more efficiently. In the p different attention heads, the mth head attends to a subset of the previous values; this is defined by the matrix A. In particular, the authors aim to attend to an efficient subset so that:

The number of values in each head should scale proportionally to the p root of n, where m is a specific head, i is a specific output, p is the number of heads and n the number of inputs

In addition, the authors looked for choices of A so that all inputs are connected to all future output positions across the p attention steps.

In other words, in every j ≤ i pair the authors set every A such that i can attend to j through a path of locations with maximum length p+1. Specifically, if (j, a, b, c, …, i) is the path of indices:

All inputs are connected to all outputs across the p attention steps

These two criteria allow us keep the ability of Transformers to propagate signals from arbitrary input positions to arbitrary output positions in a constant number of steps, while reducing the total effective computation to O(n √p n).

Two-dimensional factorized attention

Strided

Factorized attention in two dimensions is trickier than one dimension. A reasonable approach, if trying to predict a pixel in an image, to roughly attend to the row and column of the pixel to predict. We called this strided attention and we can see it in action in the following image.

Flexible two-dimensional factorized attention

Note: if you are like me, you will ask ‘why does the first representation of this attention method show that it is attending to five pixels to the left while the second representation shows only three? Is it the same method?’. I don’t understand this either, if any of you does please let me know!

The mathematical representation of this approach is the following. The first head attends to the l previous pixels (in this example, 5 pixels).

First head

The second head attends to 1 pixel every l pixels. If l is the width of the image in pixels then the second head attends to the column above the pixel to be predicted.

Second head

Fixed

This approach works well for data that aligns with the stride but not for non-periodic data. For data without a periodic structure like text, this approach is not ideal since spatial coordinates do not necessarily correlate with the positions where the element may be most relevant in the future. Here it is better to use a fixed attention pattern.

Fixed two-dimensional factorized attention

The mathematical representation of this approach is the following. The first head attends to all the previous pixels in the chunk where the size of each chunk is defined by l.

First head

The second head attends to a fixed-size portion of the last segment of each of the previous chunks. In the example presented in the figure above, the fixed-size portion has size 1 (this is c, the hyperparameter).

Second head

As explained by the authors themselves:

(…) if the stride is 128 and c = 8, then all future positions greater than 128 can attend to positions 120–128, all positions greater than 256 can attend to 248–256, and so forth.

The authors used c ∈{8, 16, 32} and l ∈{128, 256} for their experiments. Additionally, they found that letting the network attend to different c-sized blocks in different heads was preferable to having them attend to the same subblock.

Sparse Transformer

Factorized-attention heads

Standard dense attention performs a linear transformation of the attend function:

However remember that in the examples we saw earlier, we have more than one attention computation/head. How do we combine them? There are a few approaches to this.

The first one entails interleaving the results of each head in a given ratio which might be a parameter. The second one is to use only one head but have it attend to the locations of pixels that both heads would attend to (merged head). The third one is to use multi-head attention (like in Attention is All You Need). The authors chose to use this last attention method for their experiments.

Scaling to hundreds of layers

To make the Transformer easier to be trained, the authors adopted a few architectural changes:

First, they adopted Kaiming He et al.’s pre-activation residual block, defining a network of N layers in the following way:

Skip-connection and resblock

Where each resblock is computed as follows:

Computation of a resblock

An important observation is that each resblock receives a gradient directly from the output layer since it the output is the sum of N applications of a and b.

Modeling diverse data types

We found using learned embeddings which either encoded the structure of the data or the factorized attention patterns were important for performance of our models.

We added either n_emb=d_data or n_emb=d_attn embeddings to each input location, where d_data refers to the number of dimensions of the data, and d_attn is the number of dimensions of the factorized attention.

For images the authors used data embeddings where d_data = 3 for the row, column and channel dimensions and for audio and text the authors used d_data = 2 where the index corresponds to each position’s row and column index in a matrix of width equal to the stride.

Other improvements

Gradient check-pointing: gradient check-pointing is specially useful for self-attention layers where long sequences entail a high-memory usage relative to the cost of computing them. This enables the authors to train on sequences as long as 16,384.

Efficient block-sparse attention kernels: the sparse attention masks for strided and fixed attention can be computed by slicing-out parts of the queries, keys and values matrices and computing the product in blocks. Also, the upper triangle in the attention matrix is never computed, halving the number of operations.

Mixed-precision training: the weights are stored in single-precision floating point but the network activations and gradients are otherwise computed in half-precision (this greatly accelerates training).

Training

The authors list a series of implementation details that where important for the model’s performance, if interested please refer to the paper.

Experiments

As explained before, the authors experimented in images, audio and text data.

Model performance on the different data types
Model timings and results for different attention methods

CIFAR-10

The model achieved 2.80 bits per dim (average cross-entropy) improving state of the art by ~2%.

The authors tried the different attention methods to see which worked better. In this case, the strided attention worked best which makes sense since the column of pixels the model attends to is dynamic and not fixed and this allows it to change the column on convenience.

Text (EnWik8)

This dataset represents the first 10⁸ bytes of Wikipedia and contains a great degree of variability in periodic structure. The best model achieved a bits per byte of 0.99 with the highest context length during evaluation 12,160 tokens (this suggests that the model is using long-term dependencies for its predictions). This result matches a model trained with more than double the parameters. Here fixed attention worked better than strided attention since the data is naturally one-dimensional.

Model performance on EnWik8 with different context lengths

ImageNet 64×64

The model achieved a loss of 3.44 bits per dimension in comparison to the previous 3.52 (SoTA). Additionally the authors generated some uncoditional samples which show that the model learns long-term structure in most images and there are no sparsity patterns which suggests that the sparsity introduced into the attention matrices does not greatly impact performance.

Unconditional image samples generated from ImageNet. Images show long-term structure .

Classical music from raw audio

The authors achieved a 1.97 bits per byte performance on 65,536-long sequences.

The samples show global coherence and can be found here.

Performance on classical audio dataset. The number of parameters and the sequence length faced a trade-off because of a fixed memory in the GPU (16GB).

This same model was used to create MuseNet, a beautiful audio generating project which generates different styles of audio with success. The project can be found here.


References

Source: Artificial Intelligence on Medium

Leave a Reply

Your email address will not be published. Required fields are marked *

Back To Top
a

Display your work in a bold & confident manner. Sometimes it’s easy for your creativity to stand out from the crowd.

Social