Member of Technical Staff at OpenAI. Formerly Staff Research Scientist at Google DeepMind. Ex-Physicist.

The Training team @OpenAI is hiring researchers in London 🚀 Our twin missions are to train better LLMs, and serve them more cheaply Get in touch if you are excited to collaborate on architecture design, reliable scaling, and faster optimization
11
38
496
93,316
First day at OpenAI London 😁!
27
6
479
47,754
Proud to be a part of NFNets, a new ImageNet SOTA: - does not use BatchNorm, LayerNorm, GroupNorm, anyNorm! - 86.5% top-1 w/o extra data - 89.2% top-1 w/ pre-training - 8.7x faster than EffNet-B7 to same test accuracy arxiv.org/abs/2102.06171 code: dpmd.ai/nfnets 1/4
2
62
354
Announcing RecurrentGemma! github.com/google-deepmind/r… - A 2B model with open weights based on Griffin - Replaces transformer with mix of gated linear recurrences and local attention - Competitive with Gemma-2B on downstream evals - Higher throughput when sampling long sequences
9
62
271
178,020
Replying to @srush_nlp
I think there is a well established answer. We initialize and optimize deep networks in such a way that the model explores simple functions before complex ones. Although overfit functions exist in weight space they are usually much harder to find.
12
6
216
17,348
I wish "ResNets" and "Transformers" were called "ResConvs" and "ResAttn". ResNet should be an umbrella term for any deep network with a repeating pattern of skip connections and residual branches.
8
13
197
19,117
I ambushed a theory workshop with a tutorial on scaling LLMs: piped.video/GfAT2zkB6-U?si=HVuy… Covers transformers, a simple model of how TPUs work, how to train models that don't fit on a single device, scaling plots and how training/inference differ
8
27
195
20,187
Incredibly excited to announce Hawk and Griffin (arxiv.org/abs/2402.19427), two recurrent language models with 1) finite sized state + fast inference 2) efficient training on device 3) excellent performance:
6
23
153
27,513
The Stochastic Gradient Descent we use in practice, SGD with Random Shuffling, is not a Stochastic Differential Equation when the learning rate is small. Instead, it follows the path of gradient flow on a regularized loss: arxiv.org/abs/2101.12176 (Mea Culpa at ICLR 2021)
2
25
141
ConvNets Match Vision Transformers at Scale: arxiv.org/abs/2310.16764 We scale NFNet pre-training on JFT-4B from 0.4 to 110k TPU-v4 core hours. After fine-tuning, our largest model achieves 90.4% ImageNet Top-1, competitive with ViTs pre-trained for similar compute budgets. 1/3
4
18
118
96,436
RecurrentGemma-9B is out! kaggle.com/models/google/rec… huggingface.co/google/recurr… - Uses Griffin architecture, combining linear recurrence with local attention - Downstream evals comparable to Mistral and Gemma - Faster inference, especially for long sequences or large batch sizes 1/n
2
25
119
32,543
Replying to @fchollet
The lesson we took from working on Griffin (arxiv.org/pdf/2402.19427) is that current model performance is bottlenecked by the channel mixing component (ie the MLP), not the sequence mixing component (ie Attention vs recurrence)
3
8
57
8,009
Want to discuss how SGD implicitly regularises NN training, or how to train ResNets without BatchNorm. Come join our two posters @iclr_conf today (Tuesday) at 5-7pm UK time: SGD/Implicit Regularization: iclr.cc/virtual/2021/poster/… Norm-Free ResNets: iclr.cc/virtual/2021/poster/…
7
44
Replying to @srush_nlp
You could give them a take home 😀: form a dataset of 5 MNIST 1's and 0's. Train a small MLP to catastrophically overfit (100% train accuracy, <55% test accuracy). Spoiler: this is possible if you initialize the first weight matrix too large and use a v small learning rate.
2
45
2,179
Our models achieve remarkable performance, but they can also leak sensitive information about individual training examples. In recent work, we significantly improve the performance of networks trained with strict differential privacy (DP) guarantees: arxiv.org/abs/2204.13650 1/n
2
5
37
Arrived at ICML 2022! If you want to discuss hyper-parameter tuning and generalization in SGD, ResNet initialization, differentially privacy in deep learning, or whatever you think I'd find fun, get in touch :)
3
31
Alternatively, you could take a sensibly initialized ResNet with norms in the right places and plot the output function at initialization. Then grow depth and width. The output function at initialization should not become more complex with increasing size.
2
35
1,431
My two takeaways: 1) The most important thing when training deep networks without tricks, is to get the initialization scheme right! 2) Theoretical work can have tangible practical benefits, but this is much more likely when theorists and practitioners collaborate closely. 4/4
1
6
32
Transformers without skip connections or normalization layers 😀 A very successful internship project which I thought was much too hard! Well done @bobby_he
Can deep transformers be trained without skip connections nor normalisation layers? Our ICLR 2023 paper shows you how, using wide NN signal propagation ideas. We hope this can potentially pave the way to more efficient deep LLMs! (1/9) Paper: arxiv.org/abs/2302.10322
29
4,837
Building on ideas from SSMs and LSTMs, Griffin matches transformer performance without global attention, achieving faster inference on long sequences. arxiv.org/abs/2402.19427 See @sohamde_'s great thread for more details:
Just got back from vacation, and super excited to finally release Griffin - a new hybrid LLM mixing RNN layers with Local Attention - scaled up to 14B params! arxiv.org/abs/2402.19427 My co-authors have already posted about our amazing results, so here's a 🧵on how we got there!
1
2
28
3,299
Replying to @tomgoldsteincs
In many cases if you took a published model "w/o early stopping" on ImageNet and trained it for 10x longer, test accuracy would fall. People increase augmentation/regularization strength until they can make full use of their compute budget.
3
24
If you properly ablated all the differences between Vision Transformers and up-to-date ResNets (eg NFNets/EffNet-V2), constraining a practical measure of compute cost like wall-clock time, I'm pretty sure you would end up with a ResNet and a modified training pipeline!
Is attention really behind the success of Vision Transformers? We think maybe... Patches Are All You Need? 🤷 Check out ConvMixer, the first model that achieves 82%+ ImageNet top-1 accuracy while also fitting into a tweet! arxiv.org/abs/2201.09792 With @zicokolter. 1/4
4
1
22
Replying to @agihippo
Griffin Author here. We had the core results, i.e. RNN-Attention hybrids matching Transformers up to 7B with matched training speed and faster inference since June 2023, but didn't get publication approval until Feb 2024. So Google gave itself a healthy head start!
2
21
1,185
We provide efficient jax code for RecurrentGemma, which can also be used for general Griffin models. This includes a memory efficient implementation of the linear recurrence in Pallas, with which we match the training speed of transformers on TPU github.com/google-deepmind/r…
1
2
19
1,741
We don't care about any specific credentials. We just want phenomenal people. We are a small team so currently looking to hire experienced researchers able to work independently (eg L4 minimum, ideally L5+). As we grow I hope to start hiring interns/residents in London as well.
1
20
2,664
Replying to @ZyphraAI @NVIDIAAI
Impressive results, but why are you comparing to Gemma-1 and not Gemma-2 (MMLU 71%)? It would also be interesting to see an inference speed comparison with RecurrentGemma!
2
18
1,843
A huge thank you to our brilliant team, especially @botev_mg @sohamde_ Anushan Fernando @GeorgeMuraru @haroun_ruba @LeonardBerrada Finally, I'm sure there will be a few rough edges in the initial release, we will try to address these as fast as we can!
1
2
17
1,368
Finally my intern @stanislavfort studied how averaging per-example gradients over multiple data augmentations reduces variance and speeds up convergence: arxiv.org/abs/2105.13343 This trick is particularly valuable in private training, where we pay a price for every gradient 8/n
1
2
17
Mixture of experts (which improve the MLPs but don't touch sequence mixing) significantly outperform dense models Flop for Flop Meanwhile the gap between different sequence mixing mechanisms (ie Transformer vs Griffin or Mamba) is pretty small
1
17
781
In RecurrentGemma, we provide two 2B model checkpoints: - A pretrained model, trained for 2T tokens - An instruction tuned model for dialogue We train on the same data as Gemma-2B for fewer tokens, achieving comparable performance. Technical report: storage.googleapis.com/deepm…
1
2
15
1,976
Pretty sure we already have 😂
🌶️ hot take 🌶️ > we should normalize training on the test set yes, you read that right. no, I'm not joking. and, yes... I have taken ML 101 👉 here's why this is crucial for future multimodal LLM research [1/n] 🧵
2
1
20
5,732
A great collaboration with @ajmooch, @sohamde_ and Karen Simonyan. This work is the third in a series of papers which first study and then remove batch normalization layers from ResNets. See: De and Smith, arxiv.org/abs/2002.10444 Brock et al., arxiv.org/abs/2101.08692 2/4
1
1
15
We also provide a reference implementation in Pytorch, though we recommend the jax code for best performance. We hope RecurrentGemma provides an alternative to pre-trained transformers, and enables researchers to explore the capabilities of this exciting new model family.
1
1
14
1,506
Alex is phenomenal + incredibly productive. One to watch!
We present Griffin: A hybrid model mixing a gated linear recurrence with local attention. This combination is extremely effective: it preserves all the efficient benefits of linear RNNs and the expressiveness of transformers. Scaled up to 14B! arxiv.org/abs/2402.19427
14
717
Looking forward to presenting our ICML oral "Resurrecting Recurrent Neural Networks for Long Sequences" in the Wednesday morning poster session. Find us at #521 learn how to make linear RNNs match SSMs and outperform Transformers on the Long Range Arena! icml.cc/virtual/2023/oral/25…
1
1
13
1,235
Replying to @srush_nlp
In defence of your blog, we are confident linear scans are faster on TPU but we aren't sure about GPU. That said I do think the field has oversold the important of parallel scans, since recent RNNs/SSMs are diagonal and therefore memory bound, not FLOPs bound.
3
13
1,341
Here is the throughput of RecurrentGemma-9B vs Gemma-7B on a single TPU-v4, using a prefill of 2k tokens. Note the y-axis is logarithmic! RecurrentGemma can perform inference at much larger batch sizes, achieving 80x higher throughput when sampling 2k tokens! 4/n
2
12
952
We note that the inference benefits of RecurrentGemma-9B over Gemma-7B are particularly large because Gemma-7B used multi-head attention to maximize downstream performance, whereas Gemma-2B used multi-query attention for faster inference. 5/n
1
13
827
RecurrentGemma-9B was trained for 2T tokens, whereas Gemma-7B was trained for 6T tokens. The token budget of Mistral-7B is not public. Confusingly, both Gemma-7B and RecurrentGemma-9B have roughly 8.5B total parameters (after accounting for embeddings)! 2/n
2
12
737
Replying to @giffmana
compute, data, good signal prop
1
12
1,724
The associated online lectures are how I got into ML: videolectures.net/david_mack…
A very unique textbook "Information theory, inference and learning algorithms" by Sir MacKay combining Information Theory with Machine Learning. Nicely written. Chapter 27 on Laplace's method is 2 great pages! Book PDF available at: inference.org.uk/itila/book.…
1
9
Thank you to Ferenc for this nice summary of our recent analysis of finite learning rates in Random Shuffling SGD!
My note on Smith et al (2021): On the Origin of Implicit Regularization in Stochastic Gradient Descent - a cool paper about modeling the behaviour of SGD just accepted to ICLR inference.vc/notes-on-the-or…
1
4
10
This is correct 👍😂 Griffin is a hybrid of the speed of a Hawk/RNN and the strength of a Lion/Transformer
1
11
1,451
A huge thank you to our incredible team! This is one of the most enjoyable projects I have ever worked on, and I learned so much from all of you. @sohamde_ @botev_mg @GeorgeMuraru @_albertgu @LeonardBerrada @yutianc @ArnaudDoucet1 @davidmbudden @yeewhye @NandoDF @caglarml
10
603
Andy is *the* most productive researcher I have ever collaborated with. Really happy with how this project turned out!
Normalizer-Free ResNets: Our ICLR2021 paper w/ @sohamde_& @SamuelMLSmith We show how to train deep ResNets w/o *any* normalization to ImageNet test accuracies competitive with ResNets, and EfficientNets at a range of FLOP budgets, while training faster. arxiv.org/abs/2101.08692
9
Replying to @tri_dao
It's surprisingly under-discussed that the decision is not whether to train the whole model in BF16 or FP32, the decision is which operations can be in BF16 and which can't
1
9
1,051
Replying to @jxmnop
The reason this mistake breaks deep ResNets is well known. Lots of papers showed why but I think ours is the simplest: arxiv.org/abs/2002.10444 In short, deep ResNets are trainable if the activations on branches are much smaller than the activations on the skip connection.
9
215
Replying to @_arohan_
This was coincidental timing. The switch from async to sync training was because "ImageNet in One Hour" figured out the correct learning rate vs batch size scaling for SGD. (arxiv.org/abs/1706.02677)
1
1
9
891
We're comparing to the strongest Mamba-3B model, released after the paper came out/trained for 600B tokens: nitter.app/tri_dao/status/1734612… The results reported in the original Mamba paper on 300B tokens are significantly weaker
With @_albertgu, we’re collaborating with @togethercompute and @cartesia and releasing a Mamba 3B model trained on 600B tokens on the SlimPajama dataset (Mamba-3B-SlimPJ). It’s among the strongest 3B models, matching the performance of strong Transformers (BTLM-3B). 1/
2
9
738
Here is the sampling latency of RecurrentGemma-9B vs Gemma-7B at batch size 1 on a single TPU-v4. We use a prefill of 4k tokens. (This is actually the worst case scenario for RecurrentGemma, which particularly excels at large batch inference) 3/n
1
9
682
Replying to @agihippo
The people in the Brain residency were extremely talented, but the number of very successful residents tells me that there are a huge number of people with the talent to be an "elite" AI researcher if they had access to the insider knowledge of a top industry lab
7
537
Replying to @SiaAhmadi1
In prior work, small has sometimes meant "infinitesimal" 😀
1
5
During inference, both Hawk and Griffin achieve lower latency and higher throughput than MQA Transformers, especially when sampling long sequences. The plot below shows the max throughputs of our 1B models when sampling from an empty prompt:
1
1
8
510
Replying to @cloneofsimo
Citation count is a bit meaningless for papers relevant at scale, since the labs using them don't publish much, but this paper is well known! Other good ones: arxiv.org/abs/1907.04164 arxiv.org/abs/2006.15081 (more informative than our better known arxiv.org/abs/1711.00489)
8
476
A huge thank you to @giffmana for converting ViT compute budget estimates in arxiv.org/abs/2106.04560 and arxiv.org/abs/2305.13035 from TPU-v3 to TPU-v4 to help us compare NFNets and ViTs directly. 3/3
7
671
Replying to @RogerGrosse
I think I strongly disagree with this. I’ve seen a lot of progress come from hill climbing on one set of benchmarks, finding/making new ones, hill climbing again. What I’ve seen is that subfields with bad benchmarks make less progress.
1
6
381
Ie giving up all to all mixing doesn’t give you the speed up you’d expect during training (and hurts on a tail of rare tasks during eval) This is fundamentally because we have to devote so much compute to the MLPs, that the efficiency of sequence mixing is less important
7
289
Replying to @giffmana
To be fair, a lot of papers (not including ours) that claim to plot FLOPs are actually plotting 6*parameter count 😅
1
7
447
Replying to @prfsanjeevarora
Theorists are like Physicists. They become incredibly valuable the moment they realize they are in the wrong field!
7
932
Replying to @roydanroy
I know this is just a joke, but as someone who has worked on both ML theory and designing/scaling performant models, I find the latter is just as intellectually interesting and challenging as the former!
1
7
1,193
Crucially, our recurrent layer is fast to compute in practice on TPU-v3, which ensures that both Hawk and Griffin match the training speed of highly-optimized Transformers. In fact, on long sequences our models train faster!
1
1
7
613
Hawk is a pure recurrent model based on the Real Gated Linear Recurrent Unit (RG-LRU), a novel gated linear recurrent layer proposed in this work. Griffin is a hybrid model which includes one local attention block for every two RG-LRU layers.
1
1
7
735
Replying to @srush_nlp
People pay way too much attention to the loss early in training! Lots of tuning decisions (eg larger learning rate/larger weight decay/smaller batch size) tend to make slower progress early in training and faster progress late in training.
1
7
545
An important caveat: recent RNNs (eg Mamba and Griffin) are memory bound. Parallel scans reduce FLOPs but they can't reduce memory accesses. The key trick isn't parallel scans, it's replacing dense recurrences with sparse ones, which moves most FLOPs into feedforward layers.
1
4
969
We see a power law between JFT held out loss and pre-training compute, just like the scaling laws for language modeling. We also find scaling model size and training epochs at the same rate is compute-optimal, as previously observed for Chinchilla (arxiv.org/abs/2203.15556). 2/3
1
5
906
The way I see it, diagonalizing an RNN effectively parallelizes it anyway (regardless of what scan you use), since this moves almost all of the FLOPs out of the recurrence and into linear layers.
1
1
6
4,663
We scale Hawk to 7B parameters, and Griffin to 14B. Both models exhibit power law scaling, just like Transformers! Griffin achieves lower held out loss than a strong transformer baseline across all model sizes, while Hawk closes the gap as we scale training FLOPs.
1
1
5
541
@ylecun the work Elad posted interleaves fixed convolutions over the sequence with learnable MLPs. So it's not that learning doesn't happen, it's just that it doesn't happen in the part of the model which transmits information along the sequence
5
380
You hurt me Antonio 😅
1
5
260
My first project with the amazing @sohamde_ showed how to train deep ResNets without batch normalization by biasing the signal towards the skip path: arxiv.org/abs/2002.10444 These insights are crucial for private training, which is incompatible with batch normalization 6/n
1
5
You can train much larger image classifiers with DP than previously thought if you: 1) Increase batch size 2) Improve ResNet initialization 3) Reduce gradient variance from data augmentation When fine-tuning, DP can achieve performance competitive with non-private networks 3/n
1
5
Replying to @srush_nlp
The advantage of a diagonal linear recurrence is that you can pre-compute almost everything in parallel before performing a lightweight sequential scan. But so long as you don't add too many FLOPs in the recurrence then yes I think you could make the recurrence nonlinear
1
5
371
@sohamde_ has given a great overall summary of the three projects here: nitter.app/sohamde_/status/136021… @ajmooch discusses some interesting aspects of the most recent paper here: nitter.app/ajmooch/status/1360220… 3/4
Our most recent work on training Normalizer-Free nets! We focus on developing performant architectures which train fast, and show that a simple technique (Adaptive Grad Clipping, or AGC) allows us to train with large batches and heavy augmentations and reach state-of-the-art.
1
5
Also do/don't work is too simplistic. Griffin and Mamba match Transformers on eval loss and typical downstream evals + have significantly higher throughput during inference. But Transformers unsurprisingly are better at long ranged retrieval. The best model depends on the task.
4
339
Thank you to my collaborators @sohamde_, @LeonardBerrada, @_jamiedh and @BorjaBalle We are excited to see the community is already building on our ideas (e.g. arxiv.org/abs/2203.00324) in our shared mission to make high-accuracy differentially private training a reality 9/n
5
A reviewer for "Don't decay the learning rate, increase the batch size" asked me to compare to Hogwild. I said there was no need because people would stop using async training now, reviewer wasn't particularly happy about it but I was right! (arxiv.org/abs/1711.00489)
1
1
5
431
Our work capturing the path of SGD (at small finite learning rate) as gradient flow on a modified loss also feels very related to this: arxiv.org/abs/2101.12176
1
5
1,231
Replying to @giffmana
I think the problem is Adam incorrectly claimed that there was a single reliable learning rate for everything, and then people unknowingly tuned the batch size/model width etc to make that learning rate work with their models, thus convincing themselves it was true.
1
5
824
One of the most talented and fun researchers I have ever worked with
🚀 Thrilled to announce: I'm now with ELLIS as a PI & MPI for Intelligent Systems as an Independent Group Leader! 🌟 Tübingen is such an amazing place. On a hunt for PhD candidates passionate about deep learning & optimization! Interested? Slide into my DMs! 🔍 @ELLISforEurope
1
5
1,036
Replying to @jxmnop
This is such an underrated paper! I remember talking to Joel literally one on one at his NeurIPS poster, and even shared the paper with @quocleix afterwards. When the scaling laws paper came out 2 years later I was so mad at myself!
1
5
1,477
Replying to @NandoDF @fchollet
Yes, in practice I would recommend always using some global attention, as it’s better for long range retrieval. But this is also partly a consequence of MLPs being the main bottleneck. A limitation of recurrences is that they don’t speed up training much in practice (1/n)
2
5
997
Replying to @wightmanr
yes! you are right in hindsight the previous paper is not completely clear on this point. I'll add an update to recent draft soon
1
3
164
My first project at Google found batch size is relatively unimportant for SGD if the learning rate is properly tuned: arxiv.org/abs/1711.00489 However the privacy accounting behind DP-SGD introduces noise whose scale falls as batch size rises, favoring large batch training 5/n
1
4
Replying to @arimorcos @j_foerst
People forget that Google had already deployed neural machine translation before the transformer paper using a deep LSTM: arxiv.org/abs/1609.08144
1
3
208
We extended this line of work with @ajmooch to introduce NFNets, highly performant un-normalized ResNets which are incredibly fast on TPU: arxiv.org/abs/2102.06171 We use both NF-ResNets and NFNets to get our strongest results with DP 7/n
1
4
Working with you @caglarml was definitely one of my highlights! I hope we keep collaborating in the future🤞
1
4
287
Yes, whoever hired this team did a really excellent job!
3
1,058
Of course, in reality the batch size vs learning rate scaling had first been found in the 90s, but forgotten.
1
4
374
This project is particularly special to me, because it combines the ideas from many of my and my colleagues previous projects while tackling a completely new problem, a beautiful demonstration of the close connections between different areas of ML research. For example: 4/n
1
4
S4 actually benefits from parallel scans because h_{t} = W h_{t-1} + f(x_t), where W is a matrix. S4D/S5 diagonalized W + turns out this works just as well! f(x_t) can be pre-computed in parallel, so now you have O(1) recurrent FLOP per memory access -> hence memory bound.
2
1
4
531
When training without extra data, we achieve the largest improvement in the SOTA on CIFAR-10 to date, of +9.7% under a "privacy budget" of (8, 1e-5)-DP When fine-tuning a pre-trained NFNet-F3, we achieve a remarkable 86.7% top-1 accuracy on ImageNet under (8, 8e-7)-DP 2/n
2
4
Replying to @rimfo @neu_rips
Yes, this is how we interpret our result. We show that the path taken by constant learning rate Random Shuffling SGD stays close to the path taken by Gradient Flow on the regularized objective.
1
3
Unfortunately Pallas's GPU backend is not very well optimized yet, so currently our training speed on GPU is a bit slower.
1
3
114
Replying to @LChoshen
The model conditioning improves early in training, so the largest stable learning rises, hence warm up. Consistent with this, better conditioned models (eg resnet-v2) often don’t need warmup.
3