Originally a while back i got some intuition with it from the query key value perspective which might help (theres also the gradient descent perspective which is good too). Scraped this from a chat with
@stochasticchasm a year ago so might be a bit dodgy.
Imagine u want to store a database of knowledge in ur state
ur gonna do this by storing key value pairs. For example on a high level, a key could be a song name, and the value associated with it is the song lyrics
when u want to extract those lyrics, u want to be able to query ur database full of key value pairs with the song name and get back the song lyrics
so ur gonna query the database using that key
im going to be assuming all keys stored in the database are orthogonal as in a 768 dimension space that assumption holds (nearly orthogonal). With orthogonal meaning that if i query with a key, i get back that keys associated value and nothing else
now why linear attention fails is because for example in a 768 dimension space, really all u can have is 768 orthogonal keys
so when ur sequence becomes longer, keys (and those key value pairs) start to interfere
so now when u query with the song name, u wont get back the song lyrics, but a linear combination of other values for other keys as well which could be unrelated
this causes the retrieval to fail as now ur getting the song lyrics and a bunch of noise with it
what retnet does is basically take ur state which is all these key value pairs added together, and scale them all down by a scalar fixed factor
so now when u query with the song name, u will get the song lyrics and noise, but if ur song lyrics were added in recently, they will have a stronger signal than the noise if that was farther back
so it prioritises recent key value pairs added in the state
the obvious issue is that if u want to query with the song name but that was stored a long time ago, its signal will be low
mamba-2 and g-retnet basically make this scalar value dependent on the sequence
so the model can learn how much to reduce the signal of all previous key value pairs. So if ur now storing an important piece of information, ur model can choose to lower the signal (decay) of the state (all previous key value pairs) so that ur new info is stored with a strong signal
then rwkv6, gla, mamba turn this scalar into a vector
So u can imagine now the model can be more expressive with its decay as it can lower the signal of aspects of the state for example
here u have to pay attention to that the vector doesnt mean each previous key value pair gets a different decay scalar, it means that ur lowering the signal of every key of those key value pairs with the same decay, but now that decay is a vector so u can decay parts of the key less and some more (talking abt parts of each key, all keys will get the same decay)
delta rule takes a whole different approach
it basically states that the ideal state update rule should selectively choose key value pairs to discard
meaning instead of the decay which acts the same on all key value pairs
we want to be able to target specific key value pairs to remove or lower the signal of more
quick clarification: linear attention is "s = s + k^T @ v" so ur state is just a sum of all key value pairs
the others are "s = w * s + k^T @ v" where w is a vector or scalar acting on the rows of s (key dimension)
the idea behind the delta rule is that if i have my state after 768 tokens, then now i can assume all keys in the state are orthogonal which is the ideal situation. Now if i want to store a new key (and its value) but its the same as one of the keys already in there, i ideally dont want to store both but instead take one out
now pay attention that just because the keys are the same or very similar, doesnt mean their values are the same
the new key being the same can be attributed to it needing to be the same bcs of the limitation of its space
like u can imagine when storing ur key value pairs, each key can choose out of the 768 choices of keys (orthogonal), then a new key comes in thats not related to any before, it needs to choose one of the 768 choices but theyre already all taken so it just chooses one
so now when u query for the song name, u get back the sum of 2 values that can be very different and thats not what u want bcs u want the value exactly ideally
so every step u have a new key value pair, and an old key value pair where those keys are the same (not in meaning but bcs of the limitations of the space). What the delta rule does is it queries the state with the new key, which means it extracts that old value. Then instead of adding in its new key value pair, it deletes that old value and adds the new one (adds to the state "k^T @ (v_new - v_old)" which equals "k^T @ v_new - k^T @ v_old" and now bcs ur state already has somewhere in there "k^T @ v_old" (bcs the keys are the same), then that old key value pair will be deleted
the important thing here is that all other key value pairs in that state are unaffected
while in the rnns like GLA, the decay is the same for all previous key value pairs at each step, of course is different at each step but the same for all previous keys and values
so now when u query the state, u exactly get back that value instead of the sum of 2 different ones
in practice though u dont want to completely delete one, so what u do is interpolate meaning instead of deleting v_old and storing v_new, u delete v_old and store "beta * v_new + (1 - beta) * v_old" instead
where beta is a scalar thats dependent on the sequence (current value)
so the model based on the actual new keys and values content, can choose how much to store of each
also can be viewed as making ur state update dependent on the content of that state, while using decay only like GLA makes that state update not dependent on that state
its like delta rule looks into the state, picks out a key value pair, and decides to forget it, with decay u just rip a piece off of each previous key value paper
of course, some issues
firstly the state update is too slow, as ur only touching one key value at a time, u ideally want to be able to get rid of like 5 at once for example which can hurt length extrapolation
also not an issue in general, but in practice the model wont query the state with an exact key and retrieve an exact value, it would query with a superposition of many keys and retrieve a superposition of many values and work with that bcs its smart
so it will retrieve some info abt the song lyrics, some info abt the songs history, some info abt other stuff and use it all however it wants to make its prediction
but the model still needs to be able to retrieve exactly what it wants which with only linear attn it cant after a certain point
gated deltanet then solves the issue of the slow update by using mamba 2 style scalar data dependent decay
that means ur doing the delta rule on the state, then on tope of that decaying all key value pairs in that state at once with the same value
this lets the model do stuff like a full delete of the state if it wants to
Then again KDA and rwkv7 make the decay a vector (only worth it if u can make it fast)
also, note that the gated deltanet/KDA decay also helps with the fact that ur still storing interpolations of values which is better than storing both but worse than storing one of them if u want exact retrieval
meaning, imagine the model was doing a math question, then u asked it abt space, its not fast enough to make a large delete in its state so has to keep info abt the math even if it doesnt want to, so the decay helps it make a quick sweep
its still a fixed state size so of course it has to lose info but linear attention cant lose info which is bad decay on its own can lose info but has to lose the same amount of info from all key value pairs, delta rule can lose info but target specific key value pairs it wants to lose info from while keeping all others perfectly stored and yes u lose some info with the interpolation but not as much as u think as u have many layers and its all a soup of info
so it could store parts of info in one key value pair, then another part in another then retrive both and mix or whatever
but the main idea is u specifically target key value pairs to lose info from