Announcing FlashFFTConv: Efficient Convolutions for Long Sequences with Tensor Cores!
We speed up exact FFT convolutions by up to 7.93x over PyTorch, reduce memory footprint, and get 4.4x speedup end-to-end. Read on for more details:
Thanks
@arankomatsuzaki and
@_akhaliq for tweeting it out yesterday :)
The key idea: map the FFT onto tensor cores using a Monarch decomposition -- which allows kernel fusion for long sequences, and uses fast matmul units to compute the FFT (pic 2).
The FFT convolution allows us to compute the convolution in O(N log N) time, instead of O(N^2) from computing it directly (as in PyTorch nn.Conv1d). With advances in gated convolutions & gated SSMs, this means that we have a sub-quadratic alternative to attention that scales well!
But the FFT convolution has a critical flaw for ML - it has low hardware utilization on GPUs. We're talking several times slower than FlashAttention/FlashAttention-v2 until you get to very long sequences.
There's two critical bottlenecks: the FFT convolution incurs a lot of expensive I/O to store intermediate outputs, and it doesn't use tensor cores (16x faster than general arithmetic on A100/H100)! Kernel fusion can partially address the I/O problems (32K & shorter), but you run into SRAM limits at long sequences.
FlashFFTConv addresses these drawbacks, and achieves speedup over FlashAttention-v2 at sequences as short as 2K.
The core idea is that a Monarch decomposition allows us to break up the FFT into smaller parts that can be computed using matrix-matrix multiply (pic 3).
This decomposition brings another benefit: Since the FFT is being split into smaller parts, you can avoid SRAM limitations even for long sequences -- we support sequences up to length 4 million.
Now, there's a good old-fashioned systems tradeoff with this Monarch decomposition. You can recurse on the decomposition to break the FFT down into more parts -- compute the FFT in 2 parts, vs 3 parts or 4 parts. Higher-order decompositions reduce your FLOPs, but require more I/O between intermediate steps. This introduces a natural tradeoff, which we model using a cost model that takes both compute time and I/O time into account (pic 4).
(There's another tradeoff for higher-order -- if your sequence is too short, you actually get to matmuls that are too small to fill up your tensor cores).
Our cost model gives us a natural guide -- we can use a order-2 decomposition for sequences up to 4K, then order-3 for sequences up to 32K, and then order-4 for sequences up to 4M. Even longer sequences may require even higher-order breakdowns!
The upshot of all of this is end-to-end speedup and higher utilization. We see up to 4.4x speedup for long sequence models (where a bunch of the gain also comes from memory reduction).
With these gains, long convolutional models are now competitive with FlashAttention-v2 at sequences as short as 2K -- with more speedup for longer sequences.
FlashFFTConv is already being used to support several internal research projects, and we'll be releasing new long-sequence models trained with FlashFFTConv this week -- stay tuned for more!
With great collaborators
@KumbongHermann,
@exnx, and
@HazyResearch. Hermann is applying to PhD programs this year -- he's a beast, you should hire him!
Shoutouts to
@tri_dao,
@BeidiChen for discussions on early versions of this work and always pushing the boundaries of ML & systems :)
Supported by resources from
@StanfordAILab,
@StanfordCRFM,
@StanfordHAI. Developed in collaboration with the great folks at
@togethercompute, and soon available for select models in their new inference API.
Check out our paper, blog, and code for more details:
Paper:
arxiv.org/abs/2311.05908
Blog:
hazyresearch.stanford.edu/bl…
Code:
github.com/HazyResearch/flas…