How Do Transformers Learn Topic Structure:
Towards a Mechanistic Understanding

Yuchen Li1, Yuanzhi Li1 2, Andrej Risteski3

1Carnegie Mellon University, 2Microsoft Research

contact (email)  /  paper (arXiv)  /  bibtex  /  slides  /  Twitter summary  /  codes (GitHub)

TL;DR

We develop mathematical understanding of the training process of transformers through the lens of learning topic structures. We prove that the embedding layer or attention layer alone can learn to encode topic structure.

In particular, our main conceptual idea is that the optimization process of self-attention can be approximately broken down into two stages (when all model parameters are initialized to small matrices near 0). This idea is theoretically reasonable, empirically verifiable, and makes the analysis tractable.

Experiments on synthetic and real-world data verify our theory.

Motivation

While the empirical successes of transformers across many domains are indisputable, accurate understanding of the learning mechanics is still largely lacking. Their capabilities have been probed on benchmarks which include a variety of structured and reasoning tasks—but mathematical understanding is lagging substantially behind. Recent lines of work have begun studying representational aspects of this question: that is, the size/depth/complexity of attention-based networks to perform certain tasks. However, there is no guarantee the learning dynamics will converge to the constructions proposed. Therefore, theoretical understanding of transformers requires characterizing the optimization process.

motivation.png
Methodology

We use synthetic data as a sandbox for a scientific approach to understanding deep learning models. Real data are messy and complex (containing semantics, syntax, etc). To study them in more formal manner, we focus on one aspect at a time, by studying some simple synthetic setting The benefits include controlling variables, which helps single out each factor to study. The long-term research agenda is to progressively study more realistic data distributions.

Data distribution

We study semantic structure, as understood through the lens of co-occurrences of words, and their topical structure. Precisely, if we fit topics to a real-life corpus like Wikipedia using a Latent Dirichlet Allocation (LDA) model [6], we find a pretrained BERT model produces token embeddings that are more similar (in terms of inner product or cosine similarity) if they belong to the same topic, and more different if they belong to different topics. Inspired by these observations, we study LDA-generated data as a sandbox to understand—both through experiments on such synthetic data, and theoretical results—the process by which the embeddings and attention learn the topical structure. We find that the above observations from Wikipedia data are even more pronounced on synthetic LDA data.

data.png
Model architecture

We study a simplified one-layer transformer.

architecture.png
Main theoretical results

We analyze the optimization process of transformers trained on data involving "semantic structure", e.g. topic modeling. Through theoretical analysis and experiments, we show that between same-topic words, the embeddings should be more similar, and the average pairwise attention should be larger.

theory.png
Main conceptual idea

In particular, we observe that with carefully chosen initialization and learning rate, the optimization process of self-attention can be approximately broken down into two stages: in stage 1, only the value matrix changes significantly; in stage 2, the key and query matrices catch up much later, even though all components are jointly optimized through standard SGD or Adam. This observation might be of independent interest, for future works on understanding the learning dynamics of transformers as well.

two_stage.png
Future directions

We initiated the study of understanding training dynamics of transformers in the presence of semantic structure captured by a topic model. Interesting directions of future work includes extending the analysis to data distributions that captures “syntactic” structure, e.g. through simple sandboxes like PCFGs. When both the model and the data distributions are complex, it remains a daunting challenge to “disentangle” how the many different aspects of the data (e.g. semantic and syntactic elements) are learned through the different parts of model architecture (e.g. attention, positional encodings, and embeddings).

future.png
Acknowledgements

We thank Bingbin Liu, Yusha Liu, and Tanya Marwah for proofreading and providing constructive comments, Yewen Fan for helpful suggestions on empirically obtaining the two-stage optimization process, and Emmy Liu and Graham Neubig for insightful discussions on the connections with empirical observations. Andrej Risteski and Yuchen Li acknowledge support by NSF awards IIS-2211907 and CCF-2238523. Andrej Risteski also acknowledges support by Amazon Research Award “Causal + Deep Out-of-Distribution Learning”.

References
  1. Jorge Pérez et al, 2021, Attention is Turing Complete
  2. Chulhee Yun et al, 2020, Are Transformers universal approximators of sequence-to-sequence functions?
  3. Shunyu Yao et al, 2021, Self-Attention Networks Can Process Bounded Hierarchical Languages
  4. Haoyu Zhao et al, 2023, Do Transformers Parse while Predicting the Masked Word?
  5. Bingbin Liu et al, 2023, Transformers learn shortcuts to automata
  6. David Blei, et al, 2003, Latent Dirichlet Allocation (LDA)
  7. Charlie Snell, et al, 2021. Approximating how single head attention learns
  8. Sébastien Bubeck et al, 2023, Sparks of AGI
  9. avikdas.com/2020/01/28/the-balanced-parentheses-problem.html