- 
                Notifications
    You must be signed in to change notification settings 
- Fork 13.5k
llama : support Jamba hybrid Transformer-Mamba models #7531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Merged
      
      
    
  
     Merged
                    Changes from 14 commits
      Commits
    
    
            Show all changes
          
          
            61 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      271104c
              
                wip: llama : separate recurrent states from the KV cache
              
              
                compilade 8db1e4d
              
                llama : use std::find for seq_nodes in llama_rs_cache
              
              
                compilade 0028010
              
                llama : state checkpoints for recurrent models
              
              
                compilade 0c8b3b2
              
                llama : correctly handle more edge cases for the rs cache
              
              
                compilade d66849f
              
                Merge branch 'master' into compilade/refactor-kv-cache
              
              
                compilade a09db95
              
                llama : rename many llama_kv_cache_* functions
              
              
                compilade c460ff1
              
                Merge branch 'master' into compilade/refactor-kv-cache
              
              
                compilade b6fafd1
              
                llama : remove useless return value for some llama_cache_* functions
              
              
                compilade b7ec12e
              
                Merge branch 'master' into compilade/refactor-kv-cache
              
              
                compilade 3b57b55
              
                Merge branch 'master' into compilade/refactor-kv-cache
              
              
                compilade 7e13f19
              
                llama : rethink recurrent state cell counts
              
              
                compilade cbc743e
              
                llama : support Jamba
              
              
                compilade 0fd13e9
              
                Merge branch 'master' into compilade/refactor-kv-cache
              
              
                compilade 61a88a1
              
                llama : fix BERT inference without KV cache
              
              
                compilade ea2e63e
              
                convert-hf : check for unprocessed Jamba experts
              
              
                compilade fc59407
              
                convert-hf : support Mini-Jamba conversion
              
              
                compilade 181dadf
              
                llama : fix Jamba quantization sanity checks
              
              
                compilade 3a414b0
              
                llama : sequence-length-aware batch splitting
              
              
                compilade 4e4c41e
              
                Merge branch 'master' into compilade/refactor-kv-cache
              
              
                compilade 3587a94
              
                llama : use equal-sequence-length sub-batches for recurrent models
              
              
                compilade 5d3c7b9
              
                Merge branch 'master' into compilade/refactor-kv-cache
              
              
                compilade 72eea49
              
                llama : fix batch split output count for embeddings
              
              
                compilade 18d1c14
              
                llama : minimize swaps when reordering logits
              
              
                compilade 61200ef
              
                llama : fix edge case finding batch seq_id of split recurrent cell
              
              
                compilade eb589d5
              
                llama : avoid copies for simple batch splits
              
              
                compilade 8fb57ac
              
                llama : use im2col and mul_mat to perform convolution for Mamba
              
              
                compilade 17f6c1e
              
                llama : fix .base() compilation error on Windows
              
              
                compilade fee3c1d
              
                llama : allow doing the equivalent of SSM_CONV with SUM_ROWS and MUL
              
              
                compilade 6840ac0
              
                Merge branch 'master' into compilade/refactor-kv-cache
              
              
                compilade 372482d
              
                llama : rename llama_cache to llama_past
              
              
                compilade 43d8d4b
              
                examples : replace llama_kv_cache_seq_* with llama_past_seq_*
              
              
                compilade ff794f5
              
                Merge branch 'master' into compilade/refactor-kv-cache
              
              
                compilade 33425a7
              
                mamba : fix non-contiguous usage of ggml_silu
              
              
                compilade 10c3c41
              
                Merge branch 'master' into compilade/refactor-kv-cache
              
              
                compilade 9b38f8b
              
                Merge branch 'master' into compilade/refactor-kv-cache
              
              
                compilade bc320ef
              
                Merge branch 'master' into compilade/refactor-kv-cache
              
              
                compilade fcb889c
              
                llama : session saving and reloading for hybrid models
              
              
                compilade a03e32a
              
                Merge branch 'master' into compilade/refactor-kv-cache
              
              
                compilade 9d3f44d
              
                convert_hf : fix Jamba conversion
              
              
                compilade 5f62db7
              
                llama : fix mixed signedness comparison
              
              
                compilade 375de5b
              
                llama : use unused n_embd_k_gqa in k_shift
              
              
                compilade 4bb4b22
              
                llama : begin renaming llama_past back to llama_kv_cache
              
              
                compilade 63ac36b
              
                Merge branch 'master' into compilade/refactor-kv-cache
              
              
                compilade 124c222
              
                Merge branch 'master' into compilade/refactor-kv-cache
              
              
                compilade 8006f3b
              
                llama : remove implicit recurrent state rollbacks
              
              
                compilade 691698e
              
                Merge branch 'master' into compilade/refactor-kv-cache
              
              
                compilade e3fe612
              
                llama : partially apply clang-format style
              
              
                compilade 2bcaf64
              
                Merge branch 'master' into compilade/refactor-kv-cache
              
              
                compilade 908e655
              
                convert : fix jamba conv1d shape squeezing
              
              
                compilade 4682e21
              
                Merge branch 'master' into compilade/refactor-kv-cache
              
              
                compilade 20f8e43
              
                graph : add back hybrid memory graph input
              
              
                compilade 07c252f
              
                model : add Jamba to Mamba-specific hparams printing
              
              
                compilade f716358
              
                Merge branch 'master' into compilade/refactor-kv-cache
              
              
                compilade b0b280e
              
                Merge branch 'master' into compilade/refactor-kv-cache
              
              
                compilade db5ff0c
              
                jamba : remove redundant nullptr initializations
              
              
                compilade 2f39cd7
              
                model : remove unnecessary prefix for tensor loading constants
              
              
                compilade f7c7a92
              
                model : use ggml_swiglu_split for Mamba
              
              
                compilade a60a24b
              
                Merge branch 'master' into compilade/refactor-kv-cache
              
              
                compilade 7f3955a
              
                model : make falcon-h1 use shared mamba2 layer builder
              
              
                compilade 452207f
              
                memory : avoid referring to KV in recurrent cache logs
              
              
                compilade 4d6a179
              
                gguf-py : avoid adding duplicate tensor mappings for Jamba
              
              
                compilade File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
      
      Oops, something went wrong.
        
    
  
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm looking at adding the missing Metal kernels for
SSM_CONVandSSM_SCAN. I'm wondering if this part of the kernels where we copysrc0->dstcould be extracted outside of the operation viaggml_cpy+ggml_vieworggml_acc? Would simplify the implementationAlso, I still haven't understood the details of the computation, but if we find a way to express these ops via existing ops all together (e.g. using
ggml_conv,ggml_mul_mat, ...), it would be preferred to do so, in order to reduce the amount of kernels that we have to write.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is definitely possible. I'll find a way to extract the copies outside.
For
SSM_SCAN, I think there's a way to fully express it in terms of other ops, though it will use much more memory because of the big intermediate tensors, and new operators likeSOFT_PLUSandEXPwould be needed instead. But different lengths of simultaneous sequences might make a custom operator still necessary. I'll think about ways to make it simpler, especially since other recurrent architectures (like RWKV) will also need to work on multiple sequences per batch.For simplifying
SSM_CONV, I don't thinkggml_convsupports working on independent 1D rolling windows with varying sequence lengths.When working on a single sequence, though, it's quite simple to do the equivalent of
ggml_ssm_convwith a self-overlapping view, as I did in my original implementation which I described in more detail in #5328 (comment):https://github.com/ggerganov/llama.cpp/blob/64fbce052373faf07a36b599528f8fe1cb1d62fb/llama.cpp#L6973-L6982
Setting
nb[2]to the element size makes the view self-overlapping.But this would create too many nodes in the compute graph when done with multiple sequences (unless they're always all the same length in which case the 4th dimension could be used), so a custom operator is necessary.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One idea that we might consider is to unfuse the
n_rsdimension from the SSM ops and make them work per 1 recurrent state. Then, during inference and right before the SSM operations, we split the batch into same-sequence chunks and SSM them individually. After that we concat back the results into the full hidden state for the batchThe main goal would be to simplify the SSM operators, and potentially express them as other existing ops if possible. But additionally, I'm considering a similar processing mode for the standard transformer KV cache in which we don't rely on a "unified" buffer for all the sequences, but instead each sequence has it's own separate KV cache buffer. In that mode, we would do a similar same-sequence batch split before the attention. The main purpose of supporting this mode would be to achieve reproducible results during parallel decoding (currently, decoding the same sequence in parallel can yield slightly different results due to the unified KV cache).
Just throwing some thoughts that I have so far - will continue looking at the PR in the next days
Edit: I was writing this comment before I saw you posted - will take a look tomorrow
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this would be doable,
but would make the number of compute graph nodes scale with the number of sequences.(EDIT: if it's split when making ubatches, then the number of compute graph nodes can stay constant)Another way would be to make all sequences have the same number of new tokens in a ubatch, to allow using another dimension instead of having to loop when building the compute graphs. This would still allow batching multiple sequences with recurrent models, but without the need for new custom operators for each architecture, and still with a constant number of compute graph nodes.
The recurrent steps are simpler for ubatches with sequence lengths of
1, but prompt processing performance would be much slower than with a per-recurrent-architecture operator for longer sequences. Still thinking about ways to generalize this while keeping good performance.For the transformer KV cache, if there's logic to make all sequences within a ubatch to have the same number of new tokens, I think a mode to split batches sequence-wise will be simpler and could re-use much of the same code.
I also think there's a way to keep the unified KV cache (one buffer) and chunk it to make each sequence have their own independent contiguous reserved cells. Batching sequences together might still be possible though, if the KQ mask gets another dimension (the number of sequences in the ubatch, and the number of new tokens per sequence instead of the batch size) so that these equal-sized "chunks" get processed independently in parallel. But this might not work (because the newly-calculated KV cells have to be copied in a bunch of not-regularly-spaced places), unless... unless maybe with some kind of
ggml_set_rows? Not sure about the transposed V cache, though.A sequence-wise processing mode is likely simpler, although it's not really parallel processing then (the model weights are all read at each ubatch).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it has to be split only for the attention so that the rest of the ops are still batched. Otherwise we will sacrifice a lot of performance
Not sure how that would work. Adding dummy tokens sounds too much overhead (at least
in the case of the regular transformer). Any other ideas?
From a broad PoV, if we have an implementation that works with a single-sequence and any batch size, then to extend it to multi-sequence batches we can split the batch into same-sequence tokens right before the attention and merge it back after the attention. Each split will do what we already do for the single-sequence solution, using separate cache for each sequence. I didn't consider the number of nodes until you noted - so that might be a problem indeed.
Looking forward to this!
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will sacrifice some performance, but only in the cases where a batch contains an unequal number of tokens for each affected sequence. So this should not affect large prompt processing or parallel text generation, if both are not done in the same batch.
This is not about adding dummy tokens, but about making the number of new tokens in each ubatch the same per sequence. I think the overhead will be minmal, though there is still some.
Let me illustrate.
Let's say there's a batch with new tokens for 4 sequences of length 16, 7, 1, 1, respectively.
Splitting that into equal-length sequences would make 3 ubatches, like so:
Each of these shapes are nice and rectangular, which is good for recurrent architectures because their operations can be more easily batched across sequences this way.
But I'm not yet sure if it would also benefit Transformers, which is why I'm thinking of initially only enabling the equal-length splitting for recurrent (or hybrid) model architectures.
Doing this with a constant number of graph nodes is pretty much what using same-length sequences (as illustrated above) allows, because the split into same-sequence tokens can then simply become another tensor dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aha, got it. Good idea. I'm also not sure if this can help Transformers, but it's something to think about 👍