Skip to content

The geometry of hidden representations of large transformer models

ArXiv: 2302.00294

🎯 Pitch

This paper shows that hidden representations in large self-supervised transformers undergo a consistent depth-wise pattern—an early intrinsic-dimension expansion, a strong compression into a low-ID region that best encodes semantic structure, and a final decoding phase—and links these geometric changes to local neighbor rearrangements. The key contribution is an unsupervised, geometry-based rule to select intermediate layers (relative minima or low-ID plateaus) that maximize semantic content, enabling better downstream features and retrieval without labeled data.


1. Executive Summary (2-3 sentences)

This paper characterizes how hidden representations in large self-supervised transformers change across depth by measuring their intrinsic dimension (ID) and how local nearest-neighbor relationships get rearranged. Across protein language models (ESM-2) and an image generative transformer (iGPT), representations consistently go through an early expansion to higher ID, then a strong compression to a lower-ID region where semantic structure is best expressed, and finally a decoding/reconstruction phase near the output. The main practical takeaway is an unsupervised heuristic for selecting “best” intermediate layers for downstream tasks: choose layers at (or near) a relative minimum in the ID profile (or the low-ID plateau in protein models).

2. Context and Motivation

  • Problem / gap addressed
  • Self-supervised transformers produce a sequence of internal representations (one per layer), but it is unclear where in depth the model best encodes semantic information (e.g., image class, protein remote homology), especially because the last layer is optimized for reconstruction (masked-token recovery or next-token prediction), not necessarily for semantic abstraction.
  • Prior observations indicate last-layer representations can reflect abstract properties in some settings, but for self-supervised reconstruction objectives, there are reasons to expect the “most semantic” representation to occur in intermediate layers.

  • Why it matters

  • In many domains (proteins, images), labels can be scarce or expensive.
  • If we can identify the most semantically useful layer without supervision, we can:

    • improve retrieval (e.g., nearest-neighbor homology search),
    • choose better features for linear probes / downstream learners,
    • understand “what computation happens where” in deep transformers.
  • Prior approaches and shortcomings (as positioned here)

  • Earlier work on CNNs showed geometric quantities (notably intrinsic dimension) vary strongly across layers, and neighbor structure changes sharply near supervised classification layers.
  • For transformers, there are known results that intermediate features can be useful (e.g., iGPT linear probes), but there is not a simple geometry-based, task-agnostic indicator for layer selection.

  • How this paper positions itself

  • It proposes a geometry-first analysis: track ID (a global manifold complexity measure) and neighborhood overlap (a local structural stability measure) across depth.
  • It tests this in two very different modalities/tasks:
    • proteins: ESM-2 family trained with masked language modeling (MLM),
    • images: iGPT family trained autoregressively to predict next pixel-token.

3. Technical Approach

3.1 Reader orientation (approachable technical breakdown)

  • The system is an analysis pipeline that takes hidden activations from each layer of a pretrained transformer and computes geometric/statistical summaries of those activations.
  • It solves the problem of finding which layers encode the most semantic structure by measuring (i) how many degrees of freedom the representation manifold seems to have (ID) and (ii) how neighborhoods of examples change across layers (overlap).

3.2 Big-picture architecture (diagram in words)

  • Input datasets (proteins or images)
    Pretrained transformer (ESM-2 or iGPT)
    Layerwise representation extraction (one vector per example per layer via pooling)
    Distance computations (Euclidean distances between example vectors within a layer)
    Two analyses per layer: 1) Intrinsic dimension via TwoNN (Two Nearest Neighbors),
    2) Neighborhood overlap:
    • between consecutive layers, and
    • between a layer’s neighbors and ground-truth label partitions (when labels exist). → Layer selection heuristic: pick layers around ID minima / low-ID region.

3.3 Roadmap for the deep dive

  • Explain how representations are extracted from transformers (so distances are well-defined across variable-length proteins).
  • Define intrinsic dimension and the TwoNN estimator used here.
  • Define neighborhood overlap between layers and with labels, and what it measures operationally.
  • Walk through the experimental setup: models, datasets, hyperparameters of the analysis (e.g., k in kNN overlap).
  • Tie the measured curves (ID and overlap) to the paper’s claims about phases of computation and semantic content.

3.4 Detailed, sentence-based technical breakdown

This is an empirical representation-geometry analysis paper: the core idea is to treat the set of layer activations for a dataset as a point cloud and quantify how its geometry changes with depth.

