GROKKING: GENERALIZATION BEYOND OVERFITTING ON SMALL ALGORITHMIC DATASETS¶
ArXiv: 2201.02177
🎯 Pitch¶
This paper introduces and systematically studies “grokking,” a striking training dynamic on small, algorithmically generated binary-operation datasets where models first memorize the training set and only after vastly more optimization suddenly achieve perfect generalization. By providing a reproducible, low-compute testbed and analyzing factors (data fraction, optimizers, weight decay, embedding structure), it offers a practical window into how overparameterized networks transition from memorization to rule-learning—insight crucial for understanding and improving generalization in deep learning.
1. Executive Summary (2-3 sentences)¶
This work studies how overparameterized neural networks generalize when trained on small, algorithmically generated datasets formed from incomplete binary operation tables (e.g., modular arithmetic or group operations) where inputs are discrete symbols with no built-in structure. It identifies a striking training dynamic—grokking—where validation accuracy can stay at chance level long after the model has reached near-perfect training accuracy, and then suddenly rise to perfect generalization after vastly more optimization steps (Figure 1, Figure 4). The setting is designed to make questions about memorization vs. true rule-learning experimentally tractable and reproducible on modest compute (e.g., a single GPU, per the introduction).
2. Context and Motivation¶
- Problem / gap addressed
- Modern neural networks often generalize well despite being large enough to memorize (interpolate) their training data, which clashes with classical intuitions about overfitting.
- Many commonly observed datasets from natural data do not show cleanly separated phases of memorization and later generalization, making it hard to isolate mechanisms.
-
This work focuses on a controlled regime where memorization and generalization dynamics can be observed “in great detail”: small, fully specified algorithmic worlds with a known ground-truth rule.
-
Why it matters
- Understanding generalization in overparameterized models is a central theoretical and practical issue in deep learning.
-
A small, fast-to-run testbed makes it easier to:
- reproduce unusual generalization curves (Figure 1),
- systematically vary data fraction and optimization details (Figure 2),
- probe hypotheses about why generalization emerges late (Appendix A.5).
-
Prior approaches and shortcomings (as positioned here)
- Prior algorithmic reasoning benchmarks often emphasize performance in an unlimited data regime or generalization across input lengths (Appendix A.3), rather than the data-limited regime where memorization is easy but rule discovery is not.
-
Work on double descent studies loss vs. model/optimization capacity, whereas here the key phenomenon is a second descent (in loss) and late improvement in validation accuracy as a function of optimization time, far after training loss becomes very small (Section 3.1; Figure 4; Appendix A.3 discussion).
-
How this work positions itself
- It treats small binary-operation datasets as a “fertile ground” for studying generalization beyond memorization, with a focus on late generalization dynamics (
grokking) and how optimization/regularization choices affect data efficiency.
3. Technical Approach¶
3.1 Reader orientation (approachable technical breakdown)¶
- The system is a small
decoder-only Transformertrained to complete symbolic equations like “a ◦ b = ?” where each symbol is just a token ID. - It solves a “fill in the missing entries” problem for a partially observed binary operation table by predicting the correct output symbol
cfor unseen input pairs(a, b).
3.2 Big-picture architecture (diagram in words)¶
- Dataset generator creates all possible equations for a chosen binary operation and then splits them into train/validation by sampling a fraction for training (Appendix A.1.1).
- Tokenizer / sequence builder encodes each equation as a 5-token sequence:
[a] [op] [b] [=] [c](Section 2; Appendix A.1.1). - Transformer model (decoder-only, causal mask) reads the prefix and predicts the answer token, with loss/accuracy computed only on the answer position (Appendix A.1.2).
- Optimizer / training loop runs for up to a fixed budget of gradient updates, using variants such as
Adam,AdamW, full-batch vs. minibatch, weight decay, and noise injections (Appendix A.1.2; Section 3.3). - Evaluation measures training accuracy (on seen equations) and validation accuracy (on unseen equations) across optimization steps and across training-set fractions (Section 3; Figure 1; Figure 2).
3.3 Roadmap for the deep dive¶
- First, I explain the data construction as completion of a binary operation table (what the model sees and what it must infer).
- Second, I detail the model and objective (decoder-only Transformer; where loss is applied).
- Third, I describe the training procedure and hyperparameters, including the special settings used to highlight grokking.
- Fourth, I walk through the experimental knobs (dataset fraction, optimizer/regularization variants, noise) and what is measured.
- Fifth, I summarize the additional probes: embedding visualizations and outlier/noise experiments.
3.4 Detailed, sentence-based technical breakdown¶
This is an empirical study of training dynamics in a controlled supervised-learning setup, with the core idea that small algorithmic datasets let one separate “memorize the training set” from “learn the underlying rule,” sometimes revealing a late generalization transition called grokking (Section 3.1; Figure 1).
- Task formulation as binary operation completion
- A task is defined by a binary operation table \( \circ \) over a finite set of discrete symbols, where each equation has the form \(a \circ b = c\) (Introduction; Section 2; Appendix A.1.1).
- The key constraint is that the symbols have no internal structure: each distinct element \(a\), \(b\), or \(c\) is represented by its own token, so the model cannot exploit, for example, decimal representations of numbers or explicit permutation notation (Introduction).
-
Training on a subset of equations is framed as “filling in the blanks” of the full operation table, analogous to solving a Sudoku-like completion problem (Introduction; Figure 1 right).
-
What happens first/second/third (data-to-model pipeline in words)
- First, for a chosen operation (e.g., modular addition mod \(p\), modular division mod \(p\), or multiplication in the group \(S_5\)), the procedure enumerates all valid input pairs \((x,y)\) in the domain and computes the corresponding output \(x \circ y\) to create a complete set of equations (Appendix A.1.1).
- Second, for each training run, it randomly samples a fraction of these equations as the training set, and assigns the remaining equations to the validation set (Appendix A.1.1).
- Third, each equation is serialized into a fixed token sequence
hxihopihyih=ihx◦yi(Appendix A.1.1), i.e., five tokens[x] [op] [y] [=] [x◦y]. - Fourth, the Transformer processes the token sequence with causal masking, and the training objective computes loss/accuracy only on the answer part (the final token), so the model is trained to predict the correct output symbol given the preceding tokens (Appendix A.1.2).
-
Fifth, training proceeds for a preset number of gradient updates (the “optimization budget”), while tracking training vs. validation accuracy over time to detect memorization and any later generalization transition (Section 3.1; Figure 1; Figure 4).
-
Model architecture (as specified)
- The model is a standard
decoder-only Transformerwith causal attention masking (Appendix A.1.2). - The configuration used throughout is:
#layers: 2width(hidden size): 128#attention heads: 4#parameters: about \(4 \cdot 10^{5}\) non-embedding parameters (Appendix A.1.2)
-
The tokenization is at the symbol level described above, where “a”, “◦”, “b”, “=”, and “c” are each separate tokens (Section 2; Appendix A.1.1).
-
Optimization setup and hyperparameters (as specified)
- The default optimizer for most experiments is
AdamWwith (Appendix A.1.2):- learning rate \(10^{-3}\)
- weight decay \(1\)
- \(\beta_1 = 0.9\), \(\beta_2 = 0.98\)
- linear learning rate warmup over the first 10 updates
- minibatch size \(512\) or half of the training dataset size (whichever is smaller)
- optimization budget \(10^{5}\) gradient updates
- To better capture late generalization time-vs-data-fraction effects, some experiments increase the optimization budget to \(5 \cdot 10^{5}\) updates (Section 3.1.1; Appendix A.1.2).
-
To emphasize how late grokking can occur, Section 3.1 uses
Adamwithout weight decay and increases the budget to \(10^{6}\) updates (Appendix A.1.2; Figure 1 discussion). -
Experimental knobs studied
- Dataset fraction (data efficiency curves): The main independent variable is the fraction of all equations included in training; generalization is measured on the remaining equations (Section 3.1.1; Section 3.2; Figure 2 right).
- Optimization algorithm / noise / regularization: The ablations include full-batch vs. minibatch,
Adamvs.AdamW, residual dropout, gradient noise, Gaussian weight noise, and different learning rates (Section 3.3; Appendix A.1.2; Figure 2 left). -
Operation choice: The operations include modular arithmetic and group operations on \(S_5\), as enumerated in Appendix A.1.1 (also referenced in Section 3.2).
-
Concrete example of the learning target
- If the domain is residues mod \(p=97\), one example equation has the template:
[x] [op] [y] [=] [x ◦ y], where each residue (0–96) is its own token (Figure 1 caption; Introduction). -
In modular division, the operation is “division mod 97” (Figure 1 caption), and the model’s job is to output the correct residue token for the quotient given the two input residue tokens and the operator token.
-
How grokking is detected in this setup
- Training typically reaches near-perfect training accuracy quickly (on the order of \(10^{3}\)–\(10^{4}\) updates in several settings), meaning the model can memorize/interpolate the training subset (Figure 1 caption; Figure 2 left caption; Section 3.1.1).
Grokkingrefers to cases where validation accuracy remains near chance for a long time after memorization, and then rises sharply toward perfect generalization much later in optimization (Section 3.1; Figure 1).- The modular division example is the canonical demonstration: training accuracy becomes close to perfect at \(<10^{3}\) steps, but validation accuracy only reaches that level near \(10^{6}\) steps, with little evidence of generalization until around \(10^{5}\) steps (Figure 1 caption; Section 3.1; Figure 4 shows the corresponding loss dynamics).
4. Key Insights and Innovations¶
- (1) A clear, reproducible “late generalization” regime on small algorithmic datasets
- What is new here is the strength and clarity of the decoupling: validation performance can stay at chance long after the model has already achieved near-perfect training accuracy, and then improve dramatically later (Section 3.1; Figure 1).
-
This creates an experimental regime where memorization and generalization are temporally separated, which is harder to see cleanly in many natural datasets (Introduction framing).
-
(2)
Grokking: sudden transition from chance to perfect generalization well past overfitting Grokkingis the name given to the observed phenomenon where generalization can “turn on” late in training, after a long period of apparent non-generalization (Section 3.1; Figure 1).-
The associated loss behavior can show a second descent of validation loss after it initially rises during overfitting (Figure 4; Section 3.1).
-
(3) Compute–data tradeoff: smaller datasets can still reach perfect generalization but require much more optimization
-
Instead of converging to a lower final generalization level when trained on less data (a common intuition when models can interpolate), these tasks can maintain 100% converged performance over a range of dataset sizes, but the time to reach it increases rapidly as the training fraction decreases (Section 3.1.1; Figure 1 center discussion).
-
(4) Regularization/optimization details matter strongly; weight decay is especially effective
- Among tested interventions,
weight decaysubstantially improves data efficiency on at least the showcased group task (Section 3.3; Figure 2 left). -
The work also reports benefits from optimization noise sources (minibatches, explicit Gaussian noise), consistent with the idea that noise may encourage solutions that generalize better (Section 3.3).
-
(5) Learned embeddings can reflect underlying algebraic structure
- Visualizations (t-SNE of output-layer weight vectors) sometimes reveal recognizable structure: e.g., circular structure for modular addition and clustered coset-like structure for \(S_5\) (Section 3.4; Figure 3).
- The structure is described as more apparent in networks optimized with weight decay (Section 3.4).
5. Experimental Analysis¶
- Evaluation methodology
- Data split: For each operation, all valid equations are generated; a random fraction is used for training and the rest for validation (Appendix A.1.1).
- Metric: Accuracy on the answer token (and loss) is tracked separately on training vs. validation sets (Appendix A.1.2; Figure 1 caption; Figure 4 caption).
- Model: 2-layer decoder-only Transformer, width 128, 4 heads, \(\approx 4 \cdot 10^{5}\) non-embedding parameters (Appendix A.1.2).
-
Runs / seeds: Most experiments are repeated with 3 random seeds; the time-to-generalization experiments in Section 3.1.1 aggregate over 7 seeds (Appendix A.1.2).
-
Main quantitative behaviors (with cited specifics)
- Late generalization example (modular division mod 97, 50% train fraction):
- Training accuracy becomes close to perfect at \(<10^{3}\) optimization steps (Figure 1 caption).
- Validation accuracy reaches near-perfect only around \(10^{6}\) steps, with little evidence of generalization until around \(10^{5}\) steps (Figure 1 caption; Section 3.1).
- The corresponding validation loss increases from about \(10^{2}\) to about \(10^{5}\) optimization steps before beginning a second descent (Figure 4 caption).
- Time-to-generalization vs. dataset fraction:
- On the \(S_5\) product task, near 25–30% of data, decreasing the training fraction by 1% increases the median time to reach 99% validation accuracy by roughly 40–50% (Section 3.1.1; Figure 1 center described).
- Across tasks where generalization occurs, the number of steps until training accuracy first reaches 99% generally trends down as dataset size decreases and stays around \(10^{3}\)–\(10^{4}\) steps (Section 3.1.1).
-
Optimization/regularization ablations (Figure 2 left):
- Weight decay improves generalization “the most” among compared interventions within a fixed budget of \(10^{5}\) steps on learning the product in abstract group \(S_5\) (Figure 2 left caption; Section 3.3).
- Poor hyperparameter choices can severely limit generalization (Figure 2 left caption).
- Learning rate must be tuned within a relatively narrow window (within 1 order of magnitude) for generalization to happen in their experiments (Section 3.3).
-
Breadth of tasks tested
- Modular arithmetic operations mod \(p=97\) include \(x+y\), \(x-y\), \(x/y\) (with \(y\neq 0\)), mixtures conditioned on parity of \(y\), and several polynomial-like expressions (Appendix A.1.1).
-
Group operations include multiplication in \(S_5\) and variants like \(x\cdot y\cdot x^{-1}\) and \(x\cdot y\cdot x\) (Appendix A.1.1).
-
Do the experiments support the claims?
- The modular division curves (Figure 1 and Figure 4) directly support the existence of late generalization well after training accuracy saturates, matching the defined
grokkingphenomenon. - The data-fraction/time-to-99% results (Section 3.1.1; Figure 1 center narrative) support the compute–data tradeoff claim: smaller datasets can require dramatically more optimization to reach the same validation threshold.
-
The ablation results (Figure 2 left; Section 3.3) support the claim that optimization details—especially weight decay—change data efficiency, though the paper’s excerpted content summarizes effects qualitatively more than listing full numeric tables.
-
Ablations, failure cases, robustness checks
- Failure to generalize on some operations: For example, \(x^{3}+xy^{2}+y \pmod{97}\) does not generalize within the allowed optimization budget even up to 95% of the data, and the model behaves like it is memorizing random labels (Section 3.2).
- Outlier/noise robustness (Appendix A.4; Figure 6):
- Introducing a small number of outliers (random-label examples) up to 1000 does not noticeably impact generalization performance in the shown setting.
- A large number of outliers significantly hinders generalization (Figure 6 caption; Appendix A.4).
- Training still reaches 100% training accuracy across these settings, suggesting the model can interpolate even the noisy training set (Appendix A.4).
6. Limitations and Trade-offs¶
- Scope limitation: highly structured but synthetic tasks
- The datasets are intentionally artificial: discrete symbols with no internal structure, generated from known binary operations (Introduction; Appendix A.1.1).
-
This improves interpretability and control but limits direct conclusions about natural data distributions.
-
Compute dependence for small data fractions
- A central trade-off is that generalization may require very large numbers of optimization steps when the training dataset is small (Section 3.1.1; Figure 1 center discussion).
-
The modular division example demonstrates that validation improvement can occur on the order of \(10^{6}\) steps, far beyond the steps needed for memorization (Figure 1 caption), making this behavior expensive if scaled.
-
Hyperparameter sensitivity
- Generalization depends on optimization details; learning rate needs tuning within a narrow window (Section 3.3), and suboptimal hyperparameters can prevent generalization within the budget (Figure 2 left caption).
-
This means “does grokking happen?” can be contingent on training setup rather than purely on task structure.
-
Not all operations grok within explored budgets
- Some operations appear effectively random to the model at the tested scale/budgets (e.g., \(x^{3}+xy^{2}+y \pmod{97}\)), and the approach fails to produce generalization (Section 3.2).
-
This limits how universal the phenomenon is across algorithmic tasks, at least with this architecture and training regime.
-
Mechanistic explanation remains incomplete
- The paper provides suggestive evidence and hypotheses (e.g., flatter minima, noise effects) but does not fully pin down a causal mechanism for grokking (Discussion; Appendix A.5).
7. Implications and Future Directions¶
- How this changes the landscape
- It provides a compact, reproducible testbed where generalization can be studied as a distinct phase emerging after memorization, with dramatic separation in time (Figure 1).
-
This shifts part of the generalization discussion from “why do big networks generalize at all?” to “why does optimization sometimes first find a memorizing solution and only much later a generalizing one?”
-
Follow-up research directions suggested by the results
- Generalization measures and loss landscape geometry: Appendix A.5 reports that an approximate sharpness measure \(\phi\) (computed via the method described in Keskar et al. (2016) within the paper) correlates negatively with validation accuracy on an \(S_5\) composition objective, with Spearman correlation \(-0.79548\) and \(p < 0.000014\). This motivates deeper study of whether grokking coincides with moving into flatter regions (Appendix A.5; Figure 7).
- Noise and regularization as levers: Since weight decay and noise sources appear beneficial (Section 3.3), future work can test which aspects (implicit bias of optimizers, flatness, parameter norms) best predict late generalization.
-
Broader task families: The paper notes some operations do not generalize within budget (Section 3.2), suggesting a direction of mapping which algebraic/symmetry properties make grokking likely.
-
Practical applications / downstream use cases (within the paper’s framing)
- The primary “application” is methodological: these datasets act as a controlled environment to probe generalization, not as an end task (Introduction; Discussion).
-
Embedding visualizations (Figure 3) hint at a potential diagnostic tool: inspecting learned representations to see whether the model has discovered underlying structure (Section 3.4; Discussion).
-
Repro/Integration Guidance (when to use this approach)
- Prefer this setup when you want a clean separation between memorization and rule-learning and the ability to sweep data fraction and optimization settings quickly (Introduction; Sections 3.1–3.3).
- If the goal is to induce generalization in this regime, the reported interventions to prioritize are:
- using
AdamWwith substantial weight decay (e.g., weight decay \(=1\) in the paper’s default setting) (Appendix A.1.2; Section 3.3; Figure 2 left), - ensuring learning rate is within a workable range (Section 3.3),
- and allowing sufficiently large optimization budgets when training fractions are small (Section 3.1.1; Figure 1 center narrative).
- using