You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Say we get a batch of inputs with lengths L1,L2,... How to simultaneously compute the attention scores of these inputs by 'nopad'? That sounds amazing but I failed to figure why when reading source code.
Additionally, in the decoding phase, how do you handle different kv length?(the code suggests kv cache is of a well-formed shape [B, num heads,...], which is confusing, because different prefixes result in different length of kv cache).
I want to implement batched speculative decoding and those details are important.
Thanks. Any detail, code or pseudo code are appreciated.
The text was updated successfully, but these errors were encountered:
Thanks for your great work! Here are my concerns:
Say we get a batch of inputs with lengths L1,L2,... How to simultaneously compute the attention scores of these inputs by 'nopad'? That sounds amazing but I failed to figure why when reading source code.
Additionally, in the decoding phase, how do you handle different kv length?(the code suggests kv cache is of a well-formed shape [B, num heads,...], which is confusing, because different prefixes result in different length of kv cache).
I want to implement batched speculative decoding and those details are important.
Thanks. Any detail, code or pseudo code are appreciated.
The text was updated successfully, but these errors were encountered: