diff --git a/csrc/causal_conv1d/op_host/causal_conv1d_tiling.cpp b/csrc/causal_conv1d/op_host/causal_conv1d_tiling.cpp index fa8bd23f114..95aff443c09 100644 --- a/csrc/causal_conv1d/op_host/causal_conv1d_tiling.cpp +++ b/csrc/causal_conv1d/op_host/causal_conv1d_tiling.cpp @@ -220,8 +220,29 @@ static ge::graphStatus GetShapeDtypeInfo(gert::TilingContext* context, CausalCon auto ciShapePtr = context->GetInputShape(CACHE_INDICES_INDEX); OP_CHECK_NULL_WITH_CONTEXT(context, ciShapePtr); auto ciShape = EnsureNotScalar(ciShapePtr->GetStorageShape()); - if(ciShape.GetDimNum() != 1){return ge::GRAPH_FAILED;} - if(ciShape.GetDim(0) != batch){return ge::GRAPH_FAILED;} + int64_t cacheIndicesMode = 0; + int64_t maxQueryLen = 1; + if (ciShape.GetDimNum() == 1) { + if (ciShape.GetDim(0) != batch) { + return ge::GRAPH_FAILED; + } + cacheIndicesMode = 0; + maxQueryLen = 1; + } else if (ciShape.GetDimNum() == 2) { + if (ciShape.GetDim(0) != batch) { + return ge::GRAPH_FAILED; + } + cacheIndicesMode = 1; + maxQueryLen = ciShape.GetDim(1); + if (maxQueryLen <= 0) { + return ge::GRAPH_FAILED; + } + if (inputMode == 1 && maxQueryLen < seqLen) { + return ge::GRAPH_FAILED; + } + } else { + return ge::GRAPH_FAILED; + } auto hisShapePtr = context->GetInputShape(HAS_INITIAL_STATE_INDEX); OP_CHECK_NULL_WITH_CONTEXT(context, hisShapePtr); @@ -289,6 +310,8 @@ static ge::graphStatus GetShapeDtypeInfo(gert::TilingContext* context, CausalCon tiling.set_stateLen(stateLen); tiling.set_numCacheLines(numCacheLines); tiling.set_batch(batch); + tiling.set_cacheIndicesMode(cacheIndicesMode); + tiling.set_maxQueryLen(maxQueryLen); return ge::GRAPH_SUCCESS; } diff --git a/csrc/causal_conv1d/op_host/causal_conv1d_tiling.h b/csrc/causal_conv1d/op_host/causal_conv1d_tiling.h index 28e74e5b99b..3199a128471 100644 --- a/csrc/causal_conv1d/op_host/causal_conv1d_tiling.h +++ b/csrc/causal_conv1d/op_host/causal_conv1d_tiling.h @@ -48,6 +48,9 @@ BEGIN_TILING_DATA_DEF(CausalConv1dTilingData) TILING_DATA_FIELD_DEF(int64_t, dimTileSize); TILING_DATA_FIELD_DEF(int64_t, blocksPerSeq); + + TILING_DATA_FIELD_DEF(int64_t, cacheIndicesMode); + TILING_DATA_FIELD_DEF(int64_t, maxQueryLen); END_TILING_DATA_DEF; struct CausalConv1dCompileInfo { uint64_t ubSize = 0; diff --git a/csrc/causal_conv1d/op_kernel/causal_conv1d.h b/csrc/causal_conv1d/op_kernel/causal_conv1d.h index 3407dd37ba3..d32cd7cf269 100644 --- a/csrc/causal_conv1d/op_kernel/causal_conv1d.h +++ b/csrc/causal_conv1d/op_kernel/causal_conv1d.h @@ -93,6 +93,10 @@ struct CausalConv1dTilingData { // Channel-wise tiling int64_t dimTileSize; int64_t blocksPerSeq; + + // cacheIndices shape mode: 0 -> [batch], 1 -> [batch, maxQueryLen] + int64_t cacheIndicesMode; + int64_t maxQueryLen; }; #endif // CAUSAL_CONV1D_TILING_DATA_H_ @@ -383,6 +387,8 @@ __aicore__ inline void CausalConv1d::Process() const int32_t seqLen = tilingData_->seqLen; const int32_t dimTileSize = static_cast(tilingData_->dimTileSize); const int32_t blocksPerSeq = static_cast(tilingData_->blocksPerSeq); + const int32_t cacheIndicesMode = static_cast(tilingData_->cacheIndicesMode); + const int32_t maxQueryLen = static_cast(tilingData_->maxQueryLen); const uint32_t blockIdx = GetBlockIdx(); const uint32_t blockNum = GetBlockNum(); @@ -417,16 +423,52 @@ __aicore__ inline void CausalConv1d::Process() continue; } - const int32_t cacheIdx = cacheIndicesGm.GetValue(seq); - if (cacheIdx == tilingData_->padSlotId) { - continue; + int32_t cacheIdx = 0; + if (cacheIndicesMode == 0) { + cacheIdx = cacheIndicesGm.GetValue(seq); + if (cacheIdx == tilingData_->padSlotId) { + continue; + } + } else { + if (maxQueryLen <= 0) { + continue; + } + cacheIdx = cacheIndicesGm.GetValue(static_cast(seq) * maxQueryLen); + if (cacheIdx == tilingData_->padSlotId) { + continue; + } + if (len > maxQueryLen) { + // guard for malformed metadata; fallback to final-state path only + len = maxQueryLen; + } } const bool hasInit = hasInitialStateGm.GetValue(seq); InitRing(cacheIdx, hasInit, start, len, c0, dimTileSize, dim, dbg); RunSeq(start, len, c0, dimTileSize, dim, dbg); - WriteBackState(cacheIdx, len, c0, dimTileSize, dim, dbg); + if (cacheIndicesMode == 0) { + WriteBackState(cacheIdx, len, c0, dimTileSize, dim, dbg); + } else { + const int32_t stateLen = tilingData_->stateLen; + LocalTensor ring = inBuf.Get(); + for (int32_t t = 0; t < len; ++t) { + const int32_t tokenCacheIdx = + cacheIndicesGm.GetValue(static_cast(seq) * maxQueryLen + t); + if (tokenCacheIdx == tilingData_->padSlotId) { + continue; + } + for (int32_t pos = 0; pos < (MAX_WIDTH - 1); ++pos) { + const int32_t tap = (MAX_WIDTH - 2) - pos; + const int32_t slot = (tap == 0) ? SlotCurr(t) : SlotHist(t, tap); + const int64_t stateOffset = + static_cast(tokenCacheIdx) * stateLen * dim + static_cast(pos) * dim + c0; + PipeBarrier(); + DataCopy(convStatesGm[stateOffset], ring[slot * MAX_BLOCK_DIM], dimTileSize); + PipeBarrier(); + } + } + } } ReleaseEvents(); diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_causal_conv1d.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_causal_conv1d.py index 34db8f00e89..18527805df0 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_causal_conv1d.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_causal_conv1d.py @@ -10,6 +10,63 @@ causal_conv1d_update_npu as causal_conv1d_update from vllm_ascend.utils import enable_custom_op + +def test_ascend_causal_conv1d_2d_state_indices_writeback(): + """Validate token-level state snapshots when cache indices are 2D [batch, max_query_len].""" + torch.random.manual_seed(0) + enable_custom_op() + + device = "npu" + dtype = torch.float16 + dim = 384 # keep aligned with kernel tiling candidates + width = 4 + state_len = width - 1 + + # Two variable-length sequences: [3, 2] + query_start_loc = torch.tensor([0, 3, 5], device=device, dtype=torch.int32) + x_cpu = torch.stack( + [torch.full((dim,), float(i + 1), dtype=torch.float32) for i in range(5)], + dim=0, + ) + x = x_cpu.to(device=device, dtype=dtype) + + # Zero weights isolate state writeback semantics from convolution math. + weight = torch.zeros((width, dim), device=device, dtype=dtype) + conv_states = torch.zeros((32, state_len, dim), device=device, dtype=dtype) + has_initial_state = torch.tensor([False, False], device=device, dtype=torch.bool) + cache_indices_2d = torch.tensor( + [ + [10, 11, 12], + [20, 21, PAD_SLOT_ID], + ], + device=device, + dtype=torch.int32, + ) + + _ = torch.ops._C_ascend.causal_conv1d_fn( + x, + weight, + None, + activation="silu", + conv_state=conv_states, + has_initial_state=has_initial_state, + non_spec_state_indices_tensor=cache_indices_2d, + non_spec_query_start_loc=query_start_loc, + pad_slot_id=PAD_SLOT_ID, + ) + + conv_cpu = conv_states.float().cpu() + expected = { + 10: torch.stack([torch.zeros(dim), torch.zeros(dim), x_cpu[0]], dim=0), + 11: torch.stack([torch.zeros(dim), x_cpu[0], x_cpu[1]], dim=0), + 12: torch.stack([x_cpu[0], x_cpu[1], x_cpu[2]], dim=0), + 20: torch.stack([torch.zeros(dim), torch.zeros(dim), x_cpu[3]], dim=0), + 21: torch.stack([torch.zeros(dim), x_cpu[3], x_cpu[4]], dim=0), + } + + for cache_line, ref in expected.items(): + validate_cmp(conv_cpu[cache_line], ref, torch.float32, device="cpu") + def validate_cmp(y_cal, y_ref, dtype, device='npu'): y_cal = y_cal.to(device) y_ref = y_ref.to(device) diff --git a/tests/ut/core/test_recompute_scheduler.py b/tests/ut/core/test_recompute_scheduler.py new file mode 100644 index 00000000000..0427f9d0a9f --- /dev/null +++ b/tests/ut/core/test_recompute_scheduler.py @@ -0,0 +1,49 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from tests.ut.base import TestBase +from vllm_ascend.core.recompute_scheduler import RecomputeScheduler + + +class TestRecomputeScheduler(TestBase): + + def _fake_scheduler_init(self, *args, **kwargs): + self.vllm_config = kwargs["vllm_config"] + # Baseline expectation from upstream-style init. + self.need_mamba_block_aligned_split = True + + def _build_vllm_config(self, model_type: str, mamba_cache_mode: str): + return SimpleNamespace( + speculative_config=None, + kv_transfer_config=None, + model_config=SimpleNamespace( + hf_text_config=SimpleNamespace(model_type=model_type), + ), + cache_config=SimpleNamespace( + mamba_cache_mode=mamba_cache_mode, + ), + ) + + @patch("vllm_ascend.core.recompute_scheduler.Scheduler.__init__", new=_fake_scheduler_init) + def test_all_mode_disables_block_aligned_split_for_qwen3_5(self): + cfg = self._build_vllm_config("qwen3_5", "all") + + scheduler = RecomputeScheduler(vllm_config=cfg) + + self.assertFalse(scheduler.need_mamba_block_aligned_split) + + @patch("vllm_ascend.core.recompute_scheduler.Scheduler.__init__", new=_fake_scheduler_init) + def test_non_all_mode_keeps_block_aligned_split_for_qwen3_5(self): + cfg = self._build_vllm_config("qwen3_5", "align") + + scheduler = RecomputeScheduler(vllm_config=cfg) + + self.assertTrue(scheduler.need_mamba_block_aligned_split) + + @patch("vllm_ascend.core.recompute_scheduler.Scheduler.__init__", new=_fake_scheduler_init) + def test_all_mode_keeps_non_hybrid_unchanged(self): + cfg = self._build_vllm_config("llama", "all") + + scheduler = RecomputeScheduler(vllm_config=cfg) + + self.assertTrue(scheduler.need_mamba_block_aligned_split) diff --git a/tests/ut/patch/platform/test_patch_mamba_config.py b/tests/ut/patch/platform/test_patch_mamba_config.py new file mode 100644 index 00000000000..10746069b2b --- /dev/null +++ b/tests/ut/patch/platform/test_patch_mamba_config.py @@ -0,0 +1,75 @@ +from types import SimpleNamespace +from unittest.mock import patch + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.patch.platform.patch_mamba_config import verify_and_update_config + + +class _FakeModelCls: + + @staticmethod + def get_mamba_state_shape_from_config(_vllm_config): + # ssm: 256 bytes, conv: 128 bytes when dtype is fp16 + return [(128,), (64,)] + + @staticmethod + def get_mamba_state_dtype_from_config(_vllm_config): + return [torch.float16, torch.float16] + + +class TestPatchMambaConfig(TestBase): + + def _build_vllm_config(self, mamba_cache_mode: str): + cache_config = SimpleNamespace( + cache_dtype="auto", + block_size=None, + mamba_page_size_padded=None, + enable_prefix_caching=True, + mamba_cache_mode=mamba_cache_mode, + mamba_block_size=None, + ) + model_config = SimpleNamespace( + dtype=torch.float16, + architecture="FakeArch", + max_model_len=4096, + get_num_kv_heads=lambda _parallel: 1, + get_head_size=lambda: 1, + ) + parallel_config = SimpleNamespace() + return SimpleNamespace( + cache_config=cache_config, + model_config=model_config, + parallel_config=parallel_config, + ) + + @patch("vllm_ascend.patch.platform.patch_mamba_config.MambaModelConfig.verify_and_update_config") + @patch("vllm_ascend.patch.platform.patch_mamba_config.ModelRegistry.resolve_model_cls") + def test_prefix_caching_all_mode_uses_block_size(self, mock_resolve_model_cls, _mock_verify): + mock_resolve_model_cls.return_value = (_FakeModelCls, None) + vllm_config = self._build_vllm_config("all") + + verify_and_update_config.__func__(None, vllm_config) + + self.assertEqual(vllm_config.cache_config.mamba_block_size, vllm_config.cache_config.block_size) + + @patch("vllm_ascend.patch.platform.patch_mamba_config.MambaModelConfig.verify_and_update_config") + @patch("vllm_ascend.patch.platform.patch_mamba_config.ModelRegistry.resolve_model_cls") + def test_prefix_caching_align_mode_uses_block_size(self, mock_resolve_model_cls, _mock_verify): + mock_resolve_model_cls.return_value = (_FakeModelCls, None) + vllm_config = self._build_vllm_config("align") + + verify_and_update_config.__func__(None, vllm_config) + + self.assertEqual(vllm_config.cache_config.mamba_block_size, vllm_config.cache_config.block_size) + + @patch("vllm_ascend.patch.platform.patch_mamba_config.MambaModelConfig.verify_and_update_config") + @patch("vllm_ascend.patch.platform.patch_mamba_config.ModelRegistry.resolve_model_cls") + def test_prefix_caching_non_all_align_uses_max_model_len(self, mock_resolve_model_cls, _mock_verify): + mock_resolve_model_cls.return_value = (_FakeModelCls, None) + vllm_config = self._build_vllm_config("legacy") + + verify_and_update_config.__func__(None, vllm_config) + + self.assertEqual(vllm_config.cache_config.mamba_block_size, vllm_config.model_config.max_model_len) diff --git a/tests/ut/patch/worker/patch_common/test_patch_qwen3_5.py b/tests/ut/patch/worker/patch_common/test_patch_qwen3_5.py new file mode 100644 index 00000000000..5259214a8e8 --- /dev/null +++ b/tests/ut/patch/worker/patch_common/test_patch_qwen3_5.py @@ -0,0 +1,306 @@ +from types import SimpleNamespace +from unittest.mock import patch + +import torch + +from tests.ut.base import TestBase +from vllm.v1.attention.backends.utils import PAD_SLOT_ID +from vllm_ascend.patch.worker.patch_qwen3_5 import ( + AscendQwen3_5GatedDeltaNet, + _ensure_prefill_token_state_indices, + _is_mamba_all_prefix_mode, + _write_all_mode_token_conv_states, +) + + +class TestPatchQwen35AllModeHelpers(TestBase): + + def test_is_mamba_all_prefix_mode_true(self): + forward_context = SimpleNamespace( + vllm_config=SimpleNamespace( + cache_config=SimpleNamespace( + enable_prefix_caching=True, + mamba_cache_mode="all", + ) + ) + ) + self.assertTrue(_is_mamba_all_prefix_mode(forward_context)) + + def test_is_mamba_all_prefix_mode_false_when_not_all(self): + forward_context = SimpleNamespace( + vllm_config=SimpleNamespace( + cache_config=SimpleNamespace( + enable_prefix_caching=True, + mamba_cache_mode="align", + ) + ) + ) + self.assertFalse(_is_mamba_all_prefix_mode(forward_context)) + + def test_ensure_prefill_token_state_indices_returns_2d_directly(self): + state_indices = torch.tensor([[1, 2], [3, 4]], dtype=torch.long) + query_start_loc = torch.tensor([0, 2, 4], dtype=torch.long) + result = _ensure_prefill_token_state_indices(state_indices, query_start_loc) + self.assertTrue(torch.equal(result, state_indices)) + + def test_ensure_prefill_token_state_indices_expand_with_pad(self): + state_indices = torch.tensor([7, 9], dtype=torch.long) + query_start_loc = torch.tensor([0, 2, 5], dtype=torch.long) + + result = _ensure_prefill_token_state_indices(state_indices, query_start_loc) + + expected = torch.tensor( + [ + [7, 7, PAD_SLOT_ID], + [9, 9, 9], + ], + dtype=torch.long, + ) + self.assertTrue(torch.equal(result, expected)) + + def test_ensure_prefill_token_state_indices_shape_mismatch_fallback(self): + state_indices = torch.tensor([1, 2, 3], dtype=torch.long) + query_start_loc = torch.tensor([0, 2, 5], dtype=torch.long) + + result = _ensure_prefill_token_state_indices(state_indices, query_start_loc) + + self.assertEqual(result.shape, (3, 1)) + self.assertTrue(torch.equal(result.squeeze(1), state_indices)) + + def test_write_all_mode_token_conv_states_without_initial_state(self): + mixed_qkv_non_spec = torch.tensor( + [ + [1.0, 10.0], + [2.0, 20.0], + [3.0, 30.0], + ], + dtype=torch.float32, + ) + conv_state = torch.zeros((4, 2, 3), dtype=torch.float32) + token_state_indices = torch.tensor([[1, 2, 3]], dtype=torch.int32) + query_start_loc = torch.tensor([0, 3], dtype=torch.int32) + has_initial_state = torch.tensor([False], dtype=torch.bool) + + _write_all_mode_token_conv_states( + mixed_qkv_non_spec, + conv_state, + token_state_indices, + query_start_loc, + has_initial_state, + state_width=3, + ) + + self.assertTrue(torch.equal(conv_state[1, 0, :3], torch.tensor([0.0, 0.0, 1.0]))) + self.assertTrue(torch.equal(conv_state[2, 0, :3], torch.tensor([0.0, 1.0, 2.0]))) + self.assertTrue(torch.equal(conv_state[3, 0, :3], torch.tensor([1.0, 2.0, 3.0]))) + + def test_write_all_mode_token_conv_states_with_initial_state_and_pad(self): + mixed_qkv_non_spec = torch.tensor( + [ + [5.0, 50.0], + [6.0, 60.0], + ], + dtype=torch.float32, + ) + conv_state = torch.zeros((5, 2, 3), dtype=torch.float32) + conv_state[4, :, :3] = torch.tensor([[7.0, 8.0, 9.0], [70.0, 80.0, 90.0]]) + token_state_indices = torch.tensor([[4, PAD_SLOT_ID, 2]], dtype=torch.int32) + query_start_loc = torch.tensor([0, 2], dtype=torch.int32) + has_initial_state = torch.tensor([True], dtype=torch.bool) + + _write_all_mode_token_conv_states( + mixed_qkv_non_spec, + conv_state, + token_state_indices, + query_start_loc, + has_initial_state, + state_width=3, + ) + + self.assertTrue(torch.equal(conv_state[4, 0, :3], torch.tensor([8.0, 9.0, 5.0]))) + self.assertTrue(torch.equal(conv_state[2, 0, :3], torch.tensor([0.0, 0.0, 0.0]))) + + @patch("vllm_ascend.patch.worker.patch_qwen3_5.maybe_save_kv_layer_to_connector") + @patch("vllm_ascend.patch.worker.patch_qwen3_5.enable_sp", return_value=False) + @patch("vllm_ascend.patch.worker.patch_qwen3_5.GDNAttentionMetadata", new=SimpleNamespace) + @patch("vllm_ascend.patch.worker.patch_qwen3_5.fused_gdn_gating_patch") + @patch("vllm_ascend.patch.worker.patch_qwen3_5.fused_recurrent_gated_delta_rule") + @patch("vllm_ascend.patch.worker.patch_qwen3_5.torch.ops._C_ascend.causal_conv1d_fn", create=True) + @patch("vllm_ascend.patch.worker.patch_qwen3_5.get_forward_context") + def test_forward_core_all_mode_prefill_uses_token_level_indices( + self, + mock_get_forward_context, + mock_causal_conv1d_fn, + mock_fused_recurrent, + mock_fused_gating, + _mock_enable_sp, + _mock_save_kv, + ): + # Build a minimal forward context for all-mode prefill. + non_spec_query_start_loc = torch.tensor([0, 3], dtype=torch.int32) + non_spec_state_indices = torch.tensor([5], dtype=torch.int32) + attn_metadata = SimpleNamespace( + has_initial_state=torch.tensor([False], dtype=torch.bool), + spec_query_start_loc=None, + non_spec_query_start_loc=non_spec_query_start_loc, + spec_sequence_masks=None, + spec_token_indx=None, + non_spec_token_indx=None, + spec_state_indices_tensor=None, + non_spec_state_indices_tensor=non_spec_state_indices, + num_actual_tokens=3, + num_accepted_tokens=0, + num_prefills=1, + num_decodes=0, + num_spec_decodes=0, + ) + forward_context = SimpleNamespace( + attn_metadata={"pref": attn_metadata}, + virtual_engine=0, + vllm_config=SimpleNamespace( + cache_config=SimpleNamespace( + enable_prefix_caching=True, + mamba_cache_mode="all", + )) + ) + mock_get_forward_context.return_value = forward_context + + mock_fused_gating.return_value = ( + torch.zeros((1, 3, 2), dtype=torch.float32), + torch.zeros((1, 3, 2), dtype=torch.float32), + ) + mock_fused_recurrent.return_value = ( + torch.zeros((1, 3, 2), dtype=torch.float32), + torch.zeros((1, 2), dtype=torch.float32), + ) + + # Patch custom op entry to keep test purely CPU/UT level. + mock_causal_conv1d_fn.side_effect = lambda *args, **kwargs: args[0] + + class FakeSelf: + pass + + fake_self = FakeSelf() + fake_self.prefix = "pref" + fake_self.activation = "silu" + fake_self.A_log = torch.zeros((2,), dtype=torch.float32) + fake_self.dt_bias = torch.zeros((2,), dtype=torch.float32) + fake_self.conv1d = SimpleNamespace( + weight=torch.zeros((2, 1, 4), dtype=torch.float32), + bias=torch.zeros((2,), dtype=torch.float32), + ) + # cache layout: [num_cache_lines, state_len, dim] + fake_self.kv_cache = [( + torch.zeros((8, 3, 2), dtype=torch.float32), + torch.ones((8, 2), dtype=torch.float32), + )] + fake_self.rearrange_mixed_qkv = ( + lambda x: (None, None, None) if x is None else (x.unsqueeze(0), x.unsqueeze(0), x.unsqueeze(0)) + ) + + mixed_qkv = torch.tensor( + [[1.0, 10.0], [2.0, 20.0], [3.0, 30.0]], + dtype=torch.float32, + ) + b = torch.zeros_like(mixed_qkv) + a = torch.zeros_like(mixed_qkv) + core_attn_out = torch.zeros((3, 2), dtype=torch.float32) + + AscendQwen3_5GatedDeltaNet._forward_core(fake_self, mixed_qkv, b, a, core_attn_out) + + self.assertTrue(mock_fused_recurrent.called) + recurrent_kwargs = mock_fused_recurrent.call_args.kwargs + token_indices = recurrent_kwargs["ssm_state_indices"] + self.assertEqual(token_indices.shape, (1, 3)) + self.assertTrue(torch.equal(token_indices, torch.tensor([[5, 5, 5]], dtype=torch.int32))) + + @patch("vllm_ascend.patch.worker.patch_qwen3_5.maybe_save_kv_layer_to_connector") + @patch("vllm_ascend.patch.worker.patch_qwen3_5.enable_sp", return_value=False) + @patch("vllm_ascend.patch.worker.patch_qwen3_5.GDNAttentionMetadata", new=SimpleNamespace) + @patch("vllm_ascend.patch.worker.patch_qwen3_5.fused_sigmoid_gating_delta_rule_update") + @patch("vllm_ascend.patch.worker.patch_qwen3_5.fused_gdn_gating_patch") + @patch("vllm_ascend.patch.worker.patch_qwen3_5.fused_recurrent_gated_delta_rule") + @patch("vllm_ascend.patch.worker.patch_qwen3_5.causal_conv1d_update") + @patch("vllm_ascend.patch.worker.patch_qwen3_5.get_forward_context") + def test_forward_core_all_mode_decode_uses_recurrent_path( + self, + mock_get_forward_context, + mock_causal_conv1d_update, + mock_fused_recurrent, + mock_fused_gating, + mock_sigmoid_update, + _mock_enable_sp, + _mock_save_kv, + ): + non_spec_query_start_loc = torch.tensor([0, 2], dtype=torch.int32) + non_spec_state_indices = torch.tensor([3, 4], dtype=torch.int32) + attn_metadata = SimpleNamespace( + has_initial_state=torch.tensor([True, True], dtype=torch.bool), + spec_query_start_loc=None, + non_spec_query_start_loc=non_spec_query_start_loc, + spec_sequence_masks=None, + spec_token_indx=None, + non_spec_token_indx=None, + spec_state_indices_tensor=None, + non_spec_state_indices_tensor=non_spec_state_indices, + num_actual_tokens=2, + num_accepted_tokens=0, + num_prefills=0, + num_decodes=2, + num_spec_decodes=0, + ) + forward_context = SimpleNamespace( + attn_metadata={"pref": attn_metadata}, + virtual_engine=0, + vllm_config=SimpleNamespace( + cache_config=SimpleNamespace( + enable_prefix_caching=True, + mamba_cache_mode="all", + )) + ) + mock_get_forward_context.return_value = forward_context + + mock_causal_conv1d_update.side_effect = lambda x, *_args, **_kwargs: x + mock_fused_gating.return_value = ( + torch.zeros((1, 2, 2), dtype=torch.float32), + torch.zeros((1, 2, 2), dtype=torch.float32), + ) + mock_fused_recurrent.return_value = ( + torch.zeros((1, 2, 2), dtype=torch.float32), + torch.zeros((2, 2), dtype=torch.float32), + ) + + class FakeSelf: + pass + + fake_self = FakeSelf() + fake_self.prefix = "pref" + fake_self.activation = "silu" + fake_self.A_log = torch.zeros((2,), dtype=torch.float32) + fake_self.dt_bias = torch.zeros((2,), dtype=torch.float32) + fake_self.conv1d = SimpleNamespace( + weight=torch.zeros((2, 1, 4), dtype=torch.float32), + bias=torch.zeros((2,), dtype=torch.float32), + ) + fake_self.kv_cache = [( + torch.zeros((8, 3, 2), dtype=torch.float32), + torch.ones((8, 2), dtype=torch.float32), + )] + fake_self.rearrange_mixed_qkv = ( + lambda x: (None, None, None) if x is None else (x.unsqueeze(0), x.unsqueeze(0), x.unsqueeze(0)) + ) + + mixed_qkv = torch.tensor( + [[1.0, 10.0], [2.0, 20.0]], + dtype=torch.float32, + ) + b = torch.zeros_like(mixed_qkv) + a = torch.zeros_like(mixed_qkv) + core_attn_out = torch.zeros((2, 2), dtype=torch.float32) + + AscendQwen3_5GatedDeltaNet._forward_core(fake_self, mixed_qkv, b, a, core_attn_out) + + self.assertTrue(mock_fused_recurrent.called) + recurrent_kwargs = mock_fused_recurrent.call_args.kwargs + self.assertTrue(torch.equal(recurrent_kwargs["ssm_state_indices"], non_spec_state_indices)) + self.assertFalse(mock_sigmoid_update.called) diff --git a/vllm_ascend/core/recompute_scheduler.py b/vllm_ascend/core/recompute_scheduler.py index b83ced7e9a4..6f435772b8a 100644 --- a/vllm_ascend/core/recompute_scheduler.py +++ b/vllm_ascend/core/recompute_scheduler.py @@ -91,6 +91,14 @@ def __init__(self, *args, **kwargs): and self.vllm_config.kv_transfer_config and self.vllm_config.kv_transfer_config.is_kv_consumer ) + self.is_kv_producer = self.vllm_config.kv_transfer_config and self.vllm_config.kv_transfer_config.is_kv_producer + self.is_hybrid_model = ( + "qwen3_next" in self.vllm_config.model_config.hf_text_config.model_type + or "qwen3_5" in self.vllm_config.model_config.hf_text_config.model_type + ) + # Block-aligned chunk split is an align-mode constraint. + if self.is_hybrid_model and getattr(self.vllm_config.cache_config, "mamba_cache_mode", None) == "all": + self.need_mamba_block_aligned_split = False def add_request(self, request: Request) -> None: existing = self.requests.get(request.request_id) @@ -111,6 +119,10 @@ def add_request(self, request: Request) -> None: request.streaming_queue = deque() # Fill in placeholder tokens to enable full graph compatibility. Without # placeholders, graph matching may fail, forcing eager mode execution. + if self.is_kv_producer and self.is_hybrid_model and request.num_tokens > 1: + request.prompt_token_ids.pop() + request._all_token_ids.pop() + request.num_prompt_tokens -= 1 if self.is_mtp_kv_consumer: request.spec_token_ids = [PLACEHOLDER_TOKEN_ID] * self.num_spec_tokens self.waiting.add_request(request) @@ -118,6 +130,55 @@ def add_request(self, request: Request) -> None: if self.log_stats: request.record_event(EngineCoreEventType.QUEUED) + def _update_waiting_for_remote_kv(self, request: Request) -> bool: + """ + KV Connector: check if the request_id is finished_recving. + + The finished_recving_kv_req_ids list is populated + on the previous steps()'s update_from_output based + on the worker side connector. + + When the kv transfer is ready, we cache the blocks + and the request state will be moved back to WAITING from + WAITING_FOR_REMOTE_KV. + """ + assert self.connector is not None + if request.request_id not in self.finished_recving_kv_req_ids: + return False + + if request.request_id in self.failed_recving_kv_req_ids: + # Request had KV load failures; num_computed_tokens was already + # updated in _update_requests_with_invalid_blocks + if request.num_computed_tokens: + # Cache any valid computed tokens. + self.kv_cache_manager.cache_blocks(request, request.num_computed_tokens) + else: + # No valid computed tokens, release allocated blocks. + # There may be a local cache hit on retry. + self.kv_cache_manager.free(request) + + self.failed_recving_kv_req_ids.remove(request.request_id) + else: + # Now that the blocks are ready, actually cache them. + block_ids = self.kv_cache_manager.get_block_ids(request.request_id) + if len(block_ids) == 1: + num_computed_tokens = len(block_ids[0]) * self.block_size + # Handle the case where num request tokens less than one block. + num_computed_tokens = min(num_computed_tokens, request.num_tokens) + else: + num_computed_tokens = request.num_tokens + if num_computed_tokens == request.num_tokens: + num_computed_tokens -= 1 + # This will cache the blocks iff caching is enabled. + self.kv_cache_manager.cache_blocks(request, num_computed_tokens) + + # Update the request state for scheduling. + request.num_computed_tokens = num_computed_tokens + + # Return that we are ready. + self.finished_recving_kv_req_ids.remove(request.request_id) + return True + def schedule(self) -> RecomputeSchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. diff --git a/vllm_ascend/patch/platform/patch_mamba_config.py b/vllm_ascend/patch/platform/patch_mamba_config.py index d0a0f81e7a3..a75bcabf697 100644 --- a/vllm_ascend/patch/platform/patch_mamba_config.py +++ b/vllm_ascend/patch/platform/patch_mamba_config.py @@ -87,8 +87,12 @@ def verify_and_update_config(cls, vllm_config) -> None: "exactly equal.", mamba_padding_pct, ) - if cache_config.enable_prefix_caching and cache_config.mamba_cache_mode == "align": - cache_config.mamba_block_size = cache_config.block_size + if cache_config.enable_prefix_caching: + # Prefix caching needs block-level mamba states in both align/all modes. + if cache_config.mamba_cache_mode in ("align", "all"): + cache_config.mamba_block_size = cache_config.block_size + else: + cache_config.mamba_block_size = model_config.max_model_len else: cache_config.mamba_block_size = model_config.max_model_len diff --git a/vllm_ascend/patch/worker/patch_qwen3_5.py b/vllm_ascend/patch/worker/patch_qwen3_5.py new file mode 100644 index 00000000000..25fb055c13f --- /dev/null +++ b/vllm_ascend/patch/worker/patch_qwen3_5.py @@ -0,0 +1,394 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from collections.abc import Iterable +# mypy: ignore-errors + + +import torch +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule +from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_update +from vllm.model_executor.models.qwen3_5 import Qwen3_5GatedDeltaNet +from vllm.v1.attention.backend import AttentionMetadata # type: ignore +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata +from vllm.v1.attention.backends.utils import PAD_SLOT_ID + +from vllm_ascend.attention.utils import maybe_save_kv_layer_to_connector +from vllm_ascend.ops.triton.fla.sigmoid_gating import fused_sigmoid_gating_delta_rule_update +from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch +from vllm_ascend.utils import enable_sp + + +def _is_mamba_all_prefix_mode(forward_context) -> bool: + vllm_config = getattr(forward_context, "vllm_config", None) + if vllm_config is None: + return False + cache_config = getattr(vllm_config, "cache_config", None) + if cache_config is None: + return False + return ( + getattr(cache_config, "enable_prefix_caching", False) + and getattr(cache_config, "mamba_cache_mode", None) == "all" + ) + + +def _ensure_prefill_token_state_indices( + state_indices: torch.Tensor, + query_start_loc: torch.Tensor, +) -> torch.Tensor: + if state_indices.dim() > 1: + return state_indices + seq_lens = query_start_loc[1:] - query_start_loc[:-1] + if seq_lens.numel() == 0: + return state_indices.unsqueeze(1) + if seq_lens.shape[0] != state_indices.shape[0]: + return state_indices.unsqueeze(1) + + max_query_len = int(seq_lens.max().item()) + token_state_indices = torch.full( + (state_indices.shape[0], max_query_len), + PAD_SLOT_ID, + dtype=state_indices.dtype, + device=state_indices.device, + ) + positions = torch.arange(max_query_len, device=state_indices.device).unsqueeze(0) + valid_mask = positions < seq_lens.unsqueeze(1) + return torch.where(valid_mask, state_indices.unsqueeze(1), token_state_indices) + + +def _write_all_mode_token_conv_states( + mixed_qkv_non_spec: torch.Tensor, + conv_state: torch.Tensor, + token_state_indices: torch.Tensor, + query_start_loc: torch.Tensor, + has_initial_state: torch.Tensor, + state_width: int, +) -> None: + if state_width <= 0: + return + + seq_lens = (query_start_loc[1:] - query_start_loc[:-1]).tolist() + starts = query_start_loc[:-1].tolist() + for seq_id, (start, seq_len) in enumerate(zip(starts, seq_lens)): + if seq_len <= 0: + continue + + seq_indices = token_state_indices[seq_id] + seq_tokens = mixed_qkv_non_spec[start : start + seq_len] + first_state_idx = int(seq_indices[0].item()) + + if bool(has_initial_state[seq_id].item()): + rolling_state = conv_state[first_state_idx, :, :state_width].clone() + else: + rolling_state = torch.zeros( + (conv_state.shape[1], state_width), + dtype=conv_state.dtype, + device=conv_state.device, + ) + + for token_pos in range(seq_len): + if state_width > 1: + rolling_state[:, :-1] = rolling_state[:, 1:] + rolling_state[:, -1] = seq_tokens[token_pos].to(conv_state.dtype) + + state_idx = int(seq_indices[token_pos].item()) + if state_idx == PAD_SLOT_ID: + continue + conv_state[state_idx, :, :state_width] = rolling_state + + +class AscendQwen3_5GatedDeltaNet(Qwen3_5GatedDeltaNet): + def _forward_core( + self, + mixed_qkv: torch.Tensor, + b: torch.Tensor, + a: torch.Tensor, + core_attn_out: torch.Tensor, + ): + # Core attention computation (called by custom op). + + # NOTE: The processing logic of Qwen3_5GatedDeltaNet is the same as Qwen3NextGatedDeltaNet. + # However, because the ops `torch_npu.npu_recurrent_gated_delta_rule` + # currently does not support `ssm_state` inputs in float32 format, + # we temporarily retain the current _forward_core implementation. + # Once the ops supports float32 `ssm_state`, this patch should be removed. + + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + + if attn_metadata is None: + # V1 profile run + return + + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, GDNAttentionMetadata) + has_initial_state = attn_metadata.has_initial_state + spec_query_start_loc = attn_metadata.spec_query_start_loc + non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc + spec_sequence_masks = attn_metadata.spec_sequence_masks + spec_token_indx = attn_metadata.spec_token_indx + non_spec_token_indx = attn_metadata.non_spec_token_indx + spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 + non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 + token_state_indices = None + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + num_actual_tokens = attn_metadata.num_actual_tokens + num_accepted_tokens = attn_metadata.num_accepted_tokens + mamba_all_prefix_mode = _is_mamba_all_prefix_mode(forward_context) + + if not enable_sp(): + mixed_qkv = mixed_qkv[:num_actual_tokens] + b = b[:num_actual_tokens] + a = a[:num_actual_tokens] + + # 1. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) + if spec_sequence_masks is not None: + if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: + mixed_qkv_spec = mixed_qkv + mixed_qkv_non_spec = None + else: + mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx) + mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx) + else: + mixed_qkv_spec = None + mixed_qkv_non_spec = mixed_qkv + + # 1.1: Process the multi-query part + if spec_sequence_masks is not None: + mixed_qkv_spec = causal_conv1d_update( + mixed_qkv_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=spec_state_indices_tensor[:, 0][: attn_metadata.num_spec_decodes], + num_accepted_tokens=num_accepted_tokens, + query_start_loc=spec_query_start_loc, + max_query_len=spec_state_indices_tensor.size(-1), + validate_data=False, + ) + + # 1.2: Process the remaining part + if attn_metadata.num_prefills > 0: + if mixed_qkv_non_spec is not None: + conv_input_non_spec = mixed_qkv_non_spec + if mamba_all_prefix_mode: + token_state_indices = _ensure_prefill_token_state_indices( + non_spec_state_indices_tensor, + non_spec_query_start_loc, + ).contiguous() + conv_state_indices = token_state_indices[:, 0].contiguous() + else: + conv_state_indices = non_spec_state_indices_tensor + + conv_weights_T = conv_weights.transpose(0, 1) + mixed_qkv_non_spec = torch.ops._C_ascend.causal_conv1d_fn( + mixed_qkv_non_spec, + conv_weights_T, + self.conv1d.bias, + activation=self.activation, + conv_state=self_kv_cache[0], + has_initial_state=has_initial_state, + non_spec_state_indices_tensor=conv_state_indices, + non_spec_query_start_loc=non_spec_query_start_loc, + pad_slot_id=PAD_SLOT_ID, + ) + if mamba_all_prefix_mode: + _write_all_mode_token_conv_states( + mixed_qkv_non_spec=conv_input_non_spec, + conv_state=conv_state, + token_state_indices=token_state_indices, + query_start_loc=non_spec_query_start_loc, + has_initial_state=has_initial_state, + state_width=self.conv1d.weight.size(2) - 1, + ) + elif attn_metadata.num_decodes > 0: + mixed_qkv_non_spec = causal_conv1d_update( + mixed_qkv_non_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=non_spec_state_indices_tensor[: attn_metadata.num_actual_tokens], + validate_data=True, + ) + else: + mixed_qkv_non_spec = None + query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec) + query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(mixed_qkv_non_spec) + + if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None: + g, beta = fused_gdn_gating_patch(self.A_log, a, b, self.dt_bias) + if spec_sequence_masks is not None: + if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: + g_spec = g + beta_spec = beta + g_non_spec = None + beta_non_spec = None + else: + g_spec = g.index_select(1, spec_token_indx) + beta_spec = beta.index_select(1, spec_token_indx) + g_non_spec = g.index_select(1, non_spec_token_indx) + beta_non_spec = beta.index_select(1, non_spec_token_indx) + else: + g_spec = None + beta_spec = None + g_non_spec = g + beta_non_spec = beta + + # 2. Recurrent attention + + # 2.1: Process the multi-query part + if spec_sequence_masks is not None: + core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( + q=query_spec, + k=key_spec, + v=value_spec, + g=g_spec, + beta=beta_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1], + ssm_state_indices=spec_state_indices_tensor, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=True, + ) + else: + core_attn_out_spec, last_recurrent_state = None, None + + # 2.2: Process the remaining part + if attn_metadata.num_prefills > 0: + if mamba_all_prefix_mode: + if token_state_indices is None: + token_state_indices = _ensure_prefill_token_state_indices( + non_spec_state_indices_tensor, + non_spec_query_start_loc, + ).contiguous() + if (~has_initial_state).any(): + init_state_indices = token_state_indices[:, 0] + ssm_state[init_state_indices[~has_initial_state]] = 0 + + core_attn_out_non_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=non_spec_query_start_loc, + ssm_state_indices=token_state_indices, + use_qk_l2norm_in_kernel=True, + ) + else: + initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() + initial_state[~has_initial_state, ...] = 0 + ( + core_attn_out_non_spec, + last_recurrent_state, + ) = chunk_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=initial_state, + output_final_state=True, + cu_seqlens=non_spec_query_start_loc, + head_first=False, + use_qk_l2norm_in_kernel=True, + ) + # chunk path only produces the final state. + ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(ssm_state.dtype) + elif attn_metadata.num_decodes > 0: + core_attn_out_non_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1], + ssm_state_indices=non_spec_state_indices_tensor, + use_qk_l2norm_in_kernel=True, + ) + else: + core_attn_out_non_spec, last_recurrent_state = None, None + + elif attn_metadata.num_decodes > 0: + if mamba_all_prefix_mode: + g_non_spec, beta_non_spec = fused_gdn_gating_patch(self.A_log, a, b, self.dt_bias) + core_attn_out_non_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1], + ssm_state_indices=non_spec_state_indices_tensor, + use_qk_l2norm_in_kernel=True, + ) + else: + core_attn_out_non_spec = fused_sigmoid_gating_delta_rule_update( + A_log=self.A_log.contiguous(), + dt_bias=self.dt_bias.contiguous(), + q=query_non_spec.contiguous(), + k=key_non_spec.contiguous(), + v=value_non_spec.contiguous(), + a=a.contiguous(), + b=b.contiguous(), + initial_state_source=ssm_state, + initial_state_indices=non_spec_state_indices_tensor, + cu_seqlens=non_spec_query_start_loc, + use_qk_l2norm_in_kernel=True, + softplus_beta=1.0, + softplus_threshold=20.0, + ) + + # 3. Merge core attention output + if spec_sequence_masks is not None and core_attn_out_non_spec is not None: + merged_out = torch.empty( + (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), + dtype=core_attn_out_non_spec.dtype, + device=core_attn_out_non_spec.device, + ) + merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec) + merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec) + if not enable_sp(): + core_attn_out[:num_actual_tokens] = merged_out.squeeze(0) + else: + core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)[:num_actual_tokens] + elif spec_sequence_masks is not None: + if not enable_sp(): + core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0) + else: + core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)[:num_actual_tokens] + else: + if not enable_sp(): + core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) + else: + core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)[:num_actual_tokens] + maybe_save_kv_layer_to_connector("", []) + + +Qwen3_5GatedDeltaNet._forward_core = AscendQwen3_5GatedDeltaNet._forward_core diff --git a/vllm_ascend/patch/worker/patch_qwen3_next.py b/vllm_ascend/patch/worker/patch_qwen3_next.py index 7e7a5eecc07..2921a9b201c 100644 --- a/vllm_ascend/patch/worker/patch_qwen3_next.py +++ b/vllm_ascend/patch/worker/patch_qwen3_next.py @@ -215,13 +215,14 @@ def _forward_core( actual_seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] query_spec = l2norm_fwd(query_spec) key_spec = l2norm_fwd(key_spec) + recurrent_state = ssm_state if ssm_state.dtype == torch.bfloat16 else ssm_state.to(torch.bfloat16) core_attn_out_spec = torch_npu.npu_recurrent_gated_delta_rule( query=query_spec.squeeze(0), key=key_spec.squeeze(0), value=value_spec.squeeze(0), g=g_spec.squeeze(0), beta=beta_spec.squeeze(0), - state=ssm_state, + state=recurrent_state, scale=key_spec.shape[-1] ** -0.5, actual_seq_lengths=actual_seq_lengths, ssm_state_indices=spec_state_indices_tensor.flatten(), @@ -259,13 +260,14 @@ def _forward_core( actual_seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] query_non_spec = l2norm_fwd(query_non_spec) key_non_spec = l2norm_fwd(key_non_spec) + recurrent_state = ssm_state if ssm_state.dtype == torch.bfloat16 else ssm_state.to(torch.bfloat16) core_attn_out_non_spec = torch_npu.npu_recurrent_gated_delta_rule( query=query_non_spec.squeeze(0), key=key_non_spec.squeeze(0), value=value_non_spec.squeeze(0), g=g_non_spec.squeeze(0), beta=beta_non_spec.squeeze(0), - state=ssm_state, + state=recurrent_state, scale=key_non_spec.shape[-1] ** -0.5, actual_seq_lengths=actual_seq_lengths, ssm_state_indices=non_spec_state_indices_tensor, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 1bb40291ed2..9848c050090 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -47,8 +47,8 @@ from vllm.utils.mem_utils import DeviceMemoryProfiler from vllm.utils.torch_utils import get_dtype_size from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata -from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder -from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata, GDNAttentionMetadataBuilder +from vllm.v1.attention.backends.utils import CommonAttentionMetadata, PAD_SLOT_ID from vllm.v1.attention.selector import get_attn_backend # type: ignore from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import ( @@ -77,6 +77,10 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import record_function_or_nullcontext +from vllm.v1.worker import mamba_utils +from vllm.v1.worker.cp_utils import ( + get_total_cp_world_size, +) from vllm.v1.worker.gpu_model_runner import AsyncGPUModelRunnerOutput, GPUModelRunner from vllm.v1.worker.ubatch_utils import ( UBatchSlices, @@ -108,6 +112,10 @@ from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort from vllm_ascend.sample.sampler import AscendSampler from vllm_ascend.spec_decode import get_spec_decode_method +try: + from vllm_ascend.spec_decode.draft_proposer import AscendDraftModelProposer +except ModuleNotFoundError: + AscendDraftModelProposer = None from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer @@ -393,6 +401,18 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.cpu_slot_mapping = None self.sampling_done_event: torch.npu.Event | None = None + if vllm_version_is("0.17.0"): + # self.cudagraph_batch_sizes sorts in ascending order. + if ( + self.compilation_config.cudagraph_capture_sizes + and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ): + self.cudagraph_batch_sizes = sorted(self.compilation_config.cudagraph_capture_sizes) + else: + self.cudagraph_batch_sizes = [] + self.mamba_state_idx: dict[str, int] = {} + self._mamba_copy_bufs: mamba_utils.MambaCopyBuffers | None = None + @property def use_cp(self) -> bool: return self.pcp_size * self.dcp_size > 1 @@ -406,7 +426,11 @@ def _sync_device(self) -> None: def _set_up_drafter(self): # Set up speculative decoding. self.drafter: ( - AscendNgramProposer | AscendEagleProposer | AscendSuffixDecodingProposer | AscendMedusaProposer | None + AscendNgramProposer + | AscendEagleProposer + | AscendSuffixDecodingProposer + | AscendMedusaProposer + | None ) = None self.actual_seq_lengths_q: list[int] = [] self.decode_token_per_req = 1 @@ -420,9 +444,6 @@ def _set_up_drafter(self): assert isinstance(self.drafter, AscendEagleProposer) self.use_aux_hidden_state_outputs = self.drafter.eagle3_use_aux_hidden_state self.rejection_sampler = RejectionSampler(self.sampler) - self.actual_seq_lengths_q = list( - range(self.decode_token_per_req, self.max_num_tokens + 1, self.decode_token_per_req) - ) self.discard_request_indices = self._make_buffer(self.max_num_reqs, dtype=torch.int64) self.num_discarded_requests = 0 @@ -521,11 +542,24 @@ def get_model(self) -> nn.Module: return self.model.unwrap() return self.model - def _pad_query_start_loc_for_fia(self, num_tokens_padded: int, num_reqs_padded: int, num_reqs: int) -> int: + def _pad_query_start_loc_for_fia( + self, + num_tokens_padded: int, + num_reqs_padded: int, + num_reqs: int, + cudagraph_runtime_mode: CUDAGraphMode | None = None, + batch_desc_num_reqs: int | None = None, + ) -> int: """ This function is only designed to satisfied the constraint that when the layout is TND, the first dimension of `hidden_states` must equal the last element of `actual_seq_lengths_q`. """ + # TODO: need refactor later, related to vllm PR #34043 this pr delete func + # relax_for_mixed_batch_cudagraphs, num_reqs no longer equals the actual number of requests. + if cudagraph_runtime_mode == CUDAGraphMode.FULL: + num_reqs_padded = num_reqs + else: + num_reqs_padded = batch_desc_num_reqs if batch_desc_num_reqs is not None else num_reqs if num_tokens_padded == num_reqs_padded * self.uniform_decode_query_len: # Uniform-batch case: num_reqs must be no greater than num_reqs_padded @@ -961,7 +995,7 @@ def propose_draft_token_ids( draft_token_ids = self.drafter.propose( valid_sampled_token_ids, sampling_metadata, spec_decode_metadata, sample_hidden_states ) - elif self.speculative_config.use_eagle(): + elif self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model(): common_attn_metadata = spec_decode_common_attn_metadata sampled_token_ids = valid_sampled_token_ids @@ -1008,6 +1042,8 @@ def propose_draft_token_ids( long_seq_metadata = None # type: ignore num_prefill_reqs = 0 num_decode_reqs = 0 + + num_rejected_tokens_gpu = None if spec_decode_metadata is None: # update pcp related params if self.pcp_size > 1: @@ -1016,7 +1052,7 @@ def propose_draft_token_ids( target_positions = self._get_positions(num_scheduled_tokens) target_hidden_states = hidden_states if self.use_aux_hidden_state_outputs: - target_hidden_states = torch.cat([h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1) + target_hidden_states = torch.cat([h for h in aux_hidden_states], dim=-1) else: token_indices_to_sample = None # input_ids can be None for multimodal models. @@ -1043,8 +1079,10 @@ def propose_draft_token_ids( ) else: assert self.drafter is not None - common_attn_metadata, token_indices, token_indices_to_sample = self.drafter.prepare_inputs_padded( - common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count + common_attn_metadata, token_indices, token_indices_to_sample, num_rejected_tokens_gpu = ( + self.drafter.prepare_inputs_padded( + common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count + ) ) if self.pcp_size > 1: target_token_ids = input_ids_pcp_full[token_indices] @@ -1065,7 +1103,7 @@ def propose_draft_token_ids( target_positions=target_positions, target_hidden_states=target_hidden_states, next_token_ids=next_token_ids, - last_token_indices=token_indices_to_sample, + token_indices_to_sample=token_indices_to_sample, common_attn_metadata=common_attn_metadata, sampling_metadata=sampling_metadata, req_scheduled_tokens=req_scheduled_tokens, @@ -1074,6 +1112,7 @@ def propose_draft_token_ids( num_decode_reqs=num_decode_reqs, scheduler_output=scheduler_output, num_scheduled_tokens=num_scheduled_tokens, + num_rejected_tokens_gpu=num_rejected_tokens_gpu, ) else: raise ValueError(f"Unknown speculative decoding method: {self.speculative_config.method}") @@ -1206,6 +1245,23 @@ def execute_model( pad_attn = cudagraph_mode == CUDAGraphMode.FULL + # NOTE(Angazenn): According to https://github.com/vllm-project/vllm/pull/30877, + # there should be a corresponding 'postprocess_mamba'. However, it is called inside + # '_update_states_after_model_execute', which is not overridden in vLLM-Ascend. + # We simply utilize the implementation in vLLM. + if self.cache_config.mamba_cache_mode in ("align", "all"): + mamba_utils.preprocess_mamba( + scheduler_output, + self.kv_cache_config, + self.cache_config, + self.mamba_state_idx, + self.input_batch, + self.requests, + self.compilation_config.static_forward_context, + self.model.get_mamba_state_copy_func(), + self._get_mamba_copy_bufs(), + ) + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices @@ -1218,7 +1274,9 @@ def execute_model( # Another possible condition is num_tokens_padded != num_tokens_unpadded # but this scope is way too big and the consequences are unpredictable old_num_reqs_padded = num_reqs_padded - num_reqs_padded = self._pad_query_start_loc_for_fia(num_tokens_padded, num_reqs_padded, num_reqs) + num_reqs_padded = self._pad_query_start_loc_for_fia( + num_tokens_padded, num_reqs_padded, num_reqs, cudagraph_mode, batch_desc.num_reqs + ) if enable_sp() and num_tokens_padded == num_tokens_unpadded: if num_reqs_padded > old_num_reqs_padded: num_reqs_padded = old_num_reqs_padded @@ -1291,48 +1349,27 @@ def execute_model( # Run forward pass clear_kv_metadata = self.speculative_config is None - if vllm_version_is("0.16.0"): - with ( - record_function_or_nullcontext("forward"), - set_ascend_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_tokens_padded, - num_tokens_across_dp=num_tokens_across_dp, - aclgraph_runtime_mode=cudagraph_mode, - batch_descriptor=batch_desc, - num_actual_tokens=scheduler_output.total_num_scheduled_tokens, - model_instance=self.model, - max_tokens_across_pcp=0 if self.pcp_size == 1 else self.pcp_manager.max_num_tokens_across_pcp, - skip_compiled=has_encoder_input, - ), - self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, - ): - hidden_states = self._model_forward( - num_tokens_padded, input_ids, positions, intermediate_tensors, inputs_embeds, **model_kwargs - ) - else: - with ( - record_function_or_nullcontext("forward"), - set_ascend_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_tokens_padded, - num_tokens_across_dp=num_tokens_across_dp, - aclgraph_runtime_mode=cudagraph_mode, - batch_descriptor=batch_desc, - num_actual_tokens=scheduler_output.total_num_scheduled_tokens, - model_instance=self.model, - max_tokens_across_pcp=0 if self.pcp_size == 1 else self.pcp_manager.max_num_tokens_across_pcp, - skip_compiled=has_encoder_input, - ), - self.maybe_get_kv_connector_output( - scheduler_output, clear_metadata=clear_kv_metadata - ) as kv_connector_output, - ): - hidden_states = self._model_forward( - num_tokens_padded, input_ids, positions, intermediate_tensors, inputs_embeds, **model_kwargs - ) + with ( + record_function_or_nullcontext("forward"), + set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens_padded, + num_tokens_across_dp=num_tokens_across_dp, + aclgraph_runtime_mode=cudagraph_mode, + batch_descriptor=batch_desc, + num_actual_tokens=scheduler_output.total_num_scheduled_tokens, + model_instance=self.model, + max_tokens_across_pcp=0 if self.pcp_size == 1 else self.pcp_manager.max_num_tokens_across_pcp, + skip_compiled=has_encoder_input, + ), + self.maybe_get_kv_connector_output( + scheduler_output, clear_metadata=clear_kv_metadata + ) as kv_connector_output, + ): + hidden_states = self._model_forward( + num_tokens_padded, input_ids, positions, intermediate_tensors, inputs_embeds, **model_kwargs + ) with record_function_or_nullcontext("post process"): aux_hidden_states = None if self.use_aux_hidden_state_outputs: @@ -1418,6 +1455,11 @@ def sample_tokens( if self.execute_model_state is None: # Nothing to do (PP non-final rank case), output isn't used. + # receive sampled token ids from the last PP rank when using + # async scheduling + pipeline parallelism so downstream code + # (e.g., PCP input preparation) can access them. + if self.use_async_scheduling and get_pp_group().world_size > 1: + self._pp_receive_prev_sampled_token_ids_to_input_batch() if not kv_connector_output: return None # noqa # In case of PP with kv transfer, we need to pass through the @@ -1499,16 +1541,16 @@ def propose_draft_token_ids(sampled_token_ids): with record_function_or_nullcontext("draft_token"): if self.speculative_config: - use_padded_batch_for_eagle = ( + use_padded_batch = ( self.speculative_config - and self.speculative_config.use_eagle() + and (self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model()) and not self.speculative_config.disable_padded_drafter_batch ) - if use_padded_batch_for_eagle: + if use_padded_batch: # EAGLE speculative decoding can use the GPU sampled tokens # as inputs, and does not need to wait for bookkeeping to finish. propose_draft_token_ids(sampler_output.sampled_token_ids) - if self.speculative_config and not use_padded_batch_for_eagle: + if self.speculative_config and not use_padded_batch: # ngram and other speculative decoding methods use the sampled # tokens on the CPU, so they are run after bookkeeping. propose_draft_token_ids(valid_sampled_token_ids) @@ -1552,6 +1594,14 @@ def propose_draft_token_ids(sampled_token_ids): global_stream().wait_event(self.sampling_done_event) self._update_states_after_model_execute(sampler_output.sampled_token_ids, scheduler_output) + # In async scheduling + PP, broadcast sampled token ids from the + # last PP rank so other PP ranks can receive them without going + # through the scheduler/engine IPC path. + if self.use_async_scheduling: + pp = get_pp_group() + if pp.world_size > 1 and pp.is_last_rank: + self._pp_broadcast_prev_sampled_token_ids(sampler_output.sampled_token_ids) + if not self.use_async_scheduling: return model_runner_output return AsyncGPUModelRunnerOutput( @@ -1877,23 +1927,14 @@ def dispatch_cudagraph(num_tokens, disable_full=False, valid_modes=None): if force_eager: return (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded)) - if vllm_version_is("0.16.0"): - return self.cudagraph_dispatcher.dispatch( - num_tokens=num_tokens, - has_lora=has_lora, - uniform_decode=uniform_decode, - disable_full=disable_full, - num_active_loras=num_active_loras, - ) - else: - return self.cudagraph_dispatcher.dispatch( - num_tokens=num_tokens, - has_lora=has_lora, - uniform_decode=uniform_decode, - valid_modes=valid_modes, - invalid_modes={CUDAGraphMode.FULL} if disable_full else None, - num_active_loras=num_active_loras, - ) + return self.cudagraph_dispatcher.dispatch( + num_tokens=num_tokens, + has_lora=has_lora, + uniform_decode=uniform_decode, + valid_modes=valid_modes, + invalid_modes={CUDAGraphMode.FULL} if disable_full else None, + num_active_loras=num_active_loras, + ) cudagraph_mode, batch_descriptor = dispatch_cudagraph(num_tokens_padded, use_cascade_attn or has_encoder_output) num_tokens_padded = batch_descriptor.num_tokens @@ -1915,16 +1956,10 @@ def dispatch_cudagraph(num_tokens, disable_full=False, valid_modes=None): dp_rank = self.parallel_config.data_parallel_rank num_tokens_padded = int(num_tokens_across_dp[dp_rank].item()) # Re-dispatch with DP padding - if vllm_version_is("0.16.0"): - cudagraph_mode, batch_descriptor = dispatch_cudagraph( - num_tokens_padded, - disable_full=synced_cudagraph_mode <= CUDAGraphMode.PIECEWISE.value, - ) - else: - cudagraph_mode, batch_descriptor = dispatch_cudagraph( - num_tokens_padded, - valid_modes={CUDAGraphMode(synced_cudagraph_mode)}, - ) + cudagraph_mode, batch_descriptor = dispatch_cudagraph( + num_tokens_padded, + valid_modes={CUDAGraphMode(synced_cudagraph_mode)}, + ) # Assert to make sure the agreed upon token count is correct otherwise # num_tokens_across_dp will no-longer be valid assert batch_descriptor.num_tokens == num_tokens_padded @@ -2073,6 +2108,48 @@ def _get_block_table_and_slot_mapping(kv_cache_gid: int): cm_base.num_logits_indices = logits_indices.size(0) cm_base.logits_indices_padded = self._prepare_kv_sharing_fast_prefill(logits_indices) + def _maybe_normalize_mamba_all_prefill_state_indices(attn_metadata_i: AttentionMetadata) -> None: + if not ( + self.cache_config.enable_prefix_caching and self.cache_config.mamba_cache_mode == "all" + ): + return + if not isinstance(attn_metadata_i, GDNAttentionMetadata): + return + if attn_metadata_i.num_prefills <= 0: + return + + state_indices = attn_metadata_i.non_spec_state_indices_tensor + query_start_loc = attn_metadata_i.non_spec_query_start_loc + if state_indices is None or query_start_loc is None: + return + + if state_indices.dim() > 1: + attn_metadata_i.non_spec_state_indices_tensor = state_indices.contiguous() + return + + seq_lens = query_start_loc[1:] - query_start_loc[:-1] + if seq_lens.numel() == 0: + attn_metadata_i.non_spec_state_indices_tensor = state_indices.unsqueeze(1).contiguous() + return + if seq_lens.shape[0] != state_indices.shape[0]: + attn_metadata_i.non_spec_state_indices_tensor = state_indices.unsqueeze(1).contiguous() + return + + max_query_len = int(seq_lens.max().item()) + token_state_indices = torch.full( + (state_indices.shape[0], max_query_len), + PAD_SLOT_ID, + dtype=state_indices.dtype, + device=state_indices.device, + ) + positions = torch.arange(max_query_len, device=state_indices.device).unsqueeze(0) + valid_mask = positions < seq_lens.unsqueeze(1) + attn_metadata_i.non_spec_state_indices_tensor = torch.where( + valid_mask, + state_indices.unsqueeze(1), + token_state_indices, + ).contiguous() + def _build_attn_group_metadata( kv_cache_gid: int, attn_gid: int, @@ -2102,6 +2179,7 @@ def _build_attn_group_metadata( common_attn_metadata=common_attn_metadata, **extra_attn_metadata_args, ) + _maybe_normalize_mamba_all_prefill_state_indices(attn_metadata_i) if ubid is None: assert isinstance(attn_metadata, dict) @@ -2140,7 +2218,12 @@ def _build_attn_group_metadata( if kv_cache_gid > 0: cm.block_table_tensor, cm.slot_mapping = _get_block_table_and_slot_mapping(kv_cache_gid) if self.speculative_config and spec_decode_common_attn_metadata is None: - if isinstance(self.drafter, AscendEagleProposer): + draft_proposer_types = ( + (AscendEagleProposer, AscendDraftModelProposer) + if AscendDraftModelProposer is not None + else (AscendEagleProposer,) + ) + if isinstance(self.drafter, draft_proposer_types): if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: spec_decode_common_attn_metadata = cm else: @@ -2324,8 +2407,9 @@ def _dummy_run( cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) self.query_start_loc.np[1 : num_reqs_padded + 1] = cum_num_tokens self.query_start_loc.copy_to_gpu() - - num_reqs_padded = self._pad_query_start_loc_for_fia(num_tokens_padded, num_reqs_padded, num_reqs) + num_reqs_padded = self._pad_query_start_loc_for_fia( + num_tokens_padded, num_reqs_padded, num_reqs, cudagraph_runtime_mode, batch_desc.num_reqs + ) pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL attn_metadata, _ = self._build_attention_metadata( @@ -2518,6 +2602,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config + self._mamba_copy_bufs = None self.may_add_encoder_only_layers_to_kv_cache_config() self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) # NOTE(cmq): initialize_attn_backend must before using self.attn_groups @@ -2530,6 +2615,19 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.may_reinitialize_input_batch(kv_cache_config) kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) + if vllm_version_is("0.17.0"): + # TODO: refactor the logic of attention + # Initialize drafter attention group initialization + if self.speculative_config and ( + self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model() + ): + draft_proposer_types = ( + (AscendEagleProposer, AscendDraftModelProposer) + if AscendDraftModelProposer is not None + else (AscendEagleProposer,) + ) + assert isinstance(self.drafter, draft_proposer_types) + self.drafter.initialize_attn_backend(kv_cache_config, self.kernel_block_sizes) if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(kv_caches) @@ -2839,8 +2937,8 @@ def _reshape_kv_cache_tensors( # a conv state in some special models. target_shape = (num_blocks, *shape) - target_idx += torch.prod(torch.tensor(target_shape)).item() - tensor = raw_tensor.view(dtype)[start_idx:target_idx].view(target_shape) + target_idx += math.prod(target_shape) * get_dtype_size(dtype) + tensor = raw_tensor[start_idx:target_idx].view(dtype).view(target_shape) start_idx = target_idx state_tensors.append(tensor) kv_caches[layer_name] = state_tensors @@ -2868,7 +2966,7 @@ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: # For attention backends that support virtual block splitting, # use the supported block sizes from the backend # For other backends (like Mamba), use [0] (no splitting) - kernel_block_sizes = [] + self.kernel_block_sizes = [] for kv_cache_group_id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): kv_cache_spec = kv_cache_group.kv_cache_spec if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs): @@ -2895,15 +2993,30 @@ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: else: # Fallback to cache config block_size if no backend found kernel_block_size_list = [self.cache_config.block_size] - kernel_block_sizes.append(kernel_block_size_list) + self.kernel_block_sizes.append(kernel_block_size_list) else: # This is likely Mamba or other non-attention cache, # no splitting. # NOTE: set kernel_block_sizes to 0 to disable slotmapping computation # of mamba block. In this case, BlockTable.block_size will never equal # to kernel_block_sizes[0] - kernel_block_sizes.append([0]) - if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [[self.cache_config.block_size]]: + self.kernel_block_sizes.append([0]) + + max_num_blocks = [] + max_model_len = max(self.max_model_len, self.max_encoder_len) + for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): + if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec): + continue + max_num_blocks_per_req = cdiv(max_model_len, block_sizes[i] * get_total_cp_world_size()) + if isinstance(kv_cache_group.kv_cache_spec, MambaSpec): + mamba_blocks_per_req = ( + max_num_blocks_per_req if self.cache_config.enable_prefix_caching else 1 + ) + kv_cache_group.kv_cache_spec.num_speculative_blocks + + max_num_blocks_per_req = max(max_num_blocks_per_req, mamba_blocks_per_req) + max_num_blocks.append(max_num_blocks_per_req) + + if block_sizes != [self.cache_config.block_size] or self.kernel_block_sizes != [[self.cache_config.block_size]]: assert self.cache_config.cpu_offload_gb == 0, ( "Cannot re-initialize the input batch when CPU weight " "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 @@ -2911,7 +3024,7 @@ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: ) self.input_batch = NPUInputBatch( max_num_reqs=self.max_num_reqs, - max_model_len=max(self.model_config.max_model_len, self.max_encoder_len), + max_model_len=max_model_len, max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, @@ -2925,7 +3038,7 @@ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: if self.vllm_config.speculative_config else 0 ), - kernel_block_sizes=kernel_block_sizes, + kernel_block_sizes=self.kernel_block_sizes, ) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: @@ -3078,8 +3191,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: mamba_layers[layer_name] = attn_module if len(mamba_layers) > 0: - if self.vllm_config.cache_config.enable_prefix_caching: - raise NotImplementedError("Prefix caching is not supported for Mamba yet.") mamba_page_size_padded = 0 for layer_name, mamba_module in mamba_layers.items(): if spec := mamba_module.get_kv_cache_spec(self.vllm_config): @@ -3209,6 +3320,8 @@ def _replace_gpu_model_runner_function_wrapper(target_module_name): target_module = sys.modules[target_module_name] setattr(target_module, "graph_capture", graph_capture) # noqa: B010 yield + except Exception as e: + raise RuntimeError(f"NPUModelRunner failed, error is {e}") finally: setattr(target_module, "graph_capture", graph_capture) # noqa: B010