Skip to content

GENERALIZATION THROUGH MEMORIZATION: NEAREST NEIGHBOR LANGUAGE MODELS

ArXiv: 1911.00172

🎯 Pitch

This paper introduces kNN-LM, a plug-in retrieval augmentation that linearly interpolates a pretrained neural LM’s next-token distribution with a k-nearest-neighbors distribution over a large datastore of context→token examples, improving predictions at inference time with no additional training. By explicitly memorizing rare and long-tail patterns via nearest-neighbor lookup, kNN-LM yields large perplexity gains (new state-of-the-art on WIKITEXT-103), enables effective use of massive corpora without retraining, and provides a simple, practical path for domain adaptation and better factual recall.


1. Executive Summary (2-3 sentences)

This paper introduces kNN-LM, a simple way to improve an existing neural language model (LM) at inference time by interpolating its predicted next-token distribution with a k-nearest-neighbors distribution computed from a large datastore of past context→next-token examples (Eq. (1)–(3), Figure 1). Using a strong Transformer LM on WIKITEXT-103, retrieving neighbors from the original training set improves test perplexity from 18.65 to 16.12, and further to 15.79 when combined with a continuous cache—without additional training (Table 1). The method also enables “more data without training” (Table 3, Figure 2) and practical domain adaptation by swapping the datastore (Table 4, Figure 5).

2. Context and Motivation

  • Problem / gap addressed
  • Neural LMs must (1) map a prefix (context) to a fixed-size representation and (2) use that representation to predict the next word/token (Introduction).
  • The paper targets cases where the LM struggles on rare/long-tail patterns (e.g., names, factual knowledge, near-duplicate sentences), even if it can represent context similarity well (Section 1, Section 6; examples in Figure 6 and Appendix Tables 6–9).

  • Why it matters

  • Long-tail events can dominate practical LM errors (qualitatively shown via rare patterns and factual knowledge; Section 6).
  • Scaling LM training to massive datasets is expensive; the paper explores whether retrieval over large corpora can substitute for training on them (Section 4.2; Table 3, Figure 2).

  • Prior approaches and shortcomings (as discussed in-paper)

  • Cache models retrieve recent hidden states within the test document to help copy rare tokens (Section 2, “Related Cache Models”), but Transformers already can copy recent context via self-attention, limiting gains (Section 4.1; Table 1 discussion).
  • Nearest-neighbor “online” LMs (Grave et al., 2017a as cited) retrieve from previous hidden states for domain adaptation, whereas this work stores training data for explicit memorization (Section 2).

  • Positioning of this paper

  • Core hypothesis: representation learning (“similarity”) is easier than next-token prediction, so a pretrained LM can act as a similarity encoder f(c) while a kNN component handles “memorization” explicitly (Section 1; Section 6 conjecture).

3. Technical Approach

3.1 Reader orientation (approachable technical breakdown)

  • The system is a retrieval-augmented language model that keeps a large memory of training contexts and uses nearest-neighbor search to adjust next-token probabilities at inference time (Figure 1; Section 2).
  • It solves next-token prediction by combining (interpolating) the pretrained LM’s distribution with a distribution formed from the targets of the most similar stored contexts, with no further training (Eq. (2)–(3); Section 2).

3.2 Big-picture architecture (diagram in words)

  • Pretrained Transformer LM: given a context x, produces (a) a next-token distribution pLM(y|x) and (b) a vector representation f(x) used for retrieval (Section 2; Section 3 “kNN-LM”).
  • Datastore builder: runs a forward pass over some text collection and stores pairs (key, value) where key = f(context) and value = next token (Eq. (1); Figure 1; Section 2 “Datastore”, Section 3 “kNN-LM”).
  • FAISS nearest-neighbor index: supports approximate search over up to billions of keys (Section 2 “Implementation”; Section 3 “kNN-LM”).
  • kNN probability module: converts neighbor distances into a distribution over candidate next tokens pkNN(y|x) by distance-weighted aggregation (Eq. (2)).
  • Interpolation module: outputs final distribution p(y|x) = λ pkNN + (1−λ) pLM with λ tuned on validation data (Eq. (3); Section 3; Figure 5).

