Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 49 additions & 10 deletions python/sglang/srt/speculative/build_eagle_tree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# NOTE: Please run this file to make sure the test cases are correct.

from typing import List
import math
from enum import IntEnum
from typing import List, Optional

import torch

Expand Down Expand Up @@ -40,6 +42,12 @@ def build_tree_kernel_efficient_preprocess(
return parent_list, top_scores_index, draft_tokens


class TreeMaskMode(IntEnum):
FULL_MASK = 0
QLEN_ONLY = 1
QLEN_ONLY_BITPACKING = 2


def build_tree_kernel_efficient(
verified_id: torch.Tensor,
score_list: List[torch.Tensor],
Expand All @@ -50,6 +58,9 @@ def build_tree_kernel_efficient(
topk: int,
spec_steps: int,
num_verify_tokens: int,
tree_mask_mode: TreeMaskMode,
tree_mask_buf: Optional[torch.Tensor],
position_buf: Optional[torch.Tensor],
):
parent_list, top_scores_index, draft_tokens = (
build_tree_kernel_efficient_preprocess(
Expand All @@ -66,15 +77,37 @@ def build_tree_kernel_efficient(
device = seq_lens.device
# e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
# where each row indicates the attending pattern of each draft token
# if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed)
if tree_mask_buf is not None:
tree_mask = tree_mask_buf
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY:
tree_mask = torch.full(
(num_verify_tokens * bs * num_verify_tokens,),
True,
dtype=torch.bool,
device=device,
)
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
packed_dtypes = [torch.uint8, torch.uint16, torch.uint32]
packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8)))
tree_mask = torch.zeros(
(num_verify_tokens * bs,),
dtype=packed_dtypes[packed_dtype_idx],
device=device,
)
Comment on lines +91 to +97
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The calculation for packed_dtype_idx can result in an index out of bounds for packed_dtypes. If num_verify_tokens is large, packed_dtype_idx can be out of bounds, leading to an IndexError. Consider adding a check to ensure packed_dtype_idx is within the valid range.

        packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8)))
        if packed_dtype_idx >= len(packed_dtypes):
            raise ValueError(f"num_verify_tokens={num_verify_tokens} is too large for bitpacking.")
        tree_mask = torch.zeros(
            (num_verify_tokens * bs,),
            dtype=packed_dtypes[packed_dtype_idx],
            device=device,

elif tree_mask_mode == TreeMaskMode.FULL_MASK:
tree_mask = torch.full(
(
seq_lens_sum * num_verify_tokens
+ num_verify_tokens * num_verify_tokens * bs,
),
True,
device=device,
)
else:
raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")

# TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
tree_mask = torch.full(
(
seq_lens_sum * num_verify_tokens
+ num_verify_tokens * num_verify_tokens * bs,
),
True,
device=device,
)
retrive_index = torch.full(
(bs, num_verify_tokens), -1, device=device, dtype=torch.long
)
Expand All @@ -87,7 +120,12 @@ def build_tree_kernel_efficient(
# position: where each token belongs to
# e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
# then, positions = [7, 8, 8, 9]
positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
if position_buf is not None:
positions = position_buf
else:
positions = torch.empty(
(bs * num_verify_tokens,), device=device, dtype=torch.long
)

sgl_build_tree_kernel_efficient(
parent_list,
Expand All @@ -101,6 +139,7 @@ def build_tree_kernel_efficient(
topk,
spec_steps,
num_verify_tokens,
tree_mask_mode,
)
return (
tree_mask,
Expand Down
3 changes: 2 additions & 1 deletion sgl-kernel/csrc/common_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def(
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()");
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num, int tree_mask_mode) -> "
"()");
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);

m.def(
Expand Down
51 changes: 36 additions & 15 deletions sgl-kernel/csrc/speculative/eagle_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include "pytorch_extension_utils_rocm.h"
#endif

typedef enum { FULL_MASK = 0, QLEN_ONLY = 1, QLEN_ONLY_BITPACKING = 2 } TreeMaskMode;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider using enum class instead of typedef enum for better type safety and to avoid polluting the global namespace. This enforces stronger scoping and prevents implicit conversions.

enum class TreeMaskMode { FULL_MASK = 0, QLEN_ONLY = 1, QLEN_ONLY_BITPACKING = 2 };


// parent_list [bs, topk * (depth - 1) + 1)]
// selected_index [bs, draft_token_num - 1]
// verified_seq_len [bs]
Expand All @@ -40,7 +42,8 @@ __global__ void build_tree_efficient(
int64_t* retrive_next_sibling,
int topk,
int depth,
int draft_token_num) {
int draft_token_num,
int tree_mask_mode) {
int bid = blockIdx.x;
int tid = threadIdx.x;

Expand All @@ -52,7 +55,13 @@ __global__ void build_tree_efficient(
seq_tree_idx += verified_seq_len[i] * draft_token_num;
}
int seq_len = verified_seq_len[bid];
int token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1;
int token_tree_idx;
if (tree_mask_mode == FULL_MASK) {
token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1;
} else {
Comment on lines +59 to +61
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calculation of token_tree_idx for FULL_MASK mode relies on seq_tree_idx, which is computed using an inefficient loop. Consider pre-calculating this prefix sum on the CPU and passing it to the kernel to improve performance.

token_tree_idx = draft_token_num * draft_token_num * bid + draft_token_num * tid + 1;
}
tree_mask[token_tree_idx - 1] = true;
for (int i = 0; i < draft_token_num - 1; i++) {
tree_mask[token_tree_idx + i] = false;
}
Expand Down Expand Up @@ -124,26 +133,38 @@ void build_tree_kernel_efficient(
at::Tensor retrive_next_sibling,
int64_t topk,
int64_t depth,
int64_t draft_token_num) {
int64_t draft_token_num,
int64_t tree_mask_mode) {
// TODO (ying) check shape
// TODO (ying) check type
int bs = parent_list.size(0);
dim3 grid(bs);
dim3 block(draft_token_num);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

build_tree_efficient<<<grid, block, 0, stream>>>(
static_cast<int64_t*>(parent_list.data_ptr()),
static_cast<int64_t*>(selected_index.data_ptr()),
static_cast<int64_t*>(verified_seq_len.data_ptr()),
static_cast<bool*>(tree_mask.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()),
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
int32_t(topk),
int32_t(depth),
int32_t(draft_token_num));
if (tree_mask_mode == QLEN_ONLY_BITPACKING) {
size_t num_bytes_per_item = 1;
if (draft_token_num > 16) {
num_bytes_per_item = 4;
} else if (draft_token_num > 8) {
num_bytes_per_item = 2;
}
throw std::runtime_error("Not implemented");
Comment on lines +145 to +152
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The QLEN_ONLY_BITPACKING mode is not implemented and throws a std::runtime_error. Either complete the implementation or disable this mode in the Python code to prevent runtime errors.

} else {
build_tree_efficient<<<grid, block, 0, stream>>>(
static_cast<int64_t*>(parent_list.data_ptr()),
static_cast<int64_t*>(selected_index.data_ptr()),
static_cast<int64_t*>(verified_seq_len.data_ptr()),
static_cast<bool*>(tree_mask.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()),
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
int32_t(topk),
int32_t(depth),
int32_t(draft_token_num),
int32_t(tree_mask_mode));
}
}

template <typename IdType, typename IdType2>
Expand Down
3 changes: 2 additions & 1 deletion sgl-kernel/include/sgl_kernel_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,8 @@ void build_tree_kernel_efficient(
at::Tensor retrive_next_sibling,
int64_t topk,
int64_t depth,
int64_t draft_token_num);
int64_t draft_token_num,
int64_t tree_mask_mode);

void segment_packbits(
at::Tensor x,
Expand Down
Loading