-
Notifications
You must be signed in to change notification settings - Fork 3.4k
feat: integrate deepgemm into EPMoE #6821
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
92d647c
e057acb
19ec50e
3ce1a91
3d51a71
c80fc3c
af94a8b
2022070
55ea483
988a522
1f81f01
b4ae984
c6d51d2
0397c25
aeb437d
d2c19bf
0d3e793
ee1a2b1
d530ee7
cb6ea36
2769865
a7a6235
a1493e5
b1edec5
773f151
4975723
6584c21
c54a141
86f06b6
15c9514
afd55c2
b596fa2
fe97d1f
431f47a
1422da6
bbbd98a
35cc918
1a26425
a902338
d6b3afd
b51f324
ad9abb2
1869b18
77123c6
ddc0c33
8e71771
a6d61f6
e30b3ab
aafbe4e
0ea5bd9
87d68e9
26040f4
bbb127d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1085,3 +1085,210 @@ def tma_align_input_scale(input_scale: torch.Tensor): | |
| BLOCK_SIZE_K=BLOCK_SIZE_K, | ||
| ) | ||
| return output.t()[:m] | ||
|
|
||
|
|
||
| @triton.jit | ||
| def compute_masked_m_triton_kernel(seg_indptr, masked_m, num_experts, N): | ||
| expert_id = tl.program_id(0) | ||
| start = tl.load(seg_indptr + expert_id) | ||
| end = tl.load(seg_indptr + expert_id + 1) | ||
| tl.store(masked_m + expert_id, (end - start)) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def deepgemm_compute_src2dst_triton_kernel( | ||
| topk_ids, | ||
| reorder_ids, | ||
| seg_indptr, | ||
| src2dst, | ||
| max_m, | ||
| num_toks, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| ): | ||
| pid = tl.program_id(axis=0) | ||
| dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
| mask = dst_id < num_toks | ||
| src_id = tl.load(reorder_ids + dst_id, mask=mask) | ||
| expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks)) | ||
| expert_dst_start = tl.load(seg_indptr + expert_id) | ||
| expert_dst_offset = dst_id - expert_dst_start | ||
| dst_id = expert_id * max_m + expert_dst_offset | ||
| tl.store(src2dst + src_id, dst_id, mask=mask) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def fill_gateup_input_triton_kernel( | ||
| input_ptr, | ||
| scale_ptr, | ||
| gateup_input_ptr, | ||
| gateup_input_scale_ptr, | ||
| src2dst_ptr, | ||
| topk_ids_ptr, | ||
| start_expert_id, | ||
| end_expert_id, | ||
| topk, | ||
| m_max, | ||
| hidden_size, | ||
| scale_size, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| ): | ||
|
|
||
| src_idx = tl.program_id(0) | ||
| src2dst_ptr = src2dst_ptr + src_idx * topk | ||
| topk_ids_ptr = topk_ids_ptr + src_idx * topk | ||
| src_ptr = input_ptr + src_idx * hidden_size | ||
| scale_src_ptr = scale_ptr + src_idx * scale_size | ||
|
|
||
| for idx in range(topk): | ||
| expert_id = tl.load(topk_ids_ptr + idx) | ||
| if expert_id >= start_expert_id and expert_id <= end_expert_id: | ||
| dst_idx = tl.load(src2dst_ptr + idx) | ||
| dst_idx = dst_idx - start_expert_id * m_max | ||
| dst_ptr = gateup_input_ptr + dst_idx * hidden_size | ||
| for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): | ||
| offset = start_offset + tl.arange(0, BLOCK_SIZE) | ||
|
||
| mask = offset < hidden_size | ||
| in_data = tl.load(src_ptr + offset, mask=mask) | ||
| tl.store(dst_ptr + offset, in_data, mask=mask) | ||
| scale_dst_ptr = gateup_input_scale_ptr + dst_idx * scale_size | ||
| for start_offset in tl.range(0, scale_size, BLOCK_SIZE): | ||
| offset = start_offset + tl.arange(0, BLOCK_SIZE) | ||
| mask = offset < scale_size | ||
| in_scale = tl.load(scale_src_ptr + offset, mask=mask) | ||
| tl.store(scale_dst_ptr + offset, in_scale, mask=mask) | ||
|
|
||
|
|
||
| def exp2_upper(num: int) -> int: | ||
| for i in range(2, 31): | ||
| value = pow(2, i) | ||
| if num <= value: | ||
| return value | ||
| return num | ||
|
||
|
|
||
|
|
||
| def moe_ep_deepgemm_preproess( | ||
| topk_ids: torch.Tensor, | ||
| num_experts: int, | ||
| hidden_states: torch.Tensor, | ||
| top_k: int, | ||
| start_expert_id, | ||
| end_expert_id, | ||
| block_shape, | ||
| output_dtype: torch.dtype = torch.float8_e4m3fn, | ||
| ): | ||
| reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) | ||
| seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64) | ||
| src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32) | ||
| masked_m = torch.zeros(num_experts, device=topk_ids.device, dtype=torch.int32) | ||
|
|
||
| compute_seg_indptr_triton_kernel[(num_experts,)]( | ||
| reorder_topk_ids, seg_indptr, topk_ids.numel() | ||
| ) | ||
|
|
||
| grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),) | ||
| compute_masked_m_triton_kernel[(num_experts,)]( | ||
| seg_indptr, masked_m, num_experts, reorder_topk_ids.numel() | ||
| ) | ||
|
|
||
| m_max = exp2_upper(hidden_states.size(0)) | ||
| expected_m = (topk_ids.numel() + num_experts - 1) // num_experts | ||
| gateup_input = torch.empty( | ||
| (int(end_expert_id - start_expert_id + 1), m_max, hidden_states.size(1)), | ||
| device=hidden_states.device, | ||
| dtype=output_dtype, | ||
| ) | ||
|
|
||
| deepgemm_compute_src2dst_triton_kernel[grid]( | ||
| topk_ids, | ||
| reorder_ids, | ||
| seg_indptr, | ||
| src2dst, | ||
| m_max, | ||
| topk_ids.numel(), | ||
| BLOCK_SIZE=256, | ||
| ) | ||
|
|
||
| if block_shape is not None: | ||
| assert len(block_shape) == 2 | ||
| block_n, block_k = block_shape[0], block_shape[1] | ||
| hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k) | ||
|
|
||
| gateup_input_scale = torch.empty( | ||
| (gateup_input.size(0), gateup_input.size(1), scale.size(1)), | ||
| device=hidden_states.device, | ||
| dtype=scale.dtype, | ||
| ) | ||
|
|
||
| fill_gateup_input_triton_kernel[(hidden_states.shape[0],)]( | ||
| hidden_states, | ||
| scale, | ||
| gateup_input, | ||
| gateup_input_scale, | ||
| src2dst, | ||
| topk_ids, | ||
| start_expert_id, | ||
| end_expert_id, | ||
| top_k, | ||
| m_max, | ||
| hidden_states.size(1), | ||
| scale.size(1), | ||
| BLOCK_SIZE=1024, | ||
| ) | ||
|
||
|
|
||
| return ( | ||
| m_max, | ||
| masked_m[start_expert_id : (end_expert_id + 1)], | ||
| expected_m, | ||
| src2dst, | ||
| gateup_input, | ||
| gateup_input_scale, | ||
| ) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def deepgemm_post_reorder_triton_kernel( | ||
|
||
| down_output_ptr, | ||
| output_ptr, | ||
| src2dst_ptr, | ||
| topk_ids_ptr, | ||
| topk_weights_ptr, | ||
| start_expert_id, | ||
| end_expert_id, | ||
| topk, | ||
| hidden_size, | ||
| max_m, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| ): | ||
| InDtype = down_output_ptr.dtype.element_ty | ||
|
|
||
| src_idx = tl.program_id(0) | ||
| src2dst_ptr = src2dst_ptr + src_idx * topk | ||
| topk_ids_ptr = topk_ids_ptr + src_idx * topk | ||
| topk_weights_ptr = topk_weights_ptr + src_idx * topk | ||
|
|
||
| computed = False | ||
| store_ptr = output_ptr + src_idx * hidden_size | ||
| for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): | ||
| offset = start_offset + tl.arange(0, BLOCK_SIZE) | ||
| mask = offset < hidden_size | ||
|
|
||
| sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype) | ||
| for idx in range(topk): | ||
| expert_id = tl.load(topk_ids_ptr + idx) | ||
| if expert_id >= start_expert_id and expert_id <= end_expert_id: | ||
| computed = True | ||
| dst_idx = tl.load(src2dst_ptr + idx) | ||
| dst_idx = dst_idx - start_expert_id * max_m | ||
| weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype) | ||
| load_ptr = down_output_ptr + dst_idx * hidden_size | ||
| in_data = tl.load(load_ptr + offset, mask=mask) | ||
| sum_vec += in_data * weigh_scale | ||
| tl.store(store_ptr + offset, sum_vec, mask=mask) | ||
|
|
||
| if computed == False: | ||
| for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): | ||
| offset = start_offset + tl.arange(0, BLOCK_SIZE) | ||
| mask = offset < hidden_size | ||
| tl.store( | ||
| store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask | ||
| ) | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -31,10 +31,12 @@ | |||||
| get_tensor_model_parallel_world_size, | ||||||
| ) | ||||||
| from sglang.srt.layers.moe.ep_moe.kernels import ( | ||||||
| deepgemm_post_reorder_triton_kernel, | ||||||
| ep_gather, | ||||||
| ep_scatter, | ||||||
| gelu_and_mul_triton_kernel, | ||||||
| grouped_gemm_triton, | ||||||
| moe_ep_deepgemm_preproess, | ||||||
|
||||||
| moe_ep_deepgemm_preproess, | |
| moe_ep_deepgemm_preprocess, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter
Nin thecompute_masked_m_triton_kernelfunction signature appears to be unused within the kernel's body. If it's not required for any logic (e.g., boundary checks that might be missing or planned), could it be removed to simplify the signature and avoid confusion? The kernel is launched with a grid size of(num_experts,)and usestl.program_id(0)to getexpert_id, suggestingNmight not be directly used for indexing in its current form.