Course:CPSC522/Scaling Memory for Transformers
Scaling Memory for Transformers
This article delves into the exciting realm of memory enhancement within the Transformers for Natural Language Processing (NLP). Within this exploration, we unveil the Recurrent Memory Transformer (RMT), a model that enriches memory by strategically incorporating additional memory tokens at the beginning and end of input and output sequences. Furthermore, we introduce the groundbreaking Receptance Weighted Key Value (RWKV) model, a revolutionary approach that seamlessly combines the processing prowess of the Transformers with the memory retention capabilities of Recurrent Neural Networks (RNNs), effectively bridging the gaps in scalability and efficiency often encountered in AI architecture.
First paper: Recurrent Memory Transformer
Second paper: RWKV: Reinventing RNNs for the Transformer Era
Principal Author: Amirhossein Abaskohi
Collaborators:
Abstract
The Transformers, a neural network architecture pivotal in Natural Language Processing (NLP), faces a challenge due to quadratic time complexity which arises when processing longer sequences, as the self-attention mechanism used in the Transformers requires pairwise comparisons between all tokens in the sequence. This article explores two breakthroughs: the Recurrent Memory Transformer (RMT) which adds memory tokens to the Transformer architecture, and the Receptance Weighted Key Value (RWKV) model, a novel architecture that combines the Transformer and Recurrent Neural Network (RNN) strengths. These innovations address scalability and efficiency, marking a significant step towards more efficient AI models for sequence processing tasks.
Builds on
This article builds on the foundation of the Transformer Models, specifically addressing the limitations associated with their quadratic time complexity for long sequences. It extends the principles of NLP to enhance memory efficiency in the Transformers, crucial for handling extensive textual data. Additionally, it leverages the advancements in Recurrent Neural Networks (RNNs) to tackle the scalability challenges faced by the Transformers, creating a bridge between these two neural network paradigms.
Related Pages
In the context of memory enhancement for the Transformers, this article is closely related to Neural Networks, Recurrent Neural Networks, Long Short-Term Memory Networks (LSTMs), NLP, and the Transformer Models. NLP serves as the backdrop for improving language-based tasks, while the Transformer Models form the core architecture under consideration. The article draws parallels with LSTMs in terms of memory efficiency, discusses broader implications within Neural Networks, and explores the integration of Recurrent Neural Networks to address scalability and memory challenges.
Content
Introduction
The Transformers have become the dominant architecture for NLP tasks in recent years. Models like BERT[2], GPT-3[3], and T5[4] have achieved state-of-the-art performance on a wide range of NLP benchmarks. However, the Transformers are limited by their quadratic time complexity, making it difficult to apply them to longer sequences. Given the critical role of information retrieval in search engines, it is imperative for language models to efficiently handle extended sequences. The self-attention mechanism used in the Transformers examines every pairwise interaction between input tokens, incurring high computational costs for large contexts. This restricts sequence lengths to a few thousand tokens in most Transformer implementations.
To enable the Transformers to ingest more context and handle longer sequences, recent work has focused on augmenting them with external memory. By incorporating memory modules and propagation techniques, the Transformer models can store information from lengthy contexts and utilize this memory when needed. This alleviates the complexity limitations of the Transformers, reconciling their performance benefits with improved efficiency and sequence handling capability.
This article examines two recent Transformer memory augmentation techniques that exemplify this direction: the Recurrent Memory Transformer (RMT)[5] and the Receptance Weighted Key Value (RWKV) model[6]. RMT uses token-based memory to reduce quadratic time complexity to linear, processing sequences over 1 million tokens on a single GPU. RWKV model enhances RMT by combining the parallelizable training of the Transformers with the efficient inference of Recurrent Neural Networks (RNNs), making it a more efficient and scalable architecture. It is the first non-transformer model to scale to tens of billions of parameters, matching the performance of similarly sized the Transformers. RWKV effectively mitigates the memory bottleneck that occurs within RMT due to the allocation of a predetermined number of tokens for memory storage. Furthermore, unlike RMT, RWKV maximizes the utilization of tokens in the input sequence, enabling the consideration of a larger number of tokens in a single step of the transformers. This optimization enhances both memory efficiency and processing speed.
Recurrent Memory Transformer (RMT)
To address the quadratic time complexity challenge in the Transformers when dealing with longer sequences, a concept emerged: running the Transformers on distinct segments of the input sequence and seamlessly integrating them to propagate knowledge across the entire model. While this approach does not solve the quadratic time complexity issue in self-attention algorithm, it significantly enhances processing efficiency by combining grouping tokens and knowledge transfer. This methodology, that is used in RMT, finds its roots in models such as the the Transformer-XL[7], which was the first to employ this technique. Transformer-XL extends the foundational Transformer architecture by introducing state re-use cache mechanisms for segment-level recurrence and incorporating relative position encoding. To fully grasp the concept behind RMT, it's essential to delve into the intricacies of the Transformer-XL.
The Transformer-XL model, addresses the challenge of processing sequences with long-range dependencies. It does so by splitting the input sequence into segments and processing them sequentially. A unique feature of Transformer-XL is the use of a state re-use cache mechanism. Hidden states computed for the previous segment are cached for each transformer layer. The input for each layer consists of the last 'm' states from the cached memory concatenated with the output of the previous Transformer layer for the current segment :
RMT draws inspiration from Transformer-XL and uses the concept of a Memory Transformer[8]. This innovative approach suggests incorporating global memory tokens for representation storage. In most cases, these memory tokens are appended to the start of the input sequence. However, in decoder-only architectures, memory tokens placed at the start of the sequence face limitations due to the causal attention mask, preventing them from gathering information from subsequent tokens. On the other hand, placing memory tokens at the end of the sequence restricts preceding tokens from accessing their representations. To overcome this challenge, the RMT introduces a novel approach. It augments the input with special tokens, which are processed alongside the sequence of tokens. These memory tokens, real-valued vectors, are added at both the beginning and end of the segment tokens. This arrangement enables the representation of memory tokens from the previous segment to be used at the start and end of the next segment:
In this equation, represents the number of the Transformer layers. Memory tokens serve as both read and write memory, facilitating interaction between sequence tokens and memory states.
In RMT, memory tokens play a pivotal role in processing sequential data efficiently. Divided into two key groups, these tokens provide an effective way for the model to manage information flow:
- Read Memory: The starting group of memory tokens acts as a read-only memory, enabling sequence tokens to access previous segment memory states. This enhances the model's contextual understanding.
- Write Memory: The ending group functions as a write-only memory, allowing updates to memory representations based on the current segment. This dynamic storage creates , housing updated memory tokens for segment .
The RMT processes input segments sequentially, and it achieves inter-segment connectivity by passing memory token outputs from the current segment to the next:
Both RMT and Transformer-XL models rely on Backpropagation Through Time (BPTT) for training. Unlike Transformer-XL, in which memory gradients are stopped between segments during the backward pass, RMT allows gradients to flow across segments. The choice of how many previous segments to backpropagate is determined by a hyperparameter, which can range from 0 to 4. Increasing this parameter entails higher computational demands and GPU memory usage, but techniques like gradient checkpointing can alleviate these challenges.
Receptance Weighted Key Value (RWKV) Model
The RWKV architecture, named after its four fundamental elements—Receptance (R), Weight (W), Key (K), and Value (V)—represents a novel approach to sequential data processing. It consists of a series of stacked residual blocks, each comprising time-mixing and channel-mixing sub-blocks with recurrent structures. RWKV combines strengths of RNNs and the Transformers. It replaces standard dot-product self-attention with a linear attention mechanism[9] that directs interactions along channels rather than between tokens. This avoids the quadratic time complexity of regular self-attention.
RWKV also incorporates recurrence and time-decaying factors that help capture temporal relationships in sequential data. The overall structure exhibits similarities to causal convolutions in WaveNet [10]and quasi-recurrent networks[11]. These components allow RWKV to process information like an RNN while retaining the Transformer-style parallelized training.
The RWKV architecture consists of a series of stacked residual blocks, each comprising time-mixing and channel-mixing sub-blocks with recurrent structures. This recurrence is achieved through linear interpolation between the current input and the input at the previous time step, referred to as "time-shift mixing" or "token shift." Notably, the level of time-shift mixing can be independently adjusted for various linear projections of the input embedding, such as R, K, and V in time-mixing, as well as R and K in channel-mixing. Additionally, the time-dependent update of the is formalized as below, which resembles the AFT (Attention Free Transformer) model[12], although with a distinction: is now a channel-wise vector multiplied by relative position instead of a pairwise matrix as in AFT. Furthermore, a vector is introduced to attend separately to the current token, mitigating potential issues with . One of these issues is quadratic scaling issue due to the pairwise matrix , which becomes problematic for long sequences. vector resolves this issue by attending separately to the current token, reducing memory and computational demands. The time-mixing sub-block is defined as follows:
- Receptance (R) is represented as:
- Weight (W) is denoted as:
- Value (V) is expressed as:
Additionally, the computation for weighted summation, , is utilized to update the Value vector, where:
In the channel-mixing sub-block, the following formulas are applied:
- Receptance (R):
- Key (K):
- Output ():
Notably, taking the [1] of the Receptance in both time-mixing and channel-mixing serves as a "forget gate" to eliminate unnecessary historical information.
RWKV introduces an efficient method for parallelization, which we refer to as the "time-parallel mode." This mode draws inspiration from the parallel processing techniques commonly seen in the Transformer architectures.
In this mode, the time complexity of processing a batch of sequences in a single layer is analyzed. Specifically, the complexity is denoted as , where represents the number of sequences in the batch, signifies the maximum number of tokens, and corresponds to the number of channels. The bulk of this complexity stems from matrix multiplications involving the weight matrices , with subscripts specifying different functions (e.g., r, k, v, o).
Notably, updating the attention scores, as indicated by , necessitates a serial scan which carries a complexity of .
The matrix multiplications, akin to those seen in the typical Transformers, can be efficiently parallelized, particularly the element-wise computation. While this computation is time-dependent, it can be seamlessly parallelized along the other two dimensions[13].
One notable advantage of the RWKV architecture lies in its ability to embrace a "time-sequential mode," reminiscent of recurrent neural networks (RNNs). This mode becomes particularly valuable in scenarios such as autoregressive decoding, a common operation in language modeling where each token is computed before being fed into the subsequent step.
RWKV's RNN-like allows for a convenient and recursive formulation during decoding. This recursive approach takes advantage of a key characteristic: each output token's determination is solely dependent on the latest state, maintaining a constant size, irrespective of the sequence length. This feature sets RWKV apart from traditional self-attention mechanisms.
In practical terms, RWKV acts as an RNN decoder during this time-sequential mode, resulting in a constant processing speed and memory footprint relative to the sequence length. This translates to more efficient processing of longer sequences. In contrast, the typical self-attention mechanisms often require a key-value (KV) cache that grows linearly with the sequence length. This growth leads to reduced efficiency, increased memory demands, and longer processing times as the sequence expands.
The RWKV architecture's ability to deliver consistent performance, regardless of sequence length, positions it as a promising solution for tasks that involve lengthy sequences, offering efficiency and practicality that aligns with the demands of real-world applications.
Experiments demonstrate RWKV scales well, reaching up to 14 billion parameters while exhibiting strong performance on NLP benchmarks. The architecture displays comparable proficiency to the similarly-sized Transformers on tasks like Winogrande[14], PIQA[15], and LAMBADA[16] while requiring lower training and inference costs. Tests also reveal increasing the context length yields lower language modeling perplexity, suggesting RWKV can effectively utilize long-range information.
RWKV Contribution over RMT
Both the RWKV and RMT architectures were developed with the shared goal of enabling the Transformers to handle longer sequences efficiently. They share a common ideological foundation that revolves around the concept of transferring memories step by step. However, the way they implement this idea differs significantly, resulting in distinct advantages for each model.
In RMT, the memory transfer process occurs by breaking the input sequence into spans or segments, with memory tokens assigned to each segment. While this approach has been successful in many cases, it can sometimes lead to suboptimal performance on shorter sequences, where the benefits of memory tokens might not be fully realized. Conversely, RWKV operates seamlessly across a wide range of sequence lengths, performing nearly on par with standard the Transformers in various scenarios, and excelling in handling longer sequences.
One noteworthy difference between the two models lies in their approach to mitigating the quadratic time complexity issue inherent to the Transformers. RMT attempts to alleviate this complexity by computing spans of text and connecting them using memory tokens. While this technique is effective in certain cases, it does not completely resolve the quadratic time complexity challenge. RWKV, on the other hand, introduces a linear attention mechanism, which significantly enhances computational efficiency and ensures optimal performance across different sequence lengths.
Furthermore, RWKV stands out by adopting a Backpropagation Through Time (BPTT) approach for all tokens in the sequence. This enables the model to maintain a comprehensive historical perspective for each token, contributing to more effective learning and information retention. In contrast, RMT employs a similar approach solely for memory tokens which makes RWKV a more powerful model in longer sequences.
Conclusion
A key challenge in scaling up language models is balancing efficiency and performance. The Transformers have become dominant due to their representational capacity but are limited by quadratic time complexity. RNNs are more efficient yet restricted in learning complex dependencies. Recent innovations like RMT[5] and RWKV[6] offer a middle ground by improving the Transformer efficiency while retaining benefits like parallel training. RWKV introduces an innovative hybrid of the Transformers and RNNs using linear attention mechanisms to improve computational complexity and memory requirements. RMT incorporates token-based memory to achieve linear scaling. The RWKV and RMT architectures demonstrate promising techniques to scale up language models while balancing performance and efficiency. Both methods exemplify the potential for architectural advances to reconcile the exponential growth in model size with practical constraints like hardware limitations and carbon footprint. Although the Transformers currently dominate NLP, their quadratic time complexity will eventually become prohibitive. Techniques like RWKV and RMT offer a path towards the next generation of efficient yet accurate large language models needed to push forward progress in tasks relying on complex contextual understanding and reasoning over lengthy sequences.
Annotated Bibliography
- ↑ Vaswani A, Shazeer N, Parmar N, Uszkoreit J, Jones L, Gomez AN, Kaiser Ł, Polosukhin I. Attention is all you need. Advances in neural information processing systems. 2017;30.
- ↑ Devlin, J., Chang, M.-W., Lee, K., & Toutanova, K. (2019). "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding." In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers) (pp. 4171-4186). Association for Computational Linguistics. DOI: 10.18653/v1/N19-1423
- ↑ Brown, T., Mann, B., Ryder, N., Subbiah, M., Kaplan, J. D., Dhariwal, P., Neelakantan, A., Shyam, P., Sastry, G., Askell, A., & others (2020). "Language models are few-shot learners." Advances in Neural Information Processing Systems, 33, 1877-1901.
- ↑ Raffel C, Shazeer N, Roberts A, Lee K, Narang S, Matena M, Zhou Y, Li W, Liu PJ. Exploring the limits of transfer learning with a unified text-to-text transformer. The Journal of Machine Learning Research. 2020 Jan 1;21(1):5485-551.
- ↑ 5.0 5.1 5.2 5.3 5.4 Bulatov A, Kuratov Y, Burtsev M. Recurrent memory transformer. Advances in Neural Information Processing Systems. 2022 Dec 6;35:11079-91.
- ↑ 6.0 6.1 6.2 6.3 Peng B, Alcaide E, Anthony Q, Albalak A, Arcadinho S, Cao H, Cheng X, Chung M, Grella M, GV KK, He X. RWKV: Reinventing RNNs for the Transformer Era. arXiv preprint arXiv:2305.13048. 2023 May 22.
- ↑ Dai Z, Yang Z, Yang Y, Carbonell J, Le QV, Salakhutdinov R. Transformer-xl: Attentive language models beyond a fixed-length context. arXiv preprint arXiv:1901.02860. 2019 Jan 9.
- ↑ Burtsev MS, Kuratov Y, Peganov A, Sapunov GV. Memory transformer. arXiv preprint arXiv:2006.11527. 2020 Jun 20.
- ↑ Li R, Su J, Duan C, Zheng S. Linear attention mechanism: An efficient attention for semantic segmentation. arXiv preprint arXiv:2007.14902. 2020 Jul 29.
- ↑ Oord AV, Dieleman S, Zen H, Simonyan K, Vinyals O, Graves A, Kalchbrenner N, Senior A, Kavukcuoglu K. Wavenet: A generative model for raw audio. arXiv preprint arXiv:1609.03499. 2016 Sep 12.
- ↑ Bradbury J, Merity S, Xiong C, Socher R. Quasi-recurrent neural networks. arXiv preprint arXiv:1611.01576. 2016 Nov 5.
- ↑ Zhai S, Talbott W, Srivastava N, Huang C, Goh H, Zhang R, Susskind J. An attention free transformer. arXiv preprint arXiv:2105.14103. 2021 May 28.
- ↑ Tao Lei, Yu Zhang, Sida I. Wang, Hui Dai, and Yoav Artzi. 2018. Simple Recurrent Units for Highly Parallelizable Recurrence. In Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing, pages 4470–4481, Brussels, Belgium. Association for Computational Linguistics.
- ↑ Sakaguchi K, Bras RL, Bhagavatula C, Choi Y. Winogrande: An adversarial winograd schema challenge at scale. Communications of the ACM. 2021 Aug 24;64(9):99-106.
- ↑ Bisk Y, Zellers R, Gao J, Choi Y. Piqa: Reasoning about physical commonsense in natural language. InProceedings of the AAAI conference on artificial intelligence 2020 Apr 3 (Vol. 34, No. 05, pp. 7432-7439).
- ↑ Paperno D, Kruszewski G, Lazaridou A, Pham QN, Bernardi R, Pezzelle S, Baroni M, Boleda G, Fernández R. The LAMBADA dataset: Word prediction requiring a broad discourse context. arXiv preprint arXiv:1606.06031. 2016 Jun 20.
To Add
Put links and content here to be added. This does not need to be organized, and will not be graded as part of the page. If you find something that might be useful for a page, feel free to put it here.
|