i just had a deadline moved so i'll just poast it here and see what you guys think. this also has to do with the original experiment i was doing a few weeks ago that had the attention leak. tagging
@clashluke cause i had promised to tell him first. i'll just make it one big post, nobody's gonna read it anyway. Basically my idea comes from the obvous realization that self-attention is just a specific GNN layer (specifically a GAT layer) over a fully connected graph. This is true because when computing QK^t you're basically building a soft adjacency matrix. As all GNNs, GATs are also subject to expressivity limitations, with a theoretical bound given by the WL-test for graph isomorphism. This implies that there are structures that attention simply cannot disambiguate. In a certain sense, the NLP community has already geared up for this paradigm in one direction: if you sparsify the graph (you control its topology) or adjust its aggreagation functions, you get the various attention variants (sliding window, GQA, etc...). the one direction the community didn't seem to go in was in the other direction of the spectrum: increasing order of the topology BEYOND just a full graph. the reason i say this is because, as it turns out, if you manage to insert higher order information in GNNs, you get to cross the bound of WL-test (actually, for order n you are bounded by n-WL). this means that you can actually get another possible direction of being compute-optimal. of course, depending on how you implement this, you get a certain memory increase. but you can possibly find a new optimal configuration in the memory-expressivity axis that you didn't do before. i actually tried, in my very spare time, to implement this myself, in two distinct ways. 1) i was trying to use the TopoX library from
@ClaBat9 et al. (actually i had also bothered Claudio on linkedin some time ago about this idea, as well as
@PetarV_93 and
@s_scardapane ) to basically substitute the attention layer with a Simplicial Convolutional Network. of course the problems here were: -lifting the graph (i was using a vietoris-rips-like thresholding mechanism) and topology (i was controlling the total number of simplices). In my experiments with the small gpt-2 on shakespeare this didn't end up working well, but i'm convinced this may not necessarily mean much. 2) the second apporach was more straight forward (and gave light to my famous blunder with attention leak). basically i just noticed that QK^t is just a tensor contraction, and you can write it as einsum(bhik, bhjk) -> bhij. this means that to represent higher order interaction, one could just think of increasing the order of this contraction by contracting into a , for example, 3d tensor, as follows: einsum(bhik,bhjk,bhlk) -> bhijl. this of course also implies that V becomes a tensor now, so that it can be contracted to have regular representations as outputs. i've also tried this and, after correcting the causal mask, didn't seem to yield much. also remember i'm doing this maybe for 1 hour a week if i'm lucky, so it's likely that the code is all buggy and everything. i think, in general, it's also itneresting to have the topological direction be more intrinsic in the models. after all, you cna just see self-attenin + MLP as GNN composed with single node GNN. if we managed to remove the GNN, we could think of having actual topological representations, instead of point clouds as we have now. a new direction i thought of in latest days is not only to build topology on top of the algebraic structure (in this case a vector space), but also changing the underlying topology with something like hyperbolic, but i'm still too ignorant about it to really try and implement it. if you got this far, you may as well tell me what you think. Of course take this as just some guy doing this as a fun hobby.