Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,10 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
reinterpret_cast<typename ScheduleConfig::LayoutSFB*>(layout_sfb.data_ptr())};

cutlass::KernelHardwareInfo hw_info;
// TODO(qiyuhang): get device_id by cudaGetDevice
hw_info.device_id = 0;
hw_info.sm_count = 132;
// TODO(qiyuhang): get sm_count by cudaGetDeviceProperties
hw_info.sm_count = 78; // H20 config
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This kernel sm_count configuration is not good in H800.

$python3 sgl-kernel/benchmark/bench_fp8_blockwise_group_gemm.py

Benchmark: expected_m_per_group=128, n=512, k=7168, num_groups=256
deepgemm: 493.7056064605713 us
cutlass: 460.3487968444824 us

Benchmark: expected_m_per_group=256, n=512, k=7168, num_groups=256
deepgemm: 563.2063865661621 us
cutlass: 671.1967945098877 us

Benchmark: expected_m_per_group=256, n=256, k=7168, num_groups=256
deepgemm: 391.4144039154053 us
cutlass: 397.8463888168335 us

Benchmark: expected_m_per_group=512, n=256, k=7168, num_groups=256
deepgemm: 598.5184192657471 us
cutlass: 743.3504104614258 us

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yuan-luo What is the SM Count on H800? It seems that this may be caused by load imbalance. We have some engineers working on optimizing load balancing. Since I lack the H800 machine, could you please provide me with an ncu report?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yuan-luo We use SM Count as the number of CTAs. I checked and found that H800 has 144 SM Cores, the same as H100. In the second and fourth configurations, there are only (256/128)*(512/128)*256 = 2048 Output Tiles. 2048 / 144 = 14.22 The last stage may only enable 20% of the SM Cores. I guess this may be the cause of the performance issue and hope you can help provide an ncu report to confirm.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HydraQYH I'm working on a related PR about FP8 MoE kernel for Hopper. Will update and relate with this PR later on.


typename GemmKernel::EpilogueArguments epilogue_args{
{},
Expand All @@ -140,7 +142,8 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(

at::cuda::CUDAGuard device_guard{(char)a_ptrs.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a_ptrs.get_device());


// TODO(qiyuhang): skip can_implement when problem_sizes_host is nullptr
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM");

Expand Down