Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions csrc/causal_conv1d/op_host/causal_conv1d_tiling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,29 @@ static ge::graphStatus GetShapeDtypeInfo(gert::TilingContext* context, CausalCon
auto ciShapePtr = context->GetInputShape(CACHE_INDICES_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, ciShapePtr);
auto ciShape = EnsureNotScalar(ciShapePtr->GetStorageShape());
if(ciShape.GetDimNum() != 1){return ge::GRAPH_FAILED;}
if(ciShape.GetDim(0) != batch){return ge::GRAPH_FAILED;}
int64_t cacheIndicesMode = 0;
int64_t maxQueryLen = 1;
if (ciShape.GetDimNum() == 1) {
if (ciShape.GetDim(0) != batch) {
return ge::GRAPH_FAILED;
}
cacheIndicesMode = 0;
maxQueryLen = 1;
} else if (ciShape.GetDimNum() == 2) {
if (ciShape.GetDim(0) != batch) {
return ge::GRAPH_FAILED;
}
cacheIndicesMode = 1;
maxQueryLen = ciShape.GetDim(1);
if (maxQueryLen <= 0) {
return ge::GRAPH_FAILED;
}
if (inputMode == 1 && maxQueryLen < seqLen) {
return ge::GRAPH_FAILED;
}
} else {
return ge::GRAPH_FAILED;
}

auto hisShapePtr = context->GetInputShape(HAS_INITIAL_STATE_INDEX);
OP_CHECK_NULL_WITH_CONTEXT(context, hisShapePtr);
Expand Down Expand Up @@ -289,6 +310,8 @@ static ge::graphStatus GetShapeDtypeInfo(gert::TilingContext* context, CausalCon
tiling.set_stateLen(stateLen);
tiling.set_numCacheLines(numCacheLines);
tiling.set_batch(batch);
tiling.set_cacheIndicesMode(cacheIndicesMode);
tiling.set_maxQueryLen(maxQueryLen);
return ge::GRAPH_SUCCESS;
}

Expand Down
3 changes: 3 additions & 0 deletions csrc/causal_conv1d/op_host/causal_conv1d_tiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ BEGIN_TILING_DATA_DEF(CausalConv1dTilingData)

TILING_DATA_FIELD_DEF(int64_t, dimTileSize);
TILING_DATA_FIELD_DEF(int64_t, blocksPerSeq);

TILING_DATA_FIELD_DEF(int64_t, cacheIndicesMode);
TILING_DATA_FIELD_DEF(int64_t, maxQueryLen);
END_TILING_DATA_DEF;
struct CausalConv1dCompileInfo {
uint64_t ubSize = 0;
Expand Down
50 changes: 46 additions & 4 deletions csrc/causal_conv1d/op_kernel/causal_conv1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ struct CausalConv1dTilingData {
// Channel-wise tiling
int64_t dimTileSize;
int64_t blocksPerSeq;

// cacheIndices shape mode: 0 -> [batch], 1 -> [batch, maxQueryLen]
int64_t cacheIndicesMode;
int64_t maxQueryLen;
};
#endif // CAUSAL_CONV1D_TILING_DATA_H_

Expand Down Expand Up @@ -383,6 +387,8 @@ __aicore__ inline void CausalConv1d<T>::Process()
const int32_t seqLen = tilingData_->seqLen;
const int32_t dimTileSize = static_cast<int32_t>(tilingData_->dimTileSize);
const int32_t blocksPerSeq = static_cast<int32_t>(tilingData_->blocksPerSeq);
const int32_t cacheIndicesMode = static_cast<int32_t>(tilingData_->cacheIndicesMode);
const int32_t maxQueryLen = static_cast<int32_t>(tilingData_->maxQueryLen);

const uint32_t blockIdx = GetBlockIdx();
const uint32_t blockNum = GetBlockNum();
Expand Down Expand Up @@ -417,16 +423,52 @@ __aicore__ inline void CausalConv1d<T>::Process()
continue;
}

const int32_t cacheIdx = cacheIndicesGm.GetValue(seq);
if (cacheIdx == tilingData_->padSlotId) {
continue;
int32_t cacheIdx = 0;
if (cacheIndicesMode == 0) {
cacheIdx = cacheIndicesGm.GetValue(seq);
if (cacheIdx == tilingData_->padSlotId) {
continue;
}
} else {
if (maxQueryLen <= 0) {
continue;
}
cacheIdx = cacheIndicesGm.GetValue(static_cast<int64_t>(seq) * maxQueryLen);
if (cacheIdx == tilingData_->padSlotId) {
continue;
}
if (len > maxQueryLen) {
// guard for malformed metadata; fallback to final-state path only
len = maxQueryLen;
}
}

const bool hasInit = hasInitialStateGm.GetValue(seq);

InitRing(cacheIdx, hasInit, start, len, c0, dimTileSize, dim, dbg);
RunSeq(start, len, c0, dimTileSize, dim, dbg);
WriteBackState(cacheIdx, len, c0, dimTileSize, dim, dbg);
if (cacheIndicesMode == 0) {
WriteBackState(cacheIdx, len, c0, dimTileSize, dim, dbg);
} else {
const int32_t stateLen = tilingData_->stateLen;
LocalTensor<T> ring = inBuf.Get<T>();
for (int32_t t = 0; t < len; ++t) {
const int32_t tokenCacheIdx =
cacheIndicesGm.GetValue(static_cast<int64_t>(seq) * maxQueryLen + t);
if (tokenCacheIdx == tilingData_->padSlotId) {
continue;
}
for (int32_t pos = 0; pos < (MAX_WIDTH - 1); ++pos) {
const int32_t tap = (MAX_WIDTH - 2) - pos;
const int32_t slot = (tap == 0) ? SlotCurr(t) : SlotHist(t, tap);
const int64_t stateOffset =
static_cast<int64_t>(tokenCacheIdx) * stateLen * dim + static_cast<int64_t>(pos) * dim + c0;
PipeBarrier<PIPE_ALL>();
DataCopy(convStatesGm[stateOffset], ring[slot * MAX_BLOCK_DIM], dimTileSize);
PipeBarrier<PIPE_ALL>();
}
}
}
}

ReleaseEvents();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,63 @@
causal_conv1d_update_npu as causal_conv1d_update
from vllm_ascend.utils import enable_custom_op


def test_ascend_causal_conv1d_2d_state_indices_writeback():
"""Validate token-level state snapshots when cache indices are 2D [batch, max_query_len]."""
torch.random.manual_seed(0)
enable_custom_op()

device = "npu"
dtype = torch.float16
dim = 384 # keep aligned with kernel tiling candidates
width = 4
state_len = width - 1

# Two variable-length sequences: [3, 2]
query_start_loc = torch.tensor([0, 3, 5], device=device, dtype=torch.int32)
x_cpu = torch.stack(
[torch.full((dim,), float(i + 1), dtype=torch.float32) for i in range(5)],
dim=0,
)
x = x_cpu.to(device=device, dtype=dtype)

# Zero weights isolate state writeback semantics from convolution math.
weight = torch.zeros((width, dim), device=device, dtype=dtype)
conv_states = torch.zeros((32, state_len, dim), device=device, dtype=dtype)
has_initial_state = torch.tensor([False, False], device=device, dtype=torch.bool)
cache_indices_2d = torch.tensor(
[
[10, 11, 12],
[20, 21, PAD_SLOT_ID],
],
device=device,
dtype=torch.int32,
)

_ = torch.ops._C_ascend.causal_conv1d_fn(
x,
weight,
None,
activation="silu",
conv_state=conv_states,
has_initial_state=has_initial_state,
non_spec_state_indices_tensor=cache_indices_2d,
non_spec_query_start_loc=query_start_loc,
pad_slot_id=PAD_SLOT_ID,
)

conv_cpu = conv_states.float().cpu()
expected = {
10: torch.stack([torch.zeros(dim), torch.zeros(dim), x_cpu[0]], dim=0),
11: torch.stack([torch.zeros(dim), x_cpu[0], x_cpu[1]], dim=0),
12: torch.stack([x_cpu[0], x_cpu[1], x_cpu[2]], dim=0),
20: torch.stack([torch.zeros(dim), torch.zeros(dim), x_cpu[3]], dim=0),
21: torch.stack([torch.zeros(dim), x_cpu[3], x_cpu[4]], dim=0),
}

for cache_line, ref in expected.items():
validate_cmp(conv_cpu[cache_line], ref, torch.float32, device="cpu")

def validate_cmp(y_cal, y_ref, dtype, device='npu'):
y_cal = y_cal.to(device)
y_ref = y_ref.to(device)
Expand Down
49 changes: 49 additions & 0 deletions tests/ut/core/test_recompute_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from types import SimpleNamespace
from unittest.mock import patch

from tests.ut.base import TestBase
from vllm_ascend.core.recompute_scheduler import RecomputeScheduler


class TestRecomputeScheduler(TestBase):

def _fake_scheduler_init(self, *args, **kwargs):
self.vllm_config = kwargs["vllm_config"]
# Baseline expectation from upstream-style init.
self.need_mamba_block_aligned_split = True

def _build_vllm_config(self, model_type: str, mamba_cache_mode: str):
return SimpleNamespace(
speculative_config=None,
kv_transfer_config=None,
model_config=SimpleNamespace(
hf_text_config=SimpleNamespace(model_type=model_type),
),
cache_config=SimpleNamespace(
mamba_cache_mode=mamba_cache_mode,
),
)

@patch("vllm_ascend.core.recompute_scheduler.Scheduler.__init__", new=_fake_scheduler_init)
def test_all_mode_disables_block_aligned_split_for_qwen3_5(self):
cfg = self._build_vllm_config("qwen3_5", "all")

scheduler = RecomputeScheduler(vllm_config=cfg)

self.assertFalse(scheduler.need_mamba_block_aligned_split)

@patch("vllm_ascend.core.recompute_scheduler.Scheduler.__init__", new=_fake_scheduler_init)
def test_non_all_mode_keeps_block_aligned_split_for_qwen3_5(self):
cfg = self._build_vllm_config("qwen3_5", "align")

scheduler = RecomputeScheduler(vllm_config=cfg)

self.assertTrue(scheduler.need_mamba_block_aligned_split)

@patch("vllm_ascend.core.recompute_scheduler.Scheduler.__init__", new=_fake_scheduler_init)
def test_all_mode_keeps_non_hybrid_unchanged(self):
cfg = self._build_vllm_config("llama", "all")

scheduler = RecomputeScheduler(vllm_config=cfg)

self.assertTrue(scheduler.need_mamba_block_aligned_split)
75 changes: 75 additions & 0 deletions tests/ut/patch/platform/test_patch_mamba_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from types import SimpleNamespace
from unittest.mock import patch

import torch

from tests.ut.base import TestBase
from vllm_ascend.patch.platform.patch_mamba_config import verify_and_update_config


class _FakeModelCls:

@staticmethod
def get_mamba_state_shape_from_config(_vllm_config):
# ssm: 256 bytes, conv: 128 bytes when dtype is fp16
return [(128,), (64,)]

@staticmethod
def get_mamba_state_dtype_from_config(_vllm_config):
return [torch.float16, torch.float16]


class TestPatchMambaConfig(TestBase):

def _build_vllm_config(self, mamba_cache_mode: str):
cache_config = SimpleNamespace(
cache_dtype="auto",
block_size=None,
mamba_page_size_padded=None,
enable_prefix_caching=True,
mamba_cache_mode=mamba_cache_mode,
mamba_block_size=None,
)
model_config = SimpleNamespace(
dtype=torch.float16,
architecture="FakeArch",
max_model_len=4096,
get_num_kv_heads=lambda _parallel: 1,
get_head_size=lambda: 1,
)
parallel_config = SimpleNamespace()
return SimpleNamespace(
cache_config=cache_config,
model_config=model_config,
parallel_config=parallel_config,
)

@patch("vllm_ascend.patch.platform.patch_mamba_config.MambaModelConfig.verify_and_update_config")
@patch("vllm_ascend.patch.platform.patch_mamba_config.ModelRegistry.resolve_model_cls")
def test_prefix_caching_all_mode_uses_block_size(self, mock_resolve_model_cls, _mock_verify):
mock_resolve_model_cls.return_value = (_FakeModelCls, None)
vllm_config = self._build_vllm_config("all")

verify_and_update_config.__func__(None, vllm_config)

self.assertEqual(vllm_config.cache_config.mamba_block_size, vllm_config.cache_config.block_size)

@patch("vllm_ascend.patch.platform.patch_mamba_config.MambaModelConfig.verify_and_update_config")
@patch("vllm_ascend.patch.platform.patch_mamba_config.ModelRegistry.resolve_model_cls")
def test_prefix_caching_align_mode_uses_block_size(self, mock_resolve_model_cls, _mock_verify):
mock_resolve_model_cls.return_value = (_FakeModelCls, None)
vllm_config = self._build_vllm_config("align")

verify_and_update_config.__func__(None, vllm_config)

self.assertEqual(vllm_config.cache_config.mamba_block_size, vllm_config.cache_config.block_size)

@patch("vllm_ascend.patch.platform.patch_mamba_config.MambaModelConfig.verify_and_update_config")
@patch("vllm_ascend.patch.platform.patch_mamba_config.ModelRegistry.resolve_model_cls")
def test_prefix_caching_non_all_align_uses_max_model_len(self, mock_resolve_model_cls, _mock_verify):
mock_resolve_model_cls.return_value = (_FakeModelCls, None)
vllm_config = self._build_vllm_config("legacy")

verify_and_update_config.__func__(None, vllm_config)

self.assertEqual(vllm_config.cache_config.mamba_block_size, vllm_config.model_config.max_model_len)
Loading
Loading