Integrate DeepGeMM MegaMoE#40843
Conversation
Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu> Signed-off-by: Woosuk Kwon <woosuk@inferact.ai> Signed-off-by: Nick Hill <nickhill123@gmail.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: yasong.wang <yasong.wang@inferact.ai> Signed-off-by: Zhewen Li <zhewenli@inferact.ai> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
|
Documentation preview: https://vllm--40843.org.readthedocs.build/en/40843/ |
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Code Review
This pull request adds support for the Deepseek V4 model architecture, including a horizontally-fused MLA kernel, an MTP draft model for speculative decoding, and a custom tokenizer. It introduces optimized CUDA and Triton kernels for softplus-sqrt Top-K gating and MHC blocks. Feedback centers on performance and memory safety: the reviewer recommends replacing frequent GPU tensor allocations in metadata builders and fused kernels with pre-allocated buffers to support CUDA graph capture and reduce overhead. Additionally, the reviewer identified missing record_stream calls for tensors executed on auxiliary streams and advised making hardcoded sccache settings in the Dockerfile configurable.
| && export SCCACHE_BUCKET=inferact-sccache \ | ||
| && export SCCACHE_REGION=us-west-2 \ | ||
| && export SCCACHE_S3_NO_CREDENTIALS=0\ |
There was a problem hiding this comment.
| lambda: self.indexer( | ||
| hidden_states, qr, positions, self.indexer_rotary_emb | ||
| ), | ||
| kv_insert_and_compress, | ||
| self.ln_events[0], | ||
| self.ln_events[1], | ||
| self.aux_stream, | ||
| ) | ||
| elif self.compressor is not None: |
There was a problem hiding this comment.
In attention_impl, several tensors are used on an auxiliary stream via maybe_execute_in_parallel without being recorded on that stream. To ensure memory safety and prevent the caching allocator from reusing memory while the auxiliary stream is still active, you must call record_stream(self.aux_stream) on hidden_states, qr, and positions before the parallel execution block.
| # SWA-only layer: no compressor, no overlap. | ||
| self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) | ||
|
|
||
| # Handle dummy run (no metadata). | ||
| if not isinstance(attn_metadata, dict): | ||
| # Reserve _forward_prefill's bf16-gather workspace; the dummy | ||
| # run returns before mla_attn runs, so without this the shared | ||
| # workspace locks below the real prefill size. | ||
| sub = self.mla_attn |
| global_indices, topk_lens = compute_global_topk_indices_and_lens( | ||
| self.topk_indices_buffer[:num_decode_tokens], | ||
| swa_metadata.token_to_req_indices, | ||
| attn_metadata.block_table[:num_decodes], | ||
| block_size, | ||
| is_valid, | ||
| ) |
There was a problem hiding this comment.
| combined_indices, combined_lens = combine_topk_swa_indices( | ||
| topk_indices[query_start:query_end], | ||
| query_start_loc[ | ||
| num_decodes + chunk_start : num_decodes + chunk_end + 1 | ||
| ], | ||
| seq_lens[chunk_start:chunk_end], | ||
| gather_lens[chunk_start:chunk_end], | ||
| self.window_size, | ||
| self.compress_ratio, | ||
| top_k, | ||
| M, | ||
| N, | ||
| ) |
| post_mix = torch.empty( | ||
| num_tokens, | ||
| hc_mult, | ||
| dtype=torch.float32, | ||
| device=residual.device, | ||
| ) | ||
| comb_mix = torch.empty( | ||
| num_tokens, | ||
| hc_mult2, | ||
| dtype=torch.float32, | ||
| device=residual.device, | ||
| ) | ||
| layer_input = torch.empty( | ||
| num_tokens, | ||
| hidden_size, | ||
| dtype=torch.bfloat16, | ||
| device=residual.device, | ||
| ) | ||
|
|
||
| gemm_out_mul = torch.empty( | ||
| n_splits, | ||
| num_tokens, | ||
| hc_mult3, | ||
| dtype=torch.float32, | ||
| device=residual.device, | ||
| ) | ||
| gemm_out_sqrsum = torch.empty( | ||
| n_splits, | ||
| num_tokens, | ||
| dtype=torch.float32, | ||
| device=residual.device, | ||
| ) |
There was a problem hiding this comment.
The mhc_pre function performs multiple GPU tensor allocations (post_mix, comb_mix, layer_input, gemm_out_mul, gemm_out_sqrsum) on every forward pass. These should be moved to the layer's initialization or managed via a persistent workspace to avoid allocation overhead and support CUDA graph capture.
| output_scale_packed = torch.zeros( | ||
| (num_packed_groups, tma_aligned_M), | ||
| dtype=torch.int32, | ||
| device=input.device, | ||
| ).T[:M, :] |
| token_to_seq = torch.empty(total_seq_lens, dtype=torch.int32, device=device) | ||
|
|
||
| cu_seq_lens = torch.empty(num_reqs + 1, dtype=torch.int32, device=device) | ||
| # Assigning to slice avoids cpu sync. | ||
| cu_seq_lens[:1] = 0 | ||
| torch.cumsum(compressed_seq_lens[start_idx:end_idx], dim=0, out=cu_seq_lens[1:]) | ||
|
|
||
| query_start_loc = ( | ||
| query_start_loc[start_idx : end_idx + 1] - query_start_loc[start_idx] | ||
| ) | ||
|
|
||
| total_query_len = int( | ||
| (query_start_loc_cpu[end_idx] - query_start_loc_cpu[start_idx]).item() | ||
| ) | ||
| if query_slice is not None: | ||
| qs_start = query_slice.start | ||
| qs_stop = query_slice.stop | ||
| else: | ||
| qs_start = 0 | ||
| qs_stop = total_query_len | ||
| output_query_len = qs_stop - qs_start | ||
|
|
||
| cu_seq_len_ks = torch.empty(output_query_len, dtype=torch.int32, device=device) | ||
| cu_seq_len_ke = torch.empty(output_query_len, dtype=torch.int32, device=device) | ||
|
|
There was a problem hiding this comment.
The build_prefill_chunk_metadata function allocates several GPU tensors (token_to_seq, cu_seq_lens, cu_seq_len_ks, cu_seq_len_ke) during the metadata build phase. In the V1 architecture, metadata builders should avoid GPU allocations to maintain performance and ensure CUDA graph stability. These should be pre-allocated.
| pfx_gather_lens = torch.empty( | ||
| num_prefills, dtype=torch.int32, device=seq_lens.device | ||
| ) |
|
|
||
| index_q_fp8 = torch.empty_like(index_q, dtype=torch.float8_e4m3fn) | ||
| _fused_indexer_q_rope_quant_kernel[(num_tokens, num_index_q_heads)]( | ||
| positions, | ||
| index_q, | ||
| index_q.stride(0), | ||
| index_q.stride(1), | ||
| index_q_cos_sin_cache, | ||
| index_q_cos_sin_cache.stride(0), | ||
| index_q_cos_sin_cache.shape[-1] // 2, | ||
| index_q_fp8, |
There was a problem hiding this comment.
The fused_indexer_q_rope_quant function performs multiple GPU tensor allocations (index_weights_out, index_q_packed, index_q_scale, index_q_fp8) on every call. These allocations should be moved to the layer's initialization or handled via a persistent workspace to avoid overhead and support CUDA graph capture.
No description provided.