[MUSA][16/N] Add MUSA backend support for layers#19509
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces comprehensive support for Moore Threads (MUSA) GPUs within the SGLang runtime (SRT). It enables efficient LLM inference on MUSA hardware by integrating MUSA-specific logic and implementations across critical components such as activation functions, normalization, DeepGEMM operations, Mixture-of-Experts layers, FP8 quantization, and sampling mechanisms. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces MUSA backend support for various layers in the sglang library, including activation, quantization, layernorm, MoE, and FP8 linear operations. The changes involve adding new forward_musa implementations, updating conditional checks for kernel imports and execution paths, and modifying deep_gemm integration to handle MUSA-specific requirements. The code has been reviewed, and several issues related to conditional checks, potential dead code, and code duplication have been identified and reported in the review comments.
| if not _is_musa: | ||
| intermediate_cache2 = torch.empty( | ||
| (total_tokens, N // 2), | ||
| device=hidden_states.device, | ||
| dtype=hidden_states.dtype, | ||
| ) |
There was a problem hiding this comment.
| if not _is_musa: | ||
| intermediate_cache2 = torch.empty( | ||
| (M * topk_ids.shape[1], N // 2), | ||
| device=hidden_states.device, | ||
| dtype=hidden_states.dtype, | ||
| ) |
There was a problem hiding this comment.
| elif not _is_musa: | ||
| hidden_states_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor( | ||
| hidden_states_scale | ||
| ) |
There was a problem hiding this comment.
The code path for not _is_musa uses deep_gemm_wrapper.get_mn_major_tma_aligned_tensor, while the else path (when _is_musa) does not have any code. This could lead to incorrect behavior or performance degradation on MUSA devices. Please ensure that the MUSA backend has a proper implementation for this case.
| if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 or not _is_musa: | ||
| down_input_scale = tma_align_input_scale(down_input_scale) | ||
|
|
| if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 or not _is_musa: | ||
| down_input_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor( | ||
| down_input_scale |
| if _is_musa: | ||
| return 31 |
| else: | ||
| from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8 |
| if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and not _is_musa: | ||
| return deepgemm_w8a8_block_fp8_linear_with_fallback |
| assert input_scale is None | ||
| input_2d = input.view(-1, input.shape[-1]) | ||
| output_shape = [*input.shape[:-1], weight.shape[0]] | ||
| if input_2d.shape[0] < 4: | ||
| output = musa_fused_gemv( | ||
| input_2d, | ||
| weight, | ||
| None, | ||
| weight_scale, | ||
| ) | ||
| else: | ||
| q_input, x_scale = sglang_per_token_group_quant_fp8( | ||
| input_2d, block_size[1], column_major_scales=False | ||
| ) | ||
| output = musa_mudnn_w8a8_scaled_mm( | ||
| q_input, | ||
| weight, | ||
| out_dtype=input_2d.dtype, | ||
| scale_a=x_scale, | ||
| scale_b=weight_scale, | ||
| ) | ||
| if bias is not None: | ||
| output += bias | ||
| return output.to(dtype=input_2d.dtype).view(*output_shape) |
| if _is_musa and logits.shape[0] < 1: | ||
| batch_next_token_ids = torch.empty( | ||
| [0], dtype=torch.int64, device=logits.device | ||
| ) |
There was a problem hiding this comment.
ebcfbe1 to
74e7019
Compare
74e7019 to
d8fd705
Compare
| if sampling_info.is_all_greedy: | ||
| # Use torch.argmax if all requests use greedy sampling | ||
| batch_next_token_ids = torch.argmax(logits, -1) | ||
| # XXX (MUSA): protect against empty logits tensor |
There was a problem hiding this comment.
This can be removed to maintain consistency with the community.
| _is_musa = is_musa() | ||
|
|
||
| if _is_musa: | ||
| from sgl_kernel import musa_fused_gemv, musa_mudnn_w8a8_scaled_mm |
There was a problem hiding this comment.
Instead of _is_musa branch, using deep_gemm_wrapper.gemm_nt_f8f8bf16 could be a better option.
d8fd705 to
cbbcb20
Compare
cbbcb20 to
e5fe399
Compare
|
move to #22774 |
Motivation
This PR is the 16th in a series of pull requests (tracked in #16565) to add full support for Moore Threads GPUs, leveraging MUSA (Meta-computing Unified System Architecture) to accelerate LLM inference.
dependencies: #17946
Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci