feat(candle-nn) ConcatKvCache for 2-5x GPU speedup on autoregressive generation #3143
+304
−14
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Add ConcatKvCache for 2-5x GPU speedup on autoregressive generation
Summary
Adds a new
ConcatKvCacheimplementation that usesTensor::catinstead ofslice_setfor KV-cache updates, providing 2-5x GPU performance improvements for autoregressive generation.Motivation
The standard
KvCacheuses pre-allocated buffers withslice_setupdates, which has suboptimal performance on GPU due to:In contrast,
Tensor::catuses optimized concatenation kernels with:This PR adds
ConcatKvCacheas a new option alongside the existingKvCache, allowing developers to choose the best implementation for their use case.Changes
1. Added
ConcatKvCachetocandle-nn/src/kv_cache.rsA new KV-cache implementation that:
Tensor::catfor append operations instead ofslice_setKvCacheKey features:
new(dim),append(k, v),reset()catkernels (PR Optimize the cat operation on contiguous tensors #1855)2. Updated Qwen3 to use
ConcatKvCacheModified
candle-transformers/src/models/qwen3.rs(3 lines changed):3. Qwen3 MoE automatically benefits
No changes needed to
qwen3_moe.rs- it importsQwen3Attentionand automatically inherits the performance improvement.3. Updated Quantized-Qwen3 to use
ConcatKvCacheSame updates as implemented in Qwen3.
Easy Migration Path
ConcatKvCacheis designed as a near drop-in replacement with identical runtime API:new(dim, max_len)new(dim)append(k, v)append(k, v)reset()reset()(Tensor, Tensor)(Tensor, Tensor)Migration effort per model: Change 3 lines, get 2-5x speedup.
Performance Results
Benchmarked on Qwen3-0.6B:
Hardware:
GPU Performance - Significant Improvement
Quantized Model Performance
Also tested on Quantized Qwen3-0.6B (8-bit) to verify the optimization works across model types:
Key insight: Speedup increases with sequence length, making this especially valuable for long-context applications.
Speedup Growth Pattern
The performance advantage grows with sequence length:
This is because
slice_set's overhead compounds as the cache grows (larger strides), whilecatmaintains efficient sequential access patterns.CPU Performance - Neutral
CPU performance is essentially unchanged, confirming this optimization specifically targets GPU bottlenecks.
Design Rationale
Why add
ConcatKvCacheinstead of modifyingKvCache?Different use cases have different optimal implementations:
KvCacheConcatKvCacheRotatingKvCacheScatteredKvCacheBy keeping both implementations, developers can choose the right tool for their specific hardware and use case.
Why is
catfaster on GPU?Both
KvCacheandConcatKvCacheuse the optimizedcopy2dkernel from PR #1855, but they feed it different parameters:ConcatKvCache(viacat_contiguous):KvCache(viaslice_set):The difference:
See
candle-core/src/tensor_cat.rsfor the optimizedcat_contiguousimplementation.When to Use Each Cache
Added documentation to
kv_cache.rs:ConcatKvCacheKvCacheRotatingKvCacheScatteredKvCacheTesting
Unit Tests
Added 4 comprehensive tests for
ConcatKvCache:Integration Tests
Benchmark Command
Related PRs
copy2dkernel (both caches benefit from this)Breaking Changes
None. This PR:
Checklist