Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions docs/features/index_cache.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# IndexCache

IndexCache reduces redundant top-k computation in DeepSeek-V3.2 (DSA) models by caching and reusing top-k indices across layers.

## Background

DeepSeek-V3.2 uses a DeepSeek Sparse Attention (DSA) mechanism where top-k token selection is computed per layer. For deep models with many layers, this computation can be expensive. IndexCache allows skipping redundant top-k computations by reusing indices from previous layers.

See: [IndexCache Paper](https://arxiv.org/abs/2603.12201)

## Usage

### CLI

```bash
vllm serve deepseek-ai/DeepSeek-V3.2 \
--hf-overrides '{"use_index_cache": true, "index_topk_freq": 4}' ...
```

### Configuration Reference

| Parameter | Type | Default | Description |
|----------------------|------|---------|--------------------------------------------------------------------------------------------------------------------------------------------------|
| `use_index_cache` | bool | false | Enable IndexCache. Must be set to true to use this feature |
| `index_topk_freq` | int | 1 | Frequency (in layers) at which top-k is computed. 1 = compute on every layer (disabled), 4 = compute on 1/4 of layers |
| `index_topk_pattern` | str | null | Per-layer F/S pattern. Overrides index_topk_freq if set. Each character maps to one DSA layer: F = Full, S = Shared |

### Configuration Examples

**Using `index_topk_freq`** (compute every N layers):

```bash
vllm serve deepseek-ai/DeepSeek-V3.2 \
--hf-overrides '{"use_index_cache": true, "index_topk_freq": 4}' ...
```

**Using `index_topk_pattern`** (explicit per-layer control):

```bash
# custom pattern for 61 layers: F = compute, S = reuse
vllm serve deepseek-ai/DeepSeek-V3.2 \
--hf-overrides '{"use_index_cache": true, "index_topk_pattern": "FFSFSSSFSSFFFSSSFFFSFSSSSSSFFSFFSFFSSFFFFFFSFFFFFSFFSSSSSSFSF"}'
```

## How It Works

1. When IndexCache is enabled, layers marked with `"F"` (Full) calculate and store top-k indices
2. Subsequent layers marked with `"S"` (Shared) receive the cached indices from the previous layer instead of recomputing
3. The cached indices are passed through the layer stack, reducing total computation

## Requirements

- DeepSeek-V3.2 or compatible DSA model
- `use_index_cache: true` via `--hf-overrides`
12 changes: 8 additions & 4 deletions vllm/model_executor/layers/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
skip_topk: bool = False,
) -> None:
super().__init__()
self.hidden_size = hidden_size
Expand All @@ -87,6 +88,11 @@ def __init__(
self.indexer_rope_emb = mla_modules.indexer_rotary_emb
self.is_sparse = mla_modules.is_sparse

# Whether to skip top-k token selection computation in this layer.
# When True, the indexer will not be called, and the layer will reuse
# the topk_tokens buffer written by a previous layer in the same pass.
# Refer: https://arxiv.org/abs/2603.12201 for more details.
self.skip_topk = skip_topk
if self.indexer is not None:
assert hasattr(self.indexer, "topk_tokens")
self.topk_tokens = self.indexer.topk_tokens
Expand Down Expand Up @@ -159,10 +165,8 @@ def forward(
positions, q[..., self.qk_nope_head_dim :], k_pe
)

if self.indexer and self.is_sparse:
_topk_indices = self.indexer(
hidden_states, q_c, positions, self.indexer_rope_emb
)
if self.indexer and self.is_sparse and not self.skip_topk:
self.indexer(hidden_states, q_c, positions, self.indexer_rope_emb)

if llama_4_scaling is not None:
q *= llama_4_scaling
Expand Down
22 changes: 21 additions & 1 deletion vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.model_executor.models.utils import (
extract_layer_index,
sequence_parallel_chunk,
)
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.torch_utils import direct_register_custom_op
Expand Down Expand Up @@ -963,6 +966,7 @@ def __init__(

self.is_v32 = hasattr(config, "index_topk")

_skip_topk = False
if self.is_v32:
self.indexer_rope_emb = get_rope(
qk_rope_head_dim,
Expand All @@ -980,6 +984,21 @@ def __init__(
topk_indices_buffer,
f"{prefix}.indexer",
)

# Enable IndexCache for DeepSeek models to reduce redundant top-k
# token selection computations in sparse attention.
use_index_cache = getattr(config, "use_index_cache", False)
if use_index_cache:
# IndexCache config
# Refer: https://arxiv.org/abs/2603.12201 for more details.
_index_topk_freq = getattr(config, "index_topk_freq", 1)
_index_topk_pattern = getattr(config, "index_topk_pattern", None)
layer_id = extract_layer_index(prefix)
if _index_topk_pattern is None:
_skip_topk = max(layer_id - 1, 0) % _index_topk_freq != 0
elif 0 <= layer_id < len(_index_topk_pattern):
_skip_topk = _index_topk_pattern[layer_id] == "S"
Comment thread
chaunceyjiang marked this conversation as resolved.

else:
self.indexer_rope_emb = None
self.indexer = None
Expand Down Expand Up @@ -1017,6 +1036,7 @@ def __init__(
cache_config,
quant_config,
prefix,
skip_topk=_skip_topk,
)

def forward(
Expand Down
Loading