Sparsity allows scaling model parameters without proportionally
increasing computational cost. While mixture of experts (MoE) models are
made increasingly sparse, individual experts typically remain large and
dense. Here, we demonstrate that further increasing sparsity by
shrinking each expert to consist of a single neuron and selecting a tiny
fraction of many available neurons can improve compute efficiency and
interpretability. Counterintuitively, the key to achieving both is
removing the nonlinearity typically applied to the experts, resulting in
a network of sparsely gated linear neurons
(sgatlin). In an isoflop comparison, we find that replacing
all transformer feedforward layers with sgatlin improves
perplexity in language models across different compute budgets. At the
same time, the sparsity and linearity of the resulting feedforward
circuits present new opportunities for model interpretability. In a
small-scale case study, we demonstrate that feedforward circuits in
sgatlin can be interpreted without having to train
additional replacement models. We find that they form semantically
structured clusters and are causally implicated in factual recall. Our
findings paint a possible path towards compute-efficient and
interpretable transformer feedforward layers.
We highlight the main results in the following, you can read the full paper here.
Our sparsely gated linear neuron layer (sgatlin) consists
of two components, a gating network that efficiently computes a sparse
vector of gating weights, and a large pool of linear neurons —
the experts — that are linearly combined according to
the gating weights.
sgatlin uses the product key top-k operation introduced by
We compare language modelling perplexity of transformers with different
feedforward architectures: dense layers, MoE layers and our
sgatlin layer. In a compute-matched (isoflop) comparison
with varying model sizes of up to 4B parameters on the SlimPajama 627B
dataset
sgatlin linearly combines a small number of
linear neurons, effectively applying a low-rank linear circuit
to each position in the sequence.
Do these feedforward circuits share reusable and interpretable
structure?
Instrumental to answering this question is the insight that each circuit is uniquely identified by the sparse gating weights that created it. We can therefore analyse feedforward circuits in terms of how similar they are to other feedforward circuits by comparing their corresponding gating weights.
We investigate the extent to which a particular feedforward circuit can
be understood by relating it to similar circuits in a small transformer
trained on the TinyStories dataset
Some of the reference tokens elicit highly stereotyped circuits that simply activate for exactly the same token across all neighbours. Intriguingly, however, for several of the tokens in the reference sequence, neighbour circuits activate in different but semantically similar contexts. For instance, among the neighbours of the feedforward circuit activated for the dog token are circuits that activate for other animals.
Overall, we find that the feedforward circuits activated across sequences are reused and form semantically meaningful structure. In the following, we show a UMAP of the feedforward circuits in the penultimate layer of the 100 most frequent input tokens in a random subset of 128 unseen stories. We selectively color the unsupvervised embedding with a subset of the input tokens corresponding to a circuit.
sgatlin in the penultimate
for the 100 most frequent tokens in a random subset of 128 unseen
stories and embed them into 2D using UMAP. The resulting embedding
forms clusters that are partially explained by the semantics of the
corresponding input tokens. For instance, names used in the stories
cluster together and pronouns form clusters in their vicinity.
Read the full paper on arxiv.