Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions python/sglang/srt/speculative/eagle_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,25 @@ def generate_attn_arg_prefill(
kv_indices,
req_to_token.size(1),
)
mask_numel = (
paged_kernel_lens_sum * self.draft_token_num
+ (self.draft_token_num**2) * batch_size
)
if self.custom_mask.numel() < mask_numel:
# FIXME(attn): temporary fix for custom mask padding with cuda graph
self.custom_mask = torch.cat(
[
self.custom_mask,
torch.full(
(mask_numel - self.custom_mask.numel(),),
True,
dtype=torch.bool,
device=device,
),
],
dim=0,
)

return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask

def verify(
Expand Down
5 changes: 5 additions & 0 deletions test/registered/spec/eagle/test_eagle_infer_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sglang.srt.environ import envs
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.test.few_shot_gsm8k import run_eval as run_gsm8k_eval
from sglang.test.kits.radix_cache_server_kit import run_radix_attention_test
from sglang.test.server_fixtures.eagle_fixture import EagleServerBase
from sglang.test.test_utils import (
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
Expand All @@ -39,6 +40,10 @@ def test_request_abort(self):
for p in threads:
p.join()

def test_radix_attention(self):
run_radix_attention_test(self.base_url)
self.assertIsNone(self.process.poll())

def test_max_token_one(self):
requests.get(self.base_url + "/flush_cache")

Expand Down
Loading