Representation extraction (what vectors are analyzed). - Each transformer processes an input sequence into a sequence of hidden states of shape roughly (#tokens × d), where d is the embedding size (constant across depth for these models). - For proteins, sequences have variable length l, so the paper needs a way to compare two proteins with different l using Euclidean distances. - The paper extracts hidden representations after the first normalization layer of each transformer block (Section 2.3) and then applies average pooling along the sequence dimension to map a sequence representation in \(\mathbb{R}^{l \times d}\) to a single vector in \(\mathbb{R}^{d}\): [ \text{pool}(f_i(x)) \;=\; \frac{1}{l}\sum_{j=1}^{l} (f_i(x))_j \in \mathbb{R}^d. ] - This yields, for each layer/block index \(i\), a dataset of \(N\) vectors in \(\mathbb{R}^d\) (one per example), on which Euclidean distances are computed.

Intrinsic dimension (ID): what it is and how TwoNN estimates it. - Intrinsic dimension here means: although each vector lives in a high-dimensional ambient space \(\mathbb{R}^d\), the dataset may concentrate near a lower-dimensional manifold; ID estimates that manifold’s effective dimension. - The paper uses the TwoNN estimator (Section 2.1), which depends only on each point’s first two nearest-neighbor distances: - For point \(x_i\), let \(r_{i1}\) and \(r_{i2}\) be distances to the 1st and 2nd nearest neighbors. - Define the ratio [ \mu_i = \frac{r_{i2}}{r_{i1}}. ] - Under a locally constant density assumption, \(\mu_i\) follows a Pareto distribution whose shape parameter is the ID \(d_{\text{ID}}\): [ p(\mu \mid d_{\text{ID}}) = d_{\text{ID}} \,\mu^{-d_{\text{ID}}-1}. ] - Estimation is done by linear regression using the cumulative distribution function form \(F(\mu)=1-\mu^{-d_{\text{ID}}}\) (Appendix B describes the regression procedure). - Practical detail: the paper reports TwoNN estimates on a dataset decimated by a factor of 4 (Appendix B), after checking scale dependence. - The local-density assumption is checked using PAk (Point Adaptive kNN): they find density can be treated as constant within about the first 6 neighbors on average, and TwoNN uses only the first two (Appendix B), supporting validity at that scale.

Neighborhood overlap: what it measures and how it is computed. - To quantify how “stable” local neighborhoods are from layer to layer, the paper uses neighborhood overlap \(\chi^{l,m}_k\) (Section 2.2). - For each layer \(l\), build a \(k\)-nearest neighbor graph using Euclidean distances in that layer’s representation: - Define adjacency \(A^l_{ij}=1\) if \(x^l_j\) is among the \(k\) nearest neighbors of \(x^l_i\), else \(0\). - Overlap between layers \(l\) and \(m\) is the average fraction of neighbors shared: [ \chi^{l,m}k = \frac{1}{N}\sum. ] - Semantic probe variant (requires labels): build a “ground-truth adjacency” }^N \frac{1}{k}\sum_{j=1}^N A^l_{ij} A^m_{ij\(A^{gt}_{ij}=1\) if \(y_i=y_j\) else \(0\), and compute \(\chi^{l,gt}_k\) to measure how label-consistent a layer’s neighborhoods are. - The paper notes the qualitative trends are robust across \(k\) (Appendix Fig. S5), and uses: - proteins (SCOPe): \(k=10\), - images (ImageNet subset): \(k=30\).

System/data pipeline diagram in words (explicit flow). 1. First, choose a pretrained model (ESM-2 or iGPT) and a dataset subset (ProteinNet, SCOPe, or ImageNet subset). 2. Second, for each input example, run a forward pass and record hidden states at each block after the first normalization layer (and also consider other extraction points in Appendix Fig. S4 as a robustness check). 3. Third, average-pool hidden states across the token dimension to get one vector per example per layer. 4. Fourth, for each layer separately, compute Euclidean distances among pooled vectors to: - find first/second nearest neighbors per point for TwoNN ID estimation, - find \(k\) nearest neighbors per point for neighborhood overlap. 5. Fifth, assemble: - an ID vs relative depth curve (relative depth = block index / number of blocks), - an overlap curve \(\chi^{l,l+1}_k\) vs relative depth, - and when labels exist, a semantic overlap curve \(\chi^{l,gt}_k\) vs relative depth. 6. Finally, interpret low-ID regions / ID minima as candidate layers that maximize semantic content, and validate by showing peaks in \(\chi^{l,gt}_k\) align with those regions.

Model and dataset configurations (as provided). - The paper analyzes pretrained checkpoints; training hyperparameters (optimizer, learning rate schedule, batch size, training tokens, compute budget) are not provided in the included content, so they cannot be reported here. - Architectures used (Appendix Table 1): - ESM-2 (35M): 12 blocks, \(d=480\), 20 heads. - ESM-2 (650M): 33 blocks, \(d=1280\), 20 heads. - ESM-2 (3B): 36 blocks, \(d=2560\), 40 heads. - iGPT-S (76M): 24 blocks, \(d=512\), 8 heads. - iGPT-M (455M): 36 blocks, \(d=1024\), 8 heads. - iGPT-L (1.4B): 48 blocks, \(d=1536\), 16 heads. - Dataset subsets (Section 2.3): - ProteinNet: 25,299 protein sequences (training set) used for ID curves and \(\chi^{l,l+1}_k\). - SCOPe (Astral SCOPe v2.08): after filtering and constraints, 10,256 sequences across 288 superfamilies, used for remote homology semantic overlap \(\chi^{l,gt}_k\). - ImageNet subset: 90,000 images, constructed by sampling 300 classes × 300 images/class from the training set, used for iGPT analysis. - Compute environment (Section 2.3): - 2× Intel Xeon Gold 6226 CPUs, 256 GB RAM, 2× Nvidia V100 GPUs (32 GB each).

Worked micro-example (to make the overlap concrete). - Suppose \(N=1\) query point \(x_i\) and \(k=4\) neighbors. - If at layer \(l\) its neighbor set is \(\{a,b,c,d\}\) and at layer \(l+1\) it is \(\{b,c,d,e\}\), then 3 out of 4 neighbors are shared, so the per-point overlap is \(3/4=0.75\). - \(\chi^{l,l+1}_k\) averages this fraction across all points, giving a scalar in \([0,1]\) that reports how much the local neighbor structure changes across layers.

Design choices and why they matter (as discussed in the paper). - Average pooling is chosen to enable Euclidean comparisons across variable-length sequences (proteins) and to make image computations feasible (full token-space distance matrices would be too large); Appendix D shows pooling shifts ID downward but preserves qualitative shape (Fig. S3). - Layer extraction point choice is tested: extracting after attention vs after normalization yields consistent ID and semantic overlap trends (Appendix Fig. S4). - Using relative depth normalizes comparisons across models with different numbers of blocks.

4. Key Insights and Innovations

  • (1) A consistent “phase” structure in transformer representation geometry across modalities
  • The ID profile reveals distinct phases: an early expansion to high ID, a compression to low ID, and a decoding/reconstruction region near the output (Section 3.1; Fig. 1).
  • This is presented as a shared computational pattern across protein MLM and image autoregressive prediction.

  • (2) Linking global geometry (ID) with local structure change (neighbor overlap)

  • The paper shows that the early high-ID region coincides with substantial neighbor rearrangements (low \(\chi^{l,l+1}_k\)), while the compressed region shows stable neighborhoods (high \(\chi^{l,l+1}_k\)) (Section 3.1; Fig. 2).
  • This differs from CNN observations where major neighbor rearrangements occur mostly in the very last layers under supervised objectives.

  • (3) Semantic content peaks at low-ID intermediate regions

  • For proteins, remote homology neighborhood consistency \(\chi^{l,gt}_k\) rises through the early peak and saturates at a high value in the low-ID plateau, then drops near the output (Section 3.2; Fig. 4 left).
  • For images, class-label consistency \(\chi^{l,gt}_k\) peaks near a relative depth \(\approx 0.4\), aligning with an ID local minimum (Section 3.2; Fig. 4 right).

  • (4) An unsupervised layer-selection strategy

  • The practical heuristic is: choose intermediate layers around an ID minimum (or low-ID plateau) to maximize semantic usefulness, without needing labels (Abstract; Sections 3.2 and 4).
  • The paper motivates this as identifying the “compressed code” analogous to an autoencoder bottleneck.

5. Experimental Analysis

Evaluation methodology: datasets, metrics, setup

  • Metrics
  • Intrinsic dimension (ID) per layer via TwoNN (Section 2.1).
  • Neighborhood overlap:
    • across consecutive layers: \(\chi^{l,l+1}_k\) (Section 2.2),
    • with ground-truth labels: \(\chi^{l,gt}_k\) (Section 2.2).
  • Datasets
  • ProteinNet (25,299 sequences) for ID and \(\chi^{l,l+1}_k\) in pLMs.
  • SCOPe-derived remote homology benchmark (10,256 sequences, 288 superfamilies) for semantic overlap in proteins.
  • ImageNet subset (90,000 images, 300 classes) for semantic overlap in iGPT.
  • Key analysis hyperparameters
  • \(k=10\) for SCOPe remote homology overlap; family-exclusion is used so the probe focuses on remote homology (Section 3.2).
  • \(k=30\) for ImageNet label overlap (Section 2.2).
  • Robustness to \(k\) is shown in Appendix Fig. S5.

Main quantitative results (with numbers and where they appear)

ID profiles across depth (Section 3.1; Fig. 1). - ESM-2 protein language models (Fig. 1 left): - Early peak (roughly first third of the network) with maximum ID approximately: - ~20 (ESM-2 35M), - ~25 (ESM-2 650M), - ~32 (ESM-2 3B). - Then a sharp contraction to a plateau with ID roughly 5–7. - Then a final ascent toward values close to the input ID (interpreted as tied to MLM reconstruction). - iGPT image transformers (Fig. 1 right): - Early peak grows with model size, reaching about: - ~25 (iGPT-S), - ~28 (iGPT-M), - ~32 (iGPT-L). - After an initial decrease around relative depth 0.3–0.4, iGPT-M and iGPT-L show a local minimum around ID ~22. - In later layers, iGPT-M and iGPT-L exhibit a second shallow peak near the end; iGPT-S does not show this second peak.

Neighbor rearrangement across depth (Section 3.1; Fig. 2). - pLMs (Fig. 2 left): - Early layers (first ~40%): \(\chi^{l,l+1}_k \approx 0.5\), indicating strong neighborhood changes per layer. - Plateau region: overlap rises; in larger models, >90% neighbors are shared between consecutive layers (i.e., \(\chi^{l,l+1}_k \gtrsim 0.9\)). - Smaller model (35M) shows less stable plateau neighborhoods, linked in-text to higher perplexity. - iGPTs (Fig. 2 right): - Early stage: \(\chi^{l,l+1}_k \sim 0.7\) (more rearrangement). - Middle stage: \(\chi^{l,l+1}_k \sim 0.9\) (more stability). - Last layers: “significant neighborhood rearrangement” is observed again during reconstruction.

Training-time evolution (Section 3.1; Fig. 3). - ESM-2 (650M) checkpoints shown at [0, \(10^4\), \(3\cdot 10^4\), \(10^5\), \(5\cdot 10^5\)] training steps (Fig. 3 left caption): - The early ID peak forms quickly. - Later, the plateau compresses to lower ID, while final layers increase toward input-like ID. - iGPT-L checkpoints at [0, 0.13, 0.26, 0.52, 1]·\(10^6\) iterations (Fig. 3 right caption): - Early peak appears early in training, then a clearer local minimum develops, and later a second peak emerges.

Semantic overlap peaks at low-ID regions (Section 3.2; Fig. 4). - Proteins (SCOPe remote homology; Fig. 4 left): - \(\chi^{l,gt}_k\) is near-absent at positional embedding. - It grows during the early ID peak and reaches a stationary maximum around \(\chi^{l,gt}_k \sim 0.8\) in the plateau. - It drops sharply in the final ascent; the last hidden layer yields about 0.4 instead of 0.8 (as stated in Section 3.2). - Appendix Fig. S2 reports an additional applied result: for ProtT5-XL-U50, doing 1-NN homology search on a plateau layer improves accuracy by ~6% compared to the last layer. - Images (ImageNet labels; Fig. 4 right): - Peak semantic overlap occurs around relative depth ~0.4, aligned with the ID minimum in larger models. - Peak \(\chi^{l,gt}\) values reported: - 0.15 (iGPT-S), - 0.27 (iGPT-M), - 0.35 (iGPT-L).

Do the experiments support the claims?

  • The alignment between (i) the ID drop/minimum and (ii) peak semantic overlap is shown directly in Fig. 1 + Fig. 4 for both modalities, supporting the central claim that semantic structure is strongest in compressed intermediate representations.
  • The neighbor-overlap curves (Fig. 2) provide an additional mechanistic correlate: fast neighborhood churn during ID expansion, stability during compression, and renewed churn during decoding.
  • The training-time curves (Fig. 3) strengthen the interpretation by showing the ID peak forms early and compression develops later, consistent with a staged emergence of semantic structure.

Ablations / robustness checks / failure cases reported

  • Robustness to neighborhood size \(k\): qualitative overlap trends persist across \(k\in\{1,2,5,10,20,50\}\) (Appendix Fig. S5).
  • Robustness to extraction point: post-layernorm vs post-attention yields consistent ID and semantic overlap trends (Appendix Fig. S4).
  • Impact of average pooling: pooling lowers estimated ID but preserves curve shape qualitatively (Appendix Fig. S3).
  • The paper notes that the iGPT second peak is absent in iGPT-S and appears in later training for iGPT-L (Fig. 3), indicating some shape features depend on scale/training progression.

6. Limitations and Trade-offs

  • Representation reduction via average pooling
  • Pooling is necessary for variable-length proteins and computational feasibility for images, but it can quantitatively reduce ID (Appendix Fig. S3).
  • The paper argues the shape remains robust, yet the exact ID magnitudes may not match those from full token-level representations.

  • Dependence on Euclidean distances in pooled space

  • All geometry is computed using Euclidean distances between pooled vectors; if alternative sequence-comparison methods are used, results might shift (the paper points to other approaches in Discussion and cites [31] and [22] as explored alternatives, without analyzing them here).

  • Scope of semantic probes

  • Protein semantics are proxied by SCOPe remote homology overlap; image semantics by ImageNet class overlap.
  • These are specific notions of “semantic content,” and other tasks might highlight different layers.

  • Partial analysis of NLP / LLMs

  • The paper includes only a preliminary appendix-style exploration for Llama2-70B on SST sentiment (Appendix E; Fig. S6), where the ID profile is more complex (three peaks, two minima), but the “first ID minimum aligns with best class overlap” still holds.
  • This suggests generality but is not as thoroughly validated as proteins/images in the main sections.

  • Missing training configuration details in the provided content

  • Because the study largely analyzes pretrained models, the included text does not report optimizer settings, learning rate schedules, batch sizes, or training token counts for ESM-2/iGPT; this limits mechanistic attribution to specific training dynamics beyond what checkpoint curves show.

7. Implications and Future Directions

  • How this changes the landscape
  • It provides a concrete, quantitative bridge between:
    • global manifold complexity (ID),
    • local structural stability (neighbor overlap),
    • and semantic usefulness (label consistency), across very different transformer modalities.
  • It reframes large self-supervised transformers as behaving like autoencoders in representation space: expand → compress to an abstract code → decode for reconstruction (Discussion; the iGPT-L second peak is interpreted as making the model resemble a symmetric autoencoder).

  • Follow-up research suggested by the paper

  • Compare intermediate low-ID representations to representations from supervised models trained on the same datasets to better interpret the iGPT second peak and the semantics/geometry link (Discussion).
  • Investigate why pLMs end with an ID ascent while iGPTs can show end-of-network compression, considering objective differences (MLM vs next-token pixels) and preprocessing effects (Discussion).

  • Practical applications / downstream use cases

  • Unsupervised layer selection: use the ID profile to pick intermediate layers for:
    • protein retrieval (remote homology search),
    • image classification via linear probes or other downstream learners, without needing labels to tune the choice.
  • Representation monitoring: track ID/overlap as diagnostics when training or adapting transformers to new domains.

  • Repro/Integration Guidance (when to prefer this method)

  • Prefer this geometry-based method when:
    • you have a pretrained transformer and little/no labeled data,
    • you need to choose which layer(s) to export for embeddings,
    • you can compute kNN distances on a representative sample.
  • Concretely, based on the paper’s findings:
    • for iGPT-like models, inspect the ID curve and select layers near the relative-depth \(\approx 0.4\) ID minimum (Fig. 1 right + Fig. 4 right),
    • for ESM-2-like pLMs, select layers in the low-ID plateau immediately after the first peak (Fig. 1 left + Fig. 4 left), since semantic overlap saturates there.

If you want, I can extract from the provided content a concise “recipe” (step-by-step) for computing ID and overlap curves on a new transformer and dataset, matching exactly the paper’s procedure (pooling point, decimation factor, and \(k\) choices).