Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/sglang/srt/mem_cache/allocator_ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ def alloc_extend(
device=self.device,
)
torch.ops.npu.alloc_extend(
prefix_lens,
seq_lens,
last_loc,
prefix_lens.to(torch.int64),
seq_lens.to(torch.int64),
last_loc.to(torch.int64),
self.free_pages,
self.page_size,
out_indices,
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def __init__(self, model_runner: ModelRunner):
seq_len_fill_value=self.seq_len_fill_value,
encoder_len_fill_value=self.encoder_len_fill_value,
num_tokens_per_bs=self.num_tokens_per_bs,
cache_loc_dtype=self._cache_loc_dtype(),
)

self.tbo_plugin = TboCudaGraphRunnerPlugin()
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/model_executor/input_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@ def create(
seq_len_fill_value: int,
encoder_len_fill_value: int,
num_tokens_per_bs: int,
cache_loc_dtype: torch.dtype,
) -> "GraphInputBuffers":
with torch.device(device):
input_ids = torch.zeros((max_num_token,), dtype=torch.int64)
input_embeds = torch.zeros((max_num_token, hidden_size), dtype=dtype)
req_pool_indices = torch.zeros((max_bs,), dtype=torch.int32)
seq_lens = torch.full((max_bs,), seq_len_fill_value, dtype=torch.int32)
out_cache_loc = torch.zeros((max_num_token,), dtype=torch.int64)
out_cache_loc = torch.zeros((max_num_token,), dtype=cache_loc_dtype)
positions = torch.zeros((max_num_token,), dtype=torch.int64)
mrope_positions = torch.zeros((3, max_num_token), dtype=torch.int64)
num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/model_executor/npu_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def replay(
self.replay_prepare(forward_batch, pp_proxy_tensors)
else:
# In speculative decoding, these two fields are still needed.
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
self.buffers.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
self.buffers.positions[: self.raw_num_token].copy_(forward_batch.positions)

# Replay
if not is_deepseek_nsa(self.model_runner.model_config.hf_config):
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/speculative/spec_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ def assign_req_to_token_pool_func(
torch.ops.npu.cache_loc_assign(
req_pool_indices,
req_to_token,
start_offset,
end_offset,
start_offset.to(torch.int64),
end_offset.to(torch.int64),
out_cache_loc,
)

Expand Down
22 changes: 0 additions & 22 deletions test/srt/ascend/test_ascend_tp1_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
run_bench_offline_throughput,
)

TEST_MODEL_MATRIX = {
Expand Down Expand Up @@ -71,26 +69,6 @@ def test_a_gsm8k(self):
finally:
kill_process_tree(process.pid)

def test_b_throughput(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing throughput: {model} ===##")

output_throughput = run_bench_offline_throughput(
model,
[
*self.common_args,
],
)

print(f"##=== {model} throughput: {output_throughput} ===##")

if is_in_ci():
self.assertGreater(
output_throughput,
TEST_MODEL_MATRIX[model]["output_throughput"],
)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@
],
"per-commit-16-npu-a3": [
TestFile("ascend/test_ascend_deepep.py", 400),
TestFile("ascend/test_ascend_deepseek_mtp.py", 400),
# TestFile("ascend/test_ascend_deepseek_mtp.py", 400),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test file is commented out, likely to fix the CI. To avoid this becoming technical debt, please add a TODO comment with a reference to the relevant issue(s) (e.g., #13478, #13676). This will help track the need to re-enable the test once the underlying bugs are fully resolved.

        # TODO(#13478, #13676): Re-enable this test after fixing the related bugs.
        # TestFile("ascend/test_ascend_deepseek_mtp.py", 400),

],
}

Expand Down
Loading