3.3 Roadmap for the deep dive

  • Explain how the datastore is constructed and what exactly is stored (Eq. (1); Figure 1).
  • Explain inference-time retrieval and how it becomes a probability distribution (Eq. (2)).
  • Explain interpolation with the base LM and the role of λ (Eq. (3); Figure 5).
  • Detail the concrete engineering choices that make billion-scale retrieval feasible (FAISS setup; Section 3).
  • Summarize key hyperparameters (model architecture, context length, retrieval k, quantization, clustering) and how they affect performance (Section 3; Section 5; Figure 4; Table 5).

3.4 Detailed, sentence-based technical breakdown

This is an empirical systems-and-method paper whose core idea is to turn a pretrained LM into a similarity encoder and add an explicit, non-parametric memory via nearest-neighbor retrieval, then blend retrieval-based and neural predictions.

System/data pipeline diagram in words (what happens first, second, third)

  1. Train (or obtain) a base LM:
  2. The method assumes an already-trained autoregressive LM that can compute both a next-token distribution and intermediate hidden representations (Section 2, Section 3 “Model Architecture”).
  3. In experiments, the base model matches Baevski & Auli (2019): a decoder-only Transformer with 16 layers, 16 attention heads per layer, hidden size 1024, feedforward size 4096, and 247M parameters (Section 3 “Model Architecture”).

  4. Build the datastore with one forward pass over a text collection:

  5. For each training example (ci, wi) (context and its next token), compute a vector key ki = f(ci) from an intermediate representation of the pretrained LM, and store the value vi = wi (Eq. (1); Section 2 “Datastore”; Figure 1).
  6. The datastore is thus a set of key-value pairs (K, V) = {(f(ci), wi)} with one entry per token in the chosen corpus (Eq. (1); Figure 1; Section 3 “Implementation” notes this can be “up to billions of examples”).
  7. For their main WIKITEXT-103 setup, during this forward pass each target token gets at least 1536 tokens of prior context; for other corpora, at least 512 tokens of prior context (Section 3 “kNN-LM”).

  8. Index the datastore for fast nearest-neighbor lookup:

  9. Because exact search over billions of 1024-d vectors is too slow, they use FAISS to cluster keys and search within clusters, and to reduce memory by storing compressed vectors (Section 2 “Implementation”).
  10. Concretely, they learn 4096 cluster centroids from 1M randomly sampled keys, quantize keys to 64 bytes, and at inference the index probes 32 cluster centroids while searching for nearest neighbors (Section 3 “kNN-LM”).

  11. Run inference with retrieval augmentation:

  12. Given a test context x, compute the LM’s standard next-token distribution pLM(y|x) and compute the query representation f(x) (Section 2 “Inference”).
  13. Retrieve the k nearest neighbor entries N from the datastore by distance between stored keys ki and query f(x) (Section 2; Figure 1). In experiments, k = 1024 (Section 3 “kNN-LM”).
  14. Convert neighbor distances into a probability distribution over tokens:

    • Weight each neighbor by exp(−d(ki, f(x))) and aggregate weights for neighbors whose stored target token equals y (Eq. (2)).
    • Tokens not present among retrieved neighbors have zero probability under pkNN (Eq. (2) description).
  15. Interpolate retrieval and neural predictions:

  16. Output the final next-token distribution as a convex combination of retrieval and neural distributions:
    p(y|x) = λ pkNN(y|x) + (1 − λ) pLM(y|x) (Eq. (3)).
  17. Tune λ on a validation set; for example, Figure 5 reports an optimal λ ≈ 0.25 for WIKITEXT-103, while domain adaptation prefers a larger λ ≈ 0.65 (Figure 5; Section 5 “Interpolation Parameter”).

