Self-attention enables transformer networks to track relationships between distant tokens — such as text characters — in long sequences, but the computational resources required grow quadratically with input size. New work aims to streamline the process by rating each token’s relevance to the task at hand.
What’s new: Sainbayar Sukhbaatar and colleagues at Facebook proposed Expire-Span, which enables attention to ignore tokens that aren’t useful to the task at hand.
Key insight: Depending on the task, some tokens affect a model’s performance more than others. For instance, in predicting the sentiment of the sentence, “Then she cried,” “cried” is more important than “then.” By forgetting less relevant tokens, attention can process longer sequences with less computation.
How it works: The authors modified a transformer’s attention layers. They trained the model in typical fashion to predict the next character in a sequence using the enwik8 dataset of text from English Wikipedia. Given the first token, it predicted the next. Then, using the first two tokens, it predicted the next, and so on.
- To each attention layer, the authors added a vanilla neural network that predicted the number of times that attention should use each token. It assigned a value to each new token, subtracted 1 after each prediction, and deleted the token when the value reached 0.
- The loss function minimized the number of times the model used each token to keep it from assigning arbitrarily high values (otherwise, it could predict that every token should be used until the whole sequence had been processed). In this way, the model learned to retain only the tokens most useful to an accurate prediction.
Results: The authors evaluated Expire-Span based on total memory usage, training time per batch, and bits per byte (a measure of how well the model predicted the next token; lower is better). On enwik8, it achieved 1.03 bits per byte, while Adaptive-Span achieved 1.04 bits per byte and compressive transformer achieved 1.05 bits per byte. The authors’ model used 25 percent less GPU memory than the other two approaches (15GB versus 20GB and 21GB respectively). It also took less time to train (408ms per batch of 512 tokens compared to 483ms and 838ms).
Why it matters: Forgetting the least relevant information enables transformers to process longer sequences in less time and memory.
We’re thinking: Q: What do you do if a transformer forgets too much? A: Give it an Optimus Primer.