- 
                Notifications
    You must be signed in to change notification settings 
- Fork 13.4k
kv-cache : separate recurrent vs non-recurrent impl #12799
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
e4a626a    to
    d953616      
    Compare
  
    ed8942a    to
    2c3547e      
    Compare
  
    7414574    to
    d31e31d      
    Compare
  
    b37b295    to
    dec80ac      
    Compare
  
    66f1ba6    to
    65cde6d      
    Compare
  
    | What the reasoning for using  On a more general note, I think it is not very usual the way  | 
| 
 The public API currently works with  I think what we need to do in a follow-up PR is: 
 At this point, a completely new recurrent-specific implementation can be added:  The existing recurrent cache implementation has to be rewritten from scratch, because is was hacked on top of the KV cache implementation by repurposing the K and V tensors for the state space requirements. 
 I'll try to update this. Just to make sure, you mean the current: 
 to become interfaces with different implementations based on the type of memory? | 
| I don't fully understand the code, but I think  | 
| 
 There will need to be some top-level type which can contain multiple types of KV caches to ease supporting hybrid models. A shared interface for recurrent and non-recurrent state caches is useful to get to that point, at least for maintainability. The hardest part will be handling errors and properly keeping coherency between the different types of caches (because they don't necessarily roll-back states in the same way). That is relevant mostly for hybrid models, though. 
 Yes it will need to be rewritten at least to be able to support proper state rollback. But even if it was repurposing the K and V tensors, there are still some things which I think will remain, since Mamba and RWKV do have 2 types of recurrent states per layer. | 
e37f112    to
    7e4b545      
    Compare
  
    73df685    to
    eb623f2      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes look good. While testing this, I noticed that the KV cache is always allocated on the CPU.
        
          
                src/llama-kv-cache.cpp
              
                Outdated
          
        
      |  | ||
| ////////////////////////////////////////////// | ||
| // TODO: this should not mutate the KV cache ! | ||
| kv_cell & cell = const_cast<kv_cell &>(cells[i]); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| kv_cell & cell = const_cast<kv_cell &>(cells[i]); | |
| kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]); | 
Otherwise multi-user inference is broken for recurrent models. See #9126 (comment).
        
          
                src/llama-kv-cache.cpp
              
                Outdated
          
        
      |  | ||
| ////////////////////////////////////////////// | ||
| // TODO: this should not mutate the KV cache ! | ||
| kv_cell & cell = const_cast<kv_cell &>(cells[i]); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| kv_cell & cell = const_cast<kv_cell &>(cells[i]); | |
| kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]); | 
Same, this should fix multi-user inference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add a small multi-user test with a recurrent model to server/tests to be able to spot such regressions.
ggml-ci
ggml-ci
ggml-ci
ggml-ci
780d6fb    to
    58115a2      
    Compare
  
    58115a2    to
    7e79a42      
    Compare
  
    | @slaren @compilade I think this should be good to merge - any additional comments? | 
| 
 I think that when we introduce the  | 
| // make the outputs have the same order they had in the user-provided batch | ||
| // note: this is mostly relevant for recurrent models atm | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's also only relevant when using get_embeddings, because the buffer in that case has to be ordered to keep the API backward compatible. When purely using get_embeddings_ith, it's not required.
Unconditionally sorting is unnecessary and is likely slower. Also it seems like some assertions here break multi-user inference for recurrent models (since the line right after this block where n_outputs = n_outputs_all is assumed to have run before the sorting routine, but it hasn't).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main reason to decide to always reorder is because otherwise we have to maintain the sbatch in the state of the context. This introduces some complexity that is hard to reason around so I decided to take the hit.
We should add a test that exercises this branch. What is a server scenario that would trigger the reordering?
I'll PR the n_outputs = n_outputs_all before the sorting fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main reason to decide to always reorder is because otherwise we have to maintain the
sbatchin the state of the context. This introduces some complexity that is hard to reason around so I decided to take the hit.
Right. I've measured with perf, and it seems like it's only 0.05% of the time which is spend sorting when running perplexity with mamba-130m.
The worst case is with https://huggingface.co/delphi-suite/v0-mamba-100k, but even then sorting only takes 1.3% of the CPU time when running the hellaswag benchmark with 128 parallel sequences.
It could also be possible to copy sbatch->out_ids somewhere in llama_context at the end of each decode to restore the previous behavior of lazily sorting. But the max speed gains are negligible.
We should add a test that exercises this branch. What is a server scenario that would trigger the reordering?
Whenever a non-simple split (from sbatch) is used, there's a chance the outputs won't be in the same order as the input batch. That happens with recurrent models when there are multiple sequences decoded at once.
I think simply using a recurrent model (even https://huggingface.co/delphi-suite/v0-mamba-100k, although it might not be coherent enough) in the server tests would trigger this branch, assuming there's a test case which uses multiple sequences at once.
But the simplest way to test this is with llama-parallel, as in #13267 (review).
Overview
Attempting to make two separate classes for the 2 types of KV cache:
llama_kv_cache_unified : llama_kv_cachellama_kv_cache_recurrent : llama_kv_cacheThe main goal of this change is to simplify the logic in the primary
llama_kv_cache_unifiedclass so that we can more easily extend it with new features such as SWA. Also to introduce a certain level of abstraction that would allow to add new types of KV cache implementations in the future.Main changes
The
llama_contextnow operates with the abstractllama_memory_iinterface.Add
llama_memory_paramsand use it to implementllama_model::create_memory()for creating model-specific cachellama_kv_cache_recurrentis currently mostly a copy ofllama_kv_cache_unified, but should be now completely separated and a new recurrent-specific implementation can be doneMove KV cache shift and defrag code from
llama_contexttollama_kv_cache_unifiedThe
llama_sbatch->llama_ubatchlogic insidellama_context:decode()is now implemented by:llama_kv_cache::sbatch_init()llama_kv_cache::ubatch_next()The thinking is that certain KV cache implementation could require different types of micro-batching (e.g. same-sequence-length ubatch, single-sequence ubatch, etc.)
Remove
llama_context::output_reorder()- seemed to be relevant only for recurrent caches. We now have inlined the logic inllama_context:decode()Remove
llama_context::sbatch. Instead, create a new one for each decodeTODO before merge
llama_kv_cacheinterfacellama_kv_cache_xxxmore privateNext PRs
llama_context_params.logits_alllogic - unnecessary complication, can be achieved with explicit request for logits for all tokensinfillexample - obsoletellama_kv_cache_unifiedResolve