Key definitions (selective, paper-specific)

  • Datastore: a memory containing one entry per token from some corpus, storing (key = context representation, value = next token) (Eq. (1); Figure 1).
  • f(c): the pretrained LM’s function mapping a token prefix (context) to a fixed-length vector used for retrieval (Section 2).
  • pkNN(y|x): a retrieval-based next-token distribution formed by distance-weighted voting among retrieved neighbors’ target tokens (Eq. (2)).
  • λ: interpolation weight balancing memorized (retrieval) vs parametric (LM) predictions (Eq. (3); Figure 5).

Distance function and similarity weighting

  • Neighbors are selected by a distance function d(·,·) in embedding space; experiments primarily use squared L2 distance, and interpret the weighting as an RBF-kernel-like similarity via the exponential (Section 2 “Inference”).
  • They report that L2 distance worked better than inner product distance for FAISS retrieval in preliminary experiments (Section 2 “Implementation”).
  • They also report that computing squared L2 distances with full-precision keys improves WIKITEXT-103 perplexity from 16.5 to 16.06 compared to using FAISS distances on quantized keys directly (Section 5 “Precision of Similarity Function”).

What representation is used as the key/query f(·), and why

  • They test multiple internal Transformer states (Figure 3) as the context representation function f(·) and evaluate on WIKITEXT-103 validation perplexity (Table 5).
  • Best-performing choice in their sweep:
  • FFN input after layer norm in the final Transformer layer gives 16.06 validation perplexity (Table 5).
  • Table 5 also shows:
  • Using the model output as key helps but is weaker (17.07), and layer-normalized output is slightly better (17.01).
  • Using representations “before layer norm” tends to be worse than after layer norm (e.g., FFN input before LN: 17.06 vs after LN: 16.06) (Table 5).
  • Their interpretation: self-attention may carry more of the representation burden while the feedforward layer is closer to the prediction step; empirically, the FFN-input-after-LN state works best (Section 5 “Key Function”).

Retrieval hyperparameters and their empirical effects

  • Number of neighbors k:
  • They use k = 1024 for main experiments (Section 3 “kNN-LM”).
  • Figure 4 shows validation perplexity on WIKITEXT-103 improves monotonically as k increases from 1 up to 1024, suggesting potential gains for even larger k (Figure 4; Section 5).
  • They also note that even k = 8 suffices to achieve a new state of the art (Section 5, “Number of Neighbors per Query”).

  • Interpolation weight λ:

  • Tuned on validation (Section 3 “kNN-LM”).
  • Figure 5 demonstrates different optimal λ depending on setting:

    • WIKITEXT-103 in-domain best around λ = 0.25.
    • Domain adaptation best around λ = 0.65, i.e., rely more on pkNN (Figure 5; Section 5).
  • Datastore size:

  • When increasing datastore size (by sampling from WIKI-3B), perplexity improves monotonically and does not saturate at ~3B tokens (Figure 2a; Section 4.2).
  • As datastore size grows, the tuned optimal λ increases, meaning the model leans more on retrieval (Figure 2b).

Concrete configurations reported (and what is not specified)

  • Model architecture (specified): 16-layer decoder-only Transformer, 16 heads, hidden dim 1024, FFN dim 4096, 247M parameters (Section 3).
  • Context lengths (specified):
  • Processes 3072 tokens per example for WIKITEXT-103.
  • Processes 1024 tokens per example for other corpora (Section 3 “Model Architecture”).
  • For evaluation: score 512 tokens per test example with extra prior context of up to 2560 tokens for WIKITEXT-103 and up to 512 extra tokens for other corpora (Section 3 “Evaluation”).
  • Tokenization/vocabulary (specified):
  • WIKITEXT-103: word-level vocabulary of 250K (Section 3 “Data”).
  • Other datasets: BPE with 29K subword vocab from BERT (Section 3 “Data”).
  • Retrieval/index parameters (specified):
  • Build index with 1M sampled keys, 4096 cluster centroids, keys quantized to 64 bytes, probe 32 centroids, retrieve k = 1024 neighbors (Section 3 “kNN-LM”).
  • Optimization/training hyperparameters (not specified in provided text):
  • The paper says it uses the “exact architecture and optimization described by Baevski & Auli (2019)” (Section 3), but the excerpt does not list the optimizer, learning rate schedule, batch size, or hardware used for training. Based strictly on the provided content, those details are not available here, so I cannot report them without guessing.

Worked micro-example illustrating Eq. (2)–(3)

Suppose at test time the context is x = "Obama was born in" (motivated by Figure 1’s illustration). The model computes a query vector q = f(x).

  1. Retrieve k nearest neighbors (ki, vi) from the datastore by distance d(ki, q).
  2. Imagine among the retrieved neighbors, you see three relevant ones with targets:
  3. Neighbor A: vA = "Hawaii" with distance dA = 0.1
  4. Neighbor B: vB = "Illinois" with distance dB = 0.7
  5. Neighbor C: vC = "Hawaii" with distance dC = 0.2
  6. Compute unnormalized weights w = exp(−d):
  7. wA = exp(−0.1), wB = exp(−0.7), wC = exp(−0.2)
  8. Aggregate by token identity (Eq. (2)):
  9. Score for "Hawaii" is wA + wC
  10. Score for "Illinois" is wB
  11. Any token not appearing among neighbors has score 0
  12. Normalize to get pkNN(y|x) (implicit in Eq. (2)’s proportionality).
  13. Combine with the base LM distribution using Eq. (3):
  14. If λ = 0.25 (near the WIKITEXT-103 optimum; Figure 5), then the final probability for "Hawaii" is
    p("Hawaii"|x) = 0.25*pkNN("Hawaii"|x) + 0.75*pLM("Hawaii"|x).

This micro-example matches the mechanism shown in Figure 1: nearest contexts vote for targets, and interpolation merges retrieval with the neural LM.

4. Key Insights and Innovations

  • (1) Token-level kNN augmentation of a pretrained LM with no retraining
  • Novelty: the method stores one datastore entry per training token (context representation → next token) and queries it at inference (Eq. (1), Figure 1), rather than training a new parametric component.
  • Significance: yields large perplexity improvements on a strong baseline without updating model parameters (Table 1).

  • (2) Explicit memorization improves generalization—especially in the long tail

  • The retrieval component is particularly helpful for rare patterns like factual knowledge, names, and near duplicates (Section 6; Figure 6 and Appendix Tables 6–9).
  • The method operationalizes the paper’s hypothesis that the LM’s similarity space is strong even when its parametric classifier is weak on rare events (Introduction; Section 6).

  • (3) “More data without training”: retrieval over a huge corpus can beat training on it

  • Training on WIKI-100M (100M tokens) plus kNN retrieval over WIKI-3B (≈2.87B tokens) outperforms training the same architecture directly on WIKI-3B (Table 3; Section 4.2).
  • This reframes scaling: learn representations on a smaller dataset, then expand memory via a larger datastore (Section 4.2; Figure 2).

  • (4) Simple domain adaptation by swapping datastores

  • A Wikipedia-trained model performs poorly on Books, but adding a Books datastore sharply reduces perplexity (Table 4), and Figure 5 shows domain adaptation prefers higher reliance on kNN (larger λ).

  • (5) Practical guidance on which internal representation to retrieve on

  • The paper empirically identifies the final-layer FFN input after layer norm as the best-performing retrieval key among tested options (Table 5; Figure 3), which is a concrete and reusable design detail.

5. Experimental Analysis

Evaluation methodology

  • Datasets (Section 3 “Data”)
  • WIKITEXT-103: 103M train tokens; 250K dev and 250K test tokens; word-level vocab 250K.
  • BOOKS: Toronto Books Corpus, 0.7B tokens; whole books held out for validation/test.
  • WIKI-3B: ≈2.87B tokens; whole articles held out for validation/test.
  • WIKI-100M: random 100M-token subset of WIKI-3B (complete articles).

  • Metrics

  • Perplexity (lower is better), computed as exponentiated negative log-likelihood on heldout data (Section 3 “Evaluation”).

  • Baselines and comparisons

  • Base LM is the Transformer LM from Baevski & Auli (2019), with matching architecture and optimization (Section 3).
  • They also compare to:
    • Transformer-XL augmentation and Phrase Induction augmentation as reported numbers (Table 1).
    • A Continuous Cache method (Grave et al., 2017c) (Table 1).
  • Additional analyses compare to:
    • Interpolating with n-gram LMs (Figure 7; Section 6).
    • A memorizing Transformer trained without dropout (Figure 8; Section 6).

Main quantitative results (with specific numbers)

  • WIKITEXT-103 improvements (Table 1)
  • Base LM test perplexity: 18.65 (dev 17.96).
  • + kNN-LM: test perplexity 16.12 (dev 16.06), a ~2.53 reduction vs 18.65.
  • + Continuous Cache alone: test 18.27 (dev 17.67), a smaller improvement.
  • + kNN-LM + Continuous Cache: test 15.79 (dev 15.81), which Table 1 presents as the best reported result and is described as a 2.86 improvement over the base model.

  • BOOKS domain (Table 2)

  • Base LM test perplexity: 11.89 (dev 14.75).
  • + kNN-LM: test 10.89 (dev 14.20), improving by 1.00 on test.

  • Scaling via datastore vs training on more data (Table 3; Figure 2)

  • Train on WIKI-3B (no datastore): test 15.17 (dev 16.11).
  • Train on WIKI-100M (no datastore): test 19.59 (dev 20.99).
  • Train on WIKI-100M + datastore WIKI-3B: test 13.73 (dev 14.61), outperforming the LM trained on all WIKI-3B tokens (15.17) (Table 3; Section 4.2).
  • Figure 2a shows perplexity improves monotonically with datastore size and does not saturate up to ~3B tokens; Figure 2b shows the tuned optimal λ increases with datastore size.

  • Domain adaptation (Table 4; Figure 5)

  • Model trained on BOOKS: test 11.89 (in-domain).
  • Model trained on WIKI-3B evaluated on BOOKS: test 34.84 (very poor transfer).
  • WIKI-3B model + BOOKS datastore: test 20.47 (dev 24.85), a 14.37 reduction vs 34.84 (Table 4; Section 4.3).
  • Figure 5 further shows that increasing λ helps domain adaptation (right axis), with best around λ = 0.65, while in-domain Books performance prefers smaller λ (left axis shows a U-shape).

Do the experiments support the claims?

  • Claim: kNN augmentation improves a strong LM without retraining
  • Supported by Table 1 and Table 2: substantial perplexity reductions with the same number of trainable parameters (247M) and no additional training described for kNN-LM.

  • Claim: scaling via retrieval can substitute for training on more data

  • Supported by Table 3 and Figure 2: training on 100M tokens plus retrieval over 3B tokens yields better perplexity than training directly on 3B tokens.

  • Claim: especially helpful on rare patterns / factual knowledge

  • Supported qualitatively by Section 6 and examples (Figure 6; Appendix Tables 6–9) where pkNN assigns very high probability to the correct target in rare/near-duplicate contexts.
  • This is qualitative evidence; the excerpt does not include a quantitative breakdown by rarity bucket, so the strength of this claim is mostly illustrative in the provided content.

Ablations / robustness checks reported

  • Key representation ablation: multiple internal states compared (Table 5; Figure 3).
  • Number of neighbors: monotonic improvement with k (Figure 4).
  • Interpolation parameter sensitivity: explored and differs by setting (Figure 5).
  • Precision ablation: full-precision squared L2 improves perplexity compared to quantized-distance computation (Section 5).
  • Alternative explanations tested (Section 6):
  • n-gram interpolation yields little improvement (~0.2 perplexity points; Figure 7).
  • A memorizing Transformer without dropout reaches zero training loss but generalizes poorly (best validation perplexity 28.59), and interpolating it with the original LM improves validation perplexity by only 0.1, far less than kNN-LM’s 1.9 improvement (Section 6; Figure 8).

6. Limitations and Trade-offs

  • Inference-time computational overhead
  • kNN-LM adds retrieval cost: for WIKITEXT-103, validation inference took ~25 minutes when retrieving 1024 keys (Section 3 “Computational Cost”).
  • Retrieval requires building and storing a large index; for WIKITEXT-103, building the cache with 103M entries takes roughly 2 hours on a single CPU (Section 3).

  • Memory/storage footprint scales linearly with datastore size

  • The datastore has one entry per token (Section 2 “Implementation”); the paper notes corpora can be billions of examples, implying large storage.
  • They mitigate via quantization to 64 bytes per key (Section 3), but the excerpt does not report total datastore size in GB/TB, so exact memory requirements are not given.

  • Approximate nearest-neighbor search may affect accuracy

  • FAISS uses clustering and compressed vectors (Section 2 “Implementation”), and for non-WIKITEXT-103 datasets they use FAISS L2 distances directly on quantized keys “for faster evaluation” (Section 3).
  • This introduces a trade-off between speed and distance precision; Section 5 reports that more precise distance computation can improve perplexity (from 16.5 to 16.06 on WIKITEXT-103).

  • Hyperparameter tuning and sensitivity

  • The method introduces new hyperparameters: k, λ, key type/which layer state, FAISS clustering/probe settings, quantization precision (Section 5; Figure 4–5; Table 5).
  • λ is setting-dependent (Figure 5), which may complicate deployment across domains.

  • Dependence on the base LM’s representation quality

  • The approach assumes f(·) embeds contexts so that “similar contexts are near,” because retrieval quality directly depends on representation similarity (Section 1, Section 6).
  • The paper’s own analysis argues that representation learning generalizes better than memorization-in-parameters (Section 6), but if f(·) is poor for a domain or language, retrieval may be less effective (this limitation is implied by the method’s reliance; the excerpt does not provide counterexamples).

  • Scope of evaluation

  • The reported evaluations are on English corpora and perplexity metrics (Section 3, Tables 1–4). The excerpt does not include evaluations on downstream tasks (e.g., QA) or other languages, so generality beyond these settings is not established here.

7. Implications and Future Directions

  • How this work changes the landscape (within the paper’s scope)
  • It demonstrates a strong, practical argument that explicit memory + similarity search can significantly improve token prediction over a high-capacity Transformer LM without retraining (Table 1; Section 6).
  • It suggests an alternative scaling strategy: train representations on smaller corpora and expand knowledge via retrieval over large corpora (Section 4.2; Table 3; Figure 2).

  • Follow-up research directions suggested by the paper

  • The conclusion proposes “explicitly training similarity functions” and “reducing the size of the datastore” (Section 8).
  • Section 5’s results on which internal representation works best also suggest exploring representation learning objectives that better align with retrieval.

  • Practical applications / downstream use cases (based on provided content)

  • Improved general language modeling on standard benchmarks by augmenting an existing LM with a datastore built from its training corpus (Table 1; Section 4.1).
  • Domain adaptation without retraining by maintaining a per-domain datastore and tuning λ appropriately (Table 4; Figure 5; Section 4.3).
  • Efficient use of large datasets: ingest large corpora into a datastore rather than fully retraining a model on them (Section 4.2; Table 3).

  • Repro/Integration Guidance (grounded in the paper’s details)

  • When to prefer this method:
    • Prefer kNN-LM when you already have a strong pretrained autoregressive LM and want better perplexity, especially on rare patterns, without additional training (Table 1; Section 6).
    • Prefer swapping/adding datastores when you need domain adaptation but want to avoid fine-tuning the LM parameters (Table 4; Figure 5).
  • Key implementation choices that matter:
    • Use FAISS for scalable search; build the index with clustering and quantization as described (4096 centroids learned from 1M keys, 64-byte quantization, probe 32 centroids; Section 3).
    • Use the final-layer FFN input after layer norm as the retrieval key/query representation, since it performs best among tested options (Table 5).
    • Retrieve many neighbors if feasible (k up to 1024 improves monotonically; Figure 4), and tune λ on validation, expecting larger λ to help more in domain adaptation (Figure 5).
  • Operational considerations:
    • Plan for a one-time forward pass over the datastore corpus and index build time (e.g., ~2 hours CPU for 103M entries on WIKITEXT-103; Section 3), plus ongoing inference-time retrieval overhead (Section 3).