-
Notifications
You must be signed in to change notification settings - Fork 3.4k
feat: integrate deepgemm into EPMoE #6821
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 26 commits
92d647c
e057acb
19ec50e
3ce1a91
3d51a71
c80fc3c
af94a8b
2022070
55ea483
988a522
1f81f01
b4ae984
c6d51d2
0397c25
aeb437d
d2c19bf
0d3e793
ee1a2b1
d530ee7
cb6ea36
2769865
a7a6235
a1493e5
b1edec5
773f151
4975723
6584c21
c54a141
86f06b6
15c9514
afd55c2
b596fa2
fe97d1f
431f47a
1422da6
bbbd98a
35cc918
1a26425
a902338
d6b3afd
b51f324
ad9abb2
1869b18
77123c6
ddc0c33
8e71771
a6d61f6
e30b3ab
aafbe4e
0ea5bd9
87d68e9
26040f4
bbb127d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1085,3 +1085,207 @@ def tma_align_input_scale(input_scale: torch.Tensor): | |
| BLOCK_SIZE_K=BLOCK_SIZE_K, | ||
| ) | ||
| return output.t()[:m] | ||
|
|
||
|
|
||
| @triton.jit | ||
| def compute_masked_m_triton_kernel(seg_indptr, masked_m): | ||
| expert_id = tl.program_id(0) | ||
| start = tl.load(seg_indptr + expert_id) | ||
| end = tl.load(seg_indptr + expert_id + 1) | ||
| tl.store(masked_m + expert_id, (end - start)) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def deepgemm_compute_src2dst_triton_kernel( | ||
| topk_ids, | ||
| reorder_ids, | ||
| seg_indptr, | ||
| src2dst, | ||
| m_max, | ||
| num_toks, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| ): | ||
| pid = tl.program_id(axis=0) | ||
| dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
| mask = dst_id < num_toks | ||
| src_id = tl.load(reorder_ids + dst_id, mask=mask) | ||
| expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks)) | ||
| expert_dst_start = tl.load(seg_indptr + expert_id) | ||
| expert_dst_offset = dst_id - expert_dst_start | ||
| dst_id = expert_id * m_max + expert_dst_offset | ||
| tl.store(src2dst + src_id, dst_id, mask=mask) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def fill_gateup_input_triton_kernel( | ||
| input_ptr, | ||
| scale_ptr, | ||
| gateup_input_ptr, | ||
| gateup_input_scale_ptr, | ||
| src2dst_ptr, | ||
| topk_ids_ptr, | ||
| start_expert_id, | ||
| end_expert_id, | ||
| topk, | ||
| m_max, | ||
| hidden_size, | ||
| scale_size, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| ): | ||
|
|
||
| src_idx_int32 = tl.program_id(0) | ||
| src_idx = src_idx_int32.to(tl.int64) | ||
| src2dst_ptr = src2dst_ptr + src_idx * topk | ||
| topk_ids_ptr = topk_ids_ptr + src_idx * topk | ||
| src_ptr = input_ptr + src_idx * hidden_size | ||
| scale_src_ptr = scale_ptr + src_idx * scale_size | ||
|
|
||
| for idx in range(topk): | ||
| expert_id = tl.load(topk_ids_ptr + idx) | ||
| if expert_id >= start_expert_id and expert_id <= end_expert_id: | ||
| dst_idx_int32 = tl.load(src2dst_ptr + idx) | ||
| dst_idx = dst_idx_int32.to(tl.int64) | ||
ch-wan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| dst_idx = dst_idx - start_expert_id * m_max | ||
| dst_ptr = gateup_input_ptr + dst_idx * hidden_size | ||
| for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): | ||
| offset = start_offset + tl.arange(0, BLOCK_SIZE) | ||
|
||
| mask = offset < hidden_size | ||
| in_data = tl.load(src_ptr + offset, mask=mask) | ||
| tl.store(dst_ptr + offset, in_data, mask=mask) | ||
| scale_dst_ptr = gateup_input_scale_ptr + dst_idx * scale_size | ||
| for start_offset in tl.range(0, scale_size, BLOCK_SIZE): | ||
| offset = start_offset + tl.arange(0, BLOCK_SIZE) | ||
| mask = offset < scale_size | ||
| in_scale = tl.load(scale_src_ptr + offset, mask=mask) | ||
| tl.store(scale_dst_ptr + offset, in_scale, mask=mask) | ||
|
|
||
|
|
||
| def moe_ep_deepgemm_preprocess( | ||
| topk_ids: torch.Tensor, | ||
| num_experts: int, | ||
| hidden_states: torch.Tensor, | ||
| top_k: int, | ||
| start_expert_id, | ||
| end_expert_id, | ||
| block_shape, | ||
| output_dtype: torch.dtype = torch.float8_e4m3fn, | ||
| ): | ||
| reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) | ||
| seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64) | ||
| src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32) | ||
| masked_m = torch.zeros(num_experts, device=topk_ids.device, dtype=torch.int32) | ||
|
|
||
| compute_seg_indptr_triton_kernel[(num_experts,)]( | ||
| reorder_topk_ids, seg_indptr, topk_ids.numel() | ||
| ) | ||
|
|
||
| grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),) | ||
| compute_masked_m_triton_kernel[(num_experts,)](seg_indptr, masked_m) | ||
|
|
||
| # For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165 | ||
| m_max = (hidden_states.size(0) + 255) // 256 * 256 | ||
| expected_m = (topk_ids.numel() + num_experts - 1) // num_experts | ||
| gateup_input = torch.empty( | ||
| (int(end_expert_id - start_expert_id + 1), m_max, hidden_states.size(1)), | ||
| device=hidden_states.device, | ||
| dtype=output_dtype, | ||
| ) | ||
|
|
||
| deepgemm_compute_src2dst_triton_kernel[grid]( | ||
| topk_ids, | ||
| reorder_ids, | ||
| seg_indptr, | ||
| src2dst, | ||
| m_max, | ||
| topk_ids.numel(), | ||
| BLOCK_SIZE=256, | ||
| ) | ||
|
|
||
| assert block_shape is not None, "block_shape is not None" | ||
| assert len(block_shape) == 2 | ||
| block_n, block_k = block_shape[0], block_shape[1] | ||
| hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k) | ||
|
|
||
| gateup_input_scale = torch.empty( | ||
| (gateup_input.size(0), gateup_input.size(1), scale.size(1)), | ||
| device=hidden_states.device, | ||
| dtype=scale.dtype, | ||
| ) | ||
|
|
||
| fill_gateup_input_triton_kernel[(hidden_states.shape[0],)]( | ||
| hidden_states, | ||
| scale, | ||
| gateup_input, | ||
| gateup_input_scale, | ||
| src2dst, | ||
| topk_ids, | ||
| start_expert_id, | ||
| end_expert_id, | ||
| top_k, | ||
| m_max, | ||
| hidden_states.size(1), | ||
| scale.size(1), | ||
| BLOCK_SIZE=1024, | ||
| ) | ||
|
|
||
| return ( | ||
| m_max, | ||
| masked_m[start_expert_id : (end_expert_id + 1)], | ||
| expected_m, | ||
| src2dst, | ||
| gateup_input, | ||
| gateup_input_scale, | ||
| ) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def deepgemm_post_reorder_triton_kernel( | ||
|
||
| down_output_ptr, | ||
| output_ptr, | ||
| src2dst_ptr, | ||
| topk_ids_ptr, | ||
| topk_weights_ptr, | ||
| start_expert_id, | ||
| end_expert_id, | ||
| topk, | ||
| hidden_size, | ||
| max_m, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| ): | ||
| InDtype = down_output_ptr.dtype.element_ty | ||
|
|
||
| src_idx_int32 = tl.program_id(0) | ||
| src_idx = src_idx_int32.to(tl.int64) | ||
| src2dst_ptr = src2dst_ptr + src_idx * topk | ||
| topk_ids_ptr = topk_ids_ptr + src_idx * topk | ||
| topk_weights_ptr = topk_weights_ptr + src_idx * topk | ||
|
|
||
| computed = False | ||
| store_ptr = output_ptr + src_idx * hidden_size | ||
|
|
||
| vec = tl.arange(0, BLOCK_SIZE) | ||
| for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): | ||
| offset = start_offset + vec | ||
| mask = offset < hidden_size | ||
|
|
||
| sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype) | ||
| for idx in range(topk): | ||
| expert_id = tl.load(topk_ids_ptr + idx) | ||
| if expert_id >= start_expert_id and expert_id <= end_expert_id: | ||
| computed = True | ||
| dst_idx_int32 = tl.load(src2dst_ptr + idx) | ||
| dst_idx = dst_idx_int32.to(tl.int64) | ||
| dst_idx = dst_idx - start_expert_id * max_m | ||
| weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype) | ||
| load_ptr = down_output_ptr + dst_idx * hidden_size | ||
| in_data = tl.load(load_ptr + offset, mask=mask) | ||
| sum_vec += in_data * weigh_scale | ||
| tl.store(store_ptr + offset, sum_vec, mask=mask) | ||
|
|
||
| if not computed: | ||
| for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): | ||
| offset = start_offset + vec | ||
| mask = offset < hidden_size | ||
| tl.store( | ||
| store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.