Learning to (Learn at Test Time): RNNs with Expressive Hidden States¶
ArXiv: 2407.04620
🎯 Pitch¶
The paper introduces Test-Time Training (TTT) layers, which make the hidden state a small, learnable model (W_t) that is updated by gradient steps online as a sequence is processed, yielding sequence-modeling layers with linear per-token complexity. By compressing context via learned inner-loop updates (TTT-Linear and TTT-MLP), the approach preserves RNN-style efficiency while substantially improving long-context utilization—maintaining perplexity gains as context grows where modern RNNs plateau—making it a practical path toward scalable long-context models.
1. Executive Summary (2-3 sentences)¶
This paper introduces Test-Time Training (TTT) layers, a sequence-modeling layer family with linear-time per-token complexity whose hidden state is not a fixed vector but the weights W_t of a small model that is updated by gradient descent while processing the sequence, even at test time (Eq. (2), Figure 1). The central significance is a practical framework that aims to combine the hardware- and asymptotic-efficiency of RNN-style recurrence with improved long-context utilization, addressing the observed plateau of modern RNNs (e.g., Mamba) beyond ~16k tokens in their evaluation (Figure 2 right).
2. Context and Motivation¶
- Problem / gap
- Self-attention is effective in long-context modeling but has quadratic cost in context length because each token attends over all previous tokens (Figure 3, “Self-attention” row; discussion in §1 and §2).
-
RNN layers have linear cost but must compress all past context into a fixed-size hidden state, which can limit long-context performance (§1, §2).
-
Why important
- The paper argues that the practical advantage of linear-time models only becomes meaningful at long context (they reference “after 8k” in their setting; §1, Figure 12 discussion).
-
In long contexts, one wants perplexity to keep improving as more conditioning information is available; their evaluation shows this happens for Transformers but not for Mamba after ~16k (Figure 2 right; §1).
-
Prior approaches and shortcomings (as positioned by the paper)
- Classic RNNs (e.g., LSTMs) were previously observed not to scale like Transformers (cited as Kaplan et al.; §1).
- Modern RNNs (Mamba) improve scaling at moderate context, but still show a long-context plateau in their token-index perplexity diagnostic (Figure 2; §1).
-
Linear-attention / “fast weight” style methods exist; the paper later shows a formal equivalence between a particular TTT instantiation and
linear attention(Theorem 1, §2.6). -
How this paper positions itself
- It reframes sequence modeling layers via a unifying “hidden state + update rule + output rule” lens (Figure 3).
- It proposes a framework where the hidden state can be an arbitrary learnable model (not just a vector or matrix), trained online via a self-supervised objective whose form is itself learned in an outer loop (§2.1–§2.3).
- It contributes systems techniques (mini-batch token updates and a “dual form”) to make this tractable on accelerators (§2.4–§2.5).
3. Technical Approach¶
3.1 Reader orientation (approachable technical breakdown)¶
- The system is a new kind of recurrent sequence layer whose internal memory is a small model (e.g., a linear map or a 2-layer MLP) that is trained online as the sequence is processed.
- It solves long-context modeling by compressing past tokens into the model’s weights using self-supervised learning updates, while keeping per-token recurrence rather than attention’s growing KV cache (Figure 1, Figure 3; Eq. (2)).
3.2 Big-picture architecture (diagram in words)¶
- Input token representation
x_tenters a TTT layer. - The layer maintains a hidden state = learner state, mainly the current weights
W_t(and potentially optimizer state; §2.6). - For each token (or token mini-batch):
- Compute a self-supervised loss (multi-view reconstruction) using learnable projections (
θ_K,θ_V) (Eq. (4), §2.3). - Take (mini-batch) gradient-descent-style updates to update
W(Eq. (6), §2.4). - Produce the layer output using a test-view projection
θ_Qand the updated modelf(·; W_t)(Eq. (5), §2.3). - The TTT layer is used inside a standard LM training setup (next-token prediction outer objective), optionally with a Mamba-style backbone (temporal convolutions + gating) rather than a Transformer block (§2.7, Figure 13).
3.3 Roadmap for the deep dive¶
- First, define the TTT hidden state as model weights and the basic update/output rules (Eq. (1)–(3), Figure 1).
- Next, explain how the self-supervised task is made learnable via multi-view projections optimized in the outer loop (Eq. (4)–(5), Figure 5).
- Then cover the two efficiency techniques:
- Mini-batch TTT (parallelize gradients within token blocks) (§2.4, Figure 6–7).
- The dual form (rewrite computations to use matmuls, avoid materializing per-token gradients and intermediate weights) (§2.5, Eq. (7)–(8)).
- Finally, connect the framework to known mechanisms:
- Equivalence to linear attention (Theorem 1, Table 1).
- Equivalence to self-attention via a nonparametric learner (Theorem 2).
- Close with concrete instantiations:
TTT-LinearandTTT-MLPand their implementation choices (§2.7).
3.4 Detailed, sentence-based technical breakdown¶
Framing sentence (type of paper + core idea).
This is an algorithm + systems + empirical scaling paper whose core idea is to turn the recurrent hidden state into a trainable model updated by self-supervised gradient steps during the forward pass, so that the layer “learns at test time” by construction (Eq. (2), Figure 1, §2.1).
3.4.1 Unifying view: sequence layer = hidden state + update rule + output rule¶
- The paper treats any autoregressive sequence layer as maintaining a hidden state
s_tthat evolves via an update rule and produces outputs via an output rule (Figure 3 top). - In this lens:
- A “naive RNN” keeps a fixed-size vector state and updates it with a parametric recurrence (Figure 3 bottom).
- Self-attention keeps an ever-growing list state (KV cache), making per-token cost grow with time index
t(Figure 3 bottom). - A “naive TTT” keeps a fixed-size state too—but that state is the parameter vector/matrix of a model
f(Figure 3 bottom).
3.4.2 Core TTT mechanism: hidden state as weights W_t, update rule as learning¶
- Hidden state definition. The hidden state at time
tisW_t, the parameters of a “learner model”f(Figure 1; §2.1). - Output rule. The layer output at token
tis a prediction produced by the current learner: - Base form:
z_t = f(x_t; W_t)(Eq. (1)). - After introducing views (below):
z_t = f(θ_Q x_t; W_t)(Eq. (5)). - Update rule. The recurrence update is a gradient step on a self-supervised loss:
W_t = W_{t-1} - η ∇ℓ(W_{t-1}; x_t)(Eq. (2)).- Interpretation. Because this update happens while processing a sequence during inference too, the layer performs training at test time on the test sequence (§2.1).
3.4.3 Self-supervised objective: from naive reconstruction to learned multi-view reconstruction¶
- Naive reconstruction. A straightforward self-supervised task is to reconstruct
x_tfrom a corrupted versionx̃_t: ℓ(W; x_t) = || f(x̃_t; W) - x_t ||^2(Eq. (3)).- Learned task via multi-view projections.
- The paper introduces views—learnable low-rank linear projections that define what the learner sees and what it must predict (§2.3).
- A training view:
θ_K x_t. - A label view:
θ_V x_t. - A test view:
θ_Q x_t. - The inner-loop loss becomes:
ℓ(W; x_t) = || f(θ_K x_t; W) - θ_V x_t ||^2(Eq. (4)).
- The output becomes:
z_t = f(θ_Q x_t; W_t)(Eq. (5)).
- Outer vs inner parameters (critical distinction).
W_tis not an outer-loop parameter; it is a sequence-specific hidden state updated per example/sequence.θ_K, θ_V, θ_Q(and laterθ_init,θ_lr) are outer-loop parameters trained with the standard LM objective (next-token prediction) (§2.2–§2.3, Table 2).- Code-level picture.
- Figure 5 shows a conceptual implementation where
Taskcontainsθ_K, θ_V, θ_Q, whileLearnercontains the inner-loop model and optimizer, andTTT_Layer.forward()iterates through tokens callingtrain()thenpredict().
3.4.4 Why learning updates are plausible as “compression”¶
- The paper motivates the update rule as a compression heuristic: tokens that generate large gradients cause larger updates, so the hidden model’s weights preferentially encode information that is “learn-worthy” under the self-supervised objective (§2.1).
- Empirically, they show the TTT loss improves over time within test sequences:
- In Figure 4 (and expanded in Figure 14),
ℓ(W_t; x_t)is lower thanℓ(W_{t-1}; x_t)after one gradient step, and performance relative to the initializationW_0improves further along the sequence (Figure 4 caption; §2.1).
3.4.5 Making TTT trainable end-to-end: backprop through inner-loop updates¶
- Training a model containing TTT layers uses standard outer-loop next-token prediction, but gradients must flow through the inner-loop update computations.
- The paper notes that although the forward pass uses a gradient operator
∇internally, this still corresponds to differentiable computations; backprop “through gradients” is described as taking “gradients of gradients,” connecting to meta-learning ideas (§2.2). - The paper’s terminology:
- Inner loop: updates to
Winside the TTT layer via∇ℓ(§2.2, Table 2). - Outer loop: standard LM training updating
θ_restand also the TTT task parametersθ_K, θ_Q, θ_V(and others) (§2.2–§2.3, Table 2).
3.4.6 Systems/efficiency technique #1: mini-batch TTT (token blocks)¶
- Problem: Online GD updates
W_tstrictly depend onW_{t-1}inside the gradient, which is sequential and hard to parallelize (§2.4). - Key observation: Gradient descent variants can be written as
W_t = W_{t-1} - η G_t = W_0 - η Σ_{s=1}^t G_s(Eq. (6)), so onceG_tare available,W_tcan be obtained via prefix sums (“cumsum”). - Mini-batch gradient descent across tokens.
- Divide the sequence into mini-batches of tokens of size
b. - For tokens within a mini-batch, compute gradients with respect to the same reference weights (the weights at the start of the mini-batch), enabling parallel gradient computation (§2.4, Figure 6).
- This yields a speed–quality trade-off:
- Smaller
bis closer to online GD (more steps, better perplexity). - Larger
bis closer to batch GD (more parallelism, worse perplexity).
- Smaller
- They choose
b = 16for all experiments (Figure 7; §2.4).
3.4.7 Systems/efficiency technique #2: the dual form (matmul-friendly rewrite)¶
- Problem: Even with mini-batches, the “primal” computation involves many per-token outer products and materialization of
G_tandW_t, which is inefficient on accelerators and heavy in memory I/O (§2.5). - Dual-form idea: Do not materialize intermediate gradients
G_1…G_bor intermediate weightsW_1…W_{b-1}inside a mini-batch; instead compute: - The end-of-batch weights
W_b, and - The batch of outputs
z_1…z_b, using a small number of large matrix multiplications (§2.5, Figure 6). - Concrete derivation in the simplified linear case (§2.5).
- With
f(x;W)=Wxand a simple squared reconstruction loss, they show:W_b = W_0 - 2η (W_0 X - X) X^TwhereX=[x_1,…,x_b](derivation around Eq. (7)).
- For outputs, they derive a masked formulation:
- Define a masked triangular accumulation using
mask(X^T X)(Eq. (8)), analogous to causal masking in attention but with zeros instead of-∞.
- Define a masked triangular accumulation using
- Complexity trade-off:
- Dual form uses
O(b d^2)for end-of-batch weights and addsO(b^2 d)to compute all outputs in the batch (§2.5). - The paper argues this is acceptable because they pick small
b(16) and typicaldis a few hundred, and it improves wall-clock time substantially on TPU: “more than 5× faster” in JAX (§2.5).
- Dual form uses
- Extension to MLP learners.
- Appendix A generalizes the dual form to multi-layer MLPs using standard forward/backprop quantities like
∇Z^k land masked matmuls.
3.4.8 Theoretical equivalences: how TTT unifies RNN/attention constructions¶
- Equivalence to linear attention (parametric learner).
- Theorem 1 (§2.6) states: with a linear inner model
f(x)=Wx, batch GD,η=1/2, andW_0=0, the TTT output matches linear attention. - The proof rewrites the batch gradient update to show
W_tbecomes a running sum of outer products of projected values and keys, and the output becomes the standard linear-attention form (§2.6). - Table 1 empirically verifies the equivalence (“Linear attn. improved” and “TTT equivalence” both have perplexity
15.23, diff0). - Equivalence to self-attention (nonparametric learner).
- Theorem 2 (§2.6) shows that if the learner is a Nadaraya–Watson estimator with an exponential kernel
κ(x,x') ∝ exp((θ_K x)^T θ_Q x'), then the induced TTT layer corresponds to self-attention. - This reframes attention as a particular “learner” that stores all past data points and predicts via kernel-weighted averaging (§2.6; Appendix B elaborates Nadaraya–Watson).
3.4.9 Final instantiations: TTT-Linear and TTT-MLP, plus stabilizers¶
- Two proposed variants (§2.7).
TTT-Linear: the learnerf_lin(x)=W xwith squareW.TTT-MLP: the learner is a 2-layer MLP with hidden dimension4×input dimension andGELUactivation (§2.7).- Stability structure inside
f. - They wrap learners with residual + LayerNorm:
f(x) = x + LN(f_res(x))wheref_resis linear or MLP (§2.7).
- Table 1 shows adding “LN and residual in
f” yields a large perplexity improvement (15.27 → 14.05, improvement−1.22) on their 125M ablation path. - Learnable initialization
W_0. - They learn
W_0as an outer-loop parameterθ_initto improve training stability (§2.7). - Table 1 indicates it slightly hurts perplexity in isolation (
15.23 → 15.27) but is needed for stable training of later improvements. - Learnable inner learning rate
η. - They learn a token-dependent gate:
η(x) = η_base σ(θ_lr · x)withη_base=1(TTT-Linear) and0.1(TTT-MLP) (§2.7).
- Backbone choice.
- Instead of directly swapping attention in a Transformer block, their strongest versions use a Mamba-style backbone with temporal convolutions and gating (§2.7, Figure 13).
- They also report ablations using a Transformer backbone (Figures 10–11 show both “(M)” and “(T)” variants).
3.4.10 Worked micro-example (single-step intuition without full LM scale)¶
To make the mechanism concrete, consider one time step of a simplified TTT layer.
- Let the token embedding be
x_t ∈ R^d. - Choose projections
θ_K, θ_V, θ_Qso that: - training input:
x̂_t = θ_K x_t - training label:
y_t = θ_V x_t - test input:
x̄_t = θ_Q x_t - Let the learner be linear:
f(u;W)=W u. - The self-supervised loss at time
tis: ℓ(W; x_t) = || W x̂_t - y_t ||^2(Eq. (4) specialized).- A single inner-loop update (online GD form) is:
W_t = W_{t-1} - η ∇_W || W x̂_t - y_t ||^2(Eq. (2)).- The output for the sequence model at this time is:
z_t = W_t x̄_t(Eq. (5) specialized).- Intuition: if the current token has structure that helps predict
y_tfromx̂_t, the update will adjustWso that future tokens—processed with the updatedW—can reuse that learned structure, acting like an adaptive memory.
4. Key Insights and Innovations¶
- (1) Hidden state as an explicit learner trained online (TTT layers).
- Novelty: Instead of a vector/matrix state updated by a fixed recurrence, the hidden state is the weights of a model updated by gradient-based learning on self-supervised loss (Eq. (2), Figure 1).
-
Significance: This increases the expressive capacity of what an RNN-style layer can store, aiming to better compress long context (§1, §2.1).
-
(2) Learnable self-supervised task via multi-view reconstruction.
- Novelty: The reconstruction task is not handcrafted; it is parameterized by
θ_K, θ_V, θ_Qand optimized in the outer loop for next-token prediction (§2.3, Eq. (4)–(5), Table 2). -
Significance: This aligns the inner-loop adaptation with the final LM objective, rather than relying on human-designed corruptions/tasks (§2.3).
-
(3) Mini-batch token updates to trade off parallelism and adaptation quality.
- Novelty: The paper introduces mini-batch gradient computations across tokens inside the sequence to expose parallel computation while retaining multi-step adaptation across mini-batches (§2.4, Figure 6–7).
-
Significance: This is a practical bridge between purely online updates (slow, sequential) and batch updates (parallel but too limited), with an empirically chosen sweet spot
b=16(Figure 7). -
(4) Dual form rewrite to make TTT matmul-friendly on accelerators.
- Novelty: Within a token mini-batch, they derive an equivalent computation that avoids materializing per-token gradients and intermediate weights, using masked matmuls (Eq. (7)–(8), §2.5; Appendix A for MLPs).
-
Significance: The paper reports >5× speedup in JAX on TPU for training with the dual form compared to primal (§2.5).
-
(5) Unifying theoretical lens: TTT spans linear attention and (via learners) self-attention.
- Novelty: Theorem 1 shows linear attention is a special case of TTT (and Table 1 verifies equal perplexity under an improved implementation). Theorem 2 shows self-attention corresponds to a nonparametric learner choice (§2.6).
- Significance: This reframes common sequence layers as points in a broader design space, motivating “expressive hidden states” as “richer learners.”
5. Experimental Analysis¶
Evaluation methodology¶
- Primary metric:
perplexity(Ppl), reported on validation/evaluation following protocols aligned with the Mamba paper (§3; Figure 2 caption references Kaplan-style evaluation; §3). - Datasets:
The Pilefor standard 2k and 8k context evaluations (§3, §3.1).Books3subset of the Pile for long-context evaluations from 1k to 32k in ×2 steps (§3, §3.2; Figures 11, 15, 16).- Baselines compared:
- A strong
Transformerbaseline (“Transformer++” style, based on Llama architecture) (§3, Appendix C). Mambaas a modern RNN baseline (§3).- Model scales: Four sizes are used:
125M,350M,760M,1.3Bparameters (§3 protocols). - Mamba sizes are slightly different:
130M,370M,790M,1.4B(noted in §3 protocols). - Training setup (“Chinchilla recipe” as used in Mamba):
- Optimizer:
AdamWwithβ=(0.9, 0.95)(Appendix C). - LR schedule: cosine decay to
1e-5, with linear warmup for 10% of steps (Appendix C). - Weight decay
0.1, grad clip1.0, no dropout, mixed precision (Appendix C). - Training tokens per model size (Table 3):
2.5B,7B,15B,26B. - Peak LRs (Table 3):
3e-3,1.5e-3,1.25e-3,1e-3. - Block counts / embed dims / heads (Table 3): e.g., 1.3B uses 24 blocks,
d=2048, 32 heads. - Context/batch policy: They keep total tokens per training batch fixed at
0.5Mtokens regardless of context length (§3.2 footnote 12; Appendix C reiterates).
Main quantitative results (numbers grounded where available)¶
- Long-context utilization diagnostic (token-index perplexity):
- Figure 2 (right) shows perplexity vs token index up to
32k. -
Reported qualitative outcome:
TTT-LinearandTTT-MLPcontinue reducing perplexity as token index increases (similar to Transformer), whileMambaplateaus after16k(Figure 2 right; §1). -
Scaling trends on the Pile (2k and 8k contexts):
- Figure 10 summarizes Ppl vs FLOPs for 2k and 8k.
-
Reported conclusions (§3.1):
- At
2k,TTT-Linear (M),Mamba, andTransformerare comparable (lines overlap). - At
8k, bothTTT-Linear (M)andTTT-MLP (M)outperformMamba, and the gap widens with longer context (§3.1).
- At
-
Long-context on Books (up to 32k):
- Figure 11 (2k and 32k points) and Figure 15 (full set) show that at
32k,TTT-Linear (M)andTTT-MLP (M)outperformMamba(§3.2). -
Figure 16 shows an alternate view: for models trained from scratch, perplexity can worsen when context gets “too large,” and the best context length increases with model size; this trend is less present with Transformer finetuning (“TF finetune”) (§3.2, Figure 16 caption).
-
Ablation path from linear attention to TTT-Linear (concrete perplexities).
- Table 1 reports, for 125M models trained on Pile with their recipe:
Linear attn. improved:15.23.TTT equivalence:15.23(diff0).+ LN and residual in f:14.05(improvement−1.22).+ mini-batch TTT:12.35(improvement−1.70).+ learnable η:11.99(improvement−0.36).+ Mamba backbone:11.09(improvement−0.90), which they identify as the final TTT-Linear result used in Figure 10 (Table 1 caption).
Wall-clock / latency evaluation¶
- Figure 12 reports inference latency on
NVIDIA A100 80GB PCIe: - Prefill (forward) latency for batch size 16 increases with context for Transformer (consistent with quadratic-ish attention costs), while TTT-Linear/TTT-MLP/Mamba are roughly constant per token as context grows (Figure 12; §3.3).
- Decode (generation) uses the primal form because it is sequential (§3.3); the figure shows roughly constant per-token behavior for non-Transformer methods across tested context lengths, with Transformer growing (Figure 12).
- They also note TPU training iteration time on a v5e-256 pod at 2k context:
- Transformer baseline:
0.30s/iter, TTT-Linear:0.27s/iter(about 10% faster) without extra systems optimization (§3.3). - However, they do not provide full comparable TPU timing vs Mamba because Mamba implementation is GPU-focused (§3.3).
Do experiments support the claims?¶
- Supportive evidence:
- The long-context token-index perplexity diagnostic in Figure 2 right directly targets the claim that some RNNs fail to keep benefiting from additional context; TTT variants behave more like Transformer there.
- Table 1 provides a grounded ablation chain showing which components matter most (mini-batch TTT and LN/residual inside
f). -
Multiple model sizes (125M to 1.3B) and two datasets (Pile, Books) provide breadth (§3).
-
Caveats evident in the paper’s own reporting:
- The paper explicitly notes they do not see a clean scaling-law linear fit in their FLOPs–perplexity plots (Figures 10–11; §3.1), limiting extrapolation claims.
- TTT-MLP has wall-clock challenges due to memory I/O despite favorable FLOPs comparisons (Figure 12; Abstract; §4.2).
6. Limitations and Trade-offs¶
- Wall-clock efficiency, especially for more expressive learners (
TTT-MLP). -
The paper repeatedly emphasizes that while TTT-MLP may be “effective in terms of FLOPs,” its structure increases wall-clock time much more than FLOPs would suggest, attributing this to memory I/O (Abstract; §4.2; Figure 12).
-
Speed–quality trade-off controlled by token mini-batch size
b. - Mini-batching is necessary for parallelism but reduces the “gradient channel” dependency within a mini-batch (§2.4).
-
Figure 7 shows perplexity degrades as
bincreases, and they fixb=16as a compromise (§2.4). -
Training stability and reliance on particular stabilizers.
- Learning
W_0(θ_init) is reported as crucial for stability even if it slightly hurts perplexity in isolation (Table 1 caption; §2.7). -
The learner architecture requires LN + residual wrapper for stability and performance (Table 1; §2.7).
-
Backbone dependence / architectural confounding.
- Their strongest TTT results use a Mamba-style backbone with temporal convolutions and gating (§2.7, Figure 13).
-
This means improvements are not purely attributable to the TTT mechanism alone; Table 1 explicitly shows a gain from “+ Mamba backbone” (Table 1).
-
Long-context training-from-scratch difficulties.
- On Books, Figure 16 indicates that for all methods trained from scratch, perplexity can worsen once context becomes too large, and the optimal context depends on model size (Figure 16 caption).
-
This suggests that simply increasing context length in training is not uniformly beneficial under the fixed recipe.
-
Evaluation scope constraints acknowledged by the paper.
- No hybrid architectures (mixing attention and TTT) are explored, to keep baselines clean (§3 protocols).
-
They did not train at extremely long contexts (millions/billions of tokens of context), citing academic resource constraints (§5).
-
Systems implementation gaps.
- Training is run on TPUs in JAX, but their GPU kernel work is limited to inference; they explicitly did not build a full training kernel for GPUs (§3.3).
7. Implications and Future Directions¶
- How this work changes the landscape (as implied by the paper’s framework + results)
- It proposes that the design space of sequence layers can be expanded from “choose a recurrence vs attention” to “choose a
learner(model + optimizer + task) as memory,” with attention and linear attention appearing as special cases (Figure 8–9; Theorems 1–2). -
Empirically, it suggests that long-context weaknesses observed in modern RNNs (their Mamba plateau result) may be mitigated by making the hidden state more expressive and adaptively learned during inference (Figure 2 right).
-
Follow-up research directions explicitly suggested
- Richer outer-loop task parameterizations: explore other families of self-supervised tasks beyond linear multi-view projections (
θ_K, θ_V, θ_Q) (§5). - Systems optimization: better kernels, pipeline parallelism through time, and multi-device processing for very long sequences (§5).
- Longer contexts + larger models: explore regimes beyond 32k and potentially very long contexts, where the paper expects TTT advantages to grow (§5).
- More ambitious learners
f: larger inner models (possibly convolutional nets for video/agents) when context becomes extremely long (§5). -
Multi-level nested learning: if
fis itself attention, it can be seen as nesting additional inner loops (discussion in §5, connected to §2.6 Theorem 2). -
Practical applications / downstream use cases (grounded in the paper’s framing)
- Any setting needing efficient long-context conditioning where quadratic attention is costly, but where a fixed-size RNN state is too restrictive, could benefit if the wall-clock issues can be resolved (Abstract; §1; §4.2).
-
The paper hints at very long sequential domains like video streams and embodied agents as future targets, where online adaptation is natural (§5; related work in §4.1 references video-stream TTT motivation).
-
Repro/Integration Guidance (based on the paper’s reported choices)
- If you want a linear-time layer with relatively mature efficiency,
TTT-Linearis the more practical instantiation in this paper (Abstract; Figure 12 shows TTT-MLP heavier). - If your priority is long-context utilization and you can tolerate more systems complexity,
TTT-MLPis positioned as having “larger potential in long context” but currently bottlenecked by memory I/O (Abstract; §3.3; §4.2). - For matching the paper’s training setup, use:
AdamW(β=(0.9,0.95)), cosine decay to1e-5, 10% warmup, weight decay0.1, grad clip1.0, no dropout, mixed precision (Appendix C),- model-scale hyperparameters as in Table 3 (blocks, embed dim, heads, steps, peak LR, tokens),
- TTT token mini-batch size
b=16(Figure 7; §2.4), - learner stabilizers: LN + residual inside
f, learnableW_0, and learnable gatedη(x)(§2.7; Table 1).
- When comparing to Transformers in long context, the paper highlights that finetuning a short-context Transformer is a stronger real-world baseline than training from scratch at long context, and they include
TF finetunefollowing a described recipe (20% more tokens) for ≥4k contexts (§3.2; Appendix C).