Conversation
This reverts commit 68ad822.
|
/bot run |
📝 WalkthroughWalkthroughAdded SM120/SM121 Blackwell-native Mixture-of-Experts fused backend: new static and dynamic CuTe-DSL kernels, dispatch/workspace/weight-view plumbing, extended CuteDSL FP4 ops/helpers, and benchmark/wrapper updates to advertise and route to the new backend. Changes touch kernels, dispatch, CuteDSL ops, wrapper integration, and benchmarks. Changes
Sequence Diagram(s)sequenceDiagram
participant Host as Host/Benchmark
participant Wrapper as CuteDslMoEWrapper
participant Dispatch as SM120 Dispatch
participant Workspace as Workspace Cache
participant StaticK as MoEStaticKernel
participant DynamicK as MoEDynamicKernel
Host->>Wrapper: run(fused_moe, inputs, weights)
Wrapper->>Dispatch: select backend (SM12x?) / validate
Dispatch->>Workspace: check or allocate workspace & weight-views
Workspace-->>Dispatch: workspace + weight-views
Dispatch->>Dispatch: decide static vs dynamic (routed rows threshold)
alt static backend
Dispatch->>StaticK: launch_sm120_static_moe(workspace, tensors)
StaticK->>StaticK: Phase 0: route/pack + quantize
StaticK->>StaticK: resident-grid barrier
StaticK->>StaticK: Phase 1: FC1+SiLU+quant
StaticK->>StaticK: Phase 2: FC2+scatter
StaticK-->>Host: scatter_output
else dynamic backend
Dispatch->>DynamicK: launch_sm120_dynamic_moe(workspace, tensors)
DynamicK->>DynamicK: init queues/barriers
DynamicK->>DynamicK: warp producers: pack & publish tasks
DynamicK->>DynamicK: warp consumers: pop tasks & compute
DynamicK-->>Host: scatter_output
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related issues
Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ❌ 3❌ Failed checks (1 warning, 2 inconclusive)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces support for Blackwell SM120/SM121 architectures within the CuTe DSL MoE kernels, implementing both static (decode) and queue-driven dynamic (prefill) execution paths. The changes include new PTX intrinsics for global and shared memory operations, specialized FP4 quantization helpers, and a unified dispatch layer. Review feedback identifies critical missing intrinsics in the common library and highlights that current implementations do not yet safely support expert parallelism, potentially leading to out-of-bounds accesses. Additionally, there is a recommendation to optimize performance in the dynamic kernel by reducing the frequency of global memory fences.
| @dsl_user_op | ||
| def st_global_f32(base_ptr: Int64, value: Float32, *, loc=None, ip=None): | ||
| """Store 32-bit float to global memory.""" | ||
| llvm.inline_asm( | ||
| None, | ||
| [ | ||
| Int64(base_ptr).ir_value(loc=loc, ip=ip), | ||
| Float32(value).ir_value(loc=loc, ip=ip), | ||
| ], | ||
| "st.global.f32 [$0], $1;", | ||
| "l,f", | ||
| has_side_effects=True, | ||
| is_align_stack=False, | ||
| asm_dialect=llvm.AsmDialect.AD_ATT, | ||
| ) | ||
|
|
||
|
|
||
| @dsl_user_op | ||
| def st_global_i32(addr: Int64, val: Int32, *, loc=None, ip=None): | ||
| """Store int32 to global memory.""" | ||
| llvm.inline_asm( | ||
| None, | ||
| [ | ||
| Int64(addr).ir_value(loc=loc, ip=ip), | ||
| Int32(val).ir_value(loc=loc, ip=ip), | ||
| ], | ||
| "st.global.s32 [$0], $1;", | ||
| "l,r", | ||
| has_side_effects=True, | ||
| is_align_stack=False, | ||
| asm_dialect=llvm.AsmDialect.AD_ATT, | ||
| ) |
There was a problem hiding this comment.
The PTX intrinsics st_global_u64 and get_ptr_as_int64 appear to be missing from this file, although they are imported and used in both moe_static_kernel.py and moe_dynamic_kernel.py. Specifically, st_global_u64 is used for storing packed FP4 values, and get_ptr_as_int64 is used to obtain raw addresses for global memory operations. Without these definitions, the SM120 MoE kernels will fail to compile or run.
| if local_expert_offset != 0: | ||
| raise ValueError( | ||
| "SM120 MoE does not support expert parallelism (local_expert_offset != 0). " | ||
| "Use the SM100 CuTe DSL or CUTLASS backend for EP configurations." | ||
| ) |
There was a problem hiding this comment.
The check for expert parallelism (EP) on SM120 is insufficient. While it correctly identifies non-zero local_expert_offset, it fails to catch cases where num_local_experts != num_experts (e.g., Rank 0 in an EP configuration). The SM120 fused kernels currently do not support EP because they do not filter token_selected_experts based on the local expert range and they index weight tensors (which only contain local experts) using global expert IDs. This will lead to out-of-bounds memory accesses or incorrect results. Please update the check to ensure num_local_experts == num_experts when num_local_experts is provided.
| if backend == "dynamic" and num_local_experts != num_experts: | ||
| backend = "static" |
There was a problem hiding this comment.
The fallback from dynamic to static backend when num_local_experts != num_experts is misleading. As noted in the fused_moe.py review, the static kernel also does not correctly support expert parallelism. It will attempt to process all experts in the batch, and since its weight views are sized for num_local_experts, indexing them with global expert IDs will cause out-of-bounds accesses. If EP is not supported for SM120, this condition should likely trigger an error instead of a fallback.
| if full_tile_publish_enabled > Int32(0): | ||
| cute.arch.sync_warp() | ||
| # When the whole launch has fewer than one M-tile of routed | ||
| # rows, only the final partial-tile flush can publish work. | ||
| # Skip the per-row fence/counter path in that common micro case. | ||
| _threadfence() | ||
| cute.arch.sync_warp() |
There was a problem hiding this comment.
Executing _threadfence() (which maps to membar.gl) inside the producer loop for every routed pair may significantly degrade performance. While a fence is necessary to ensure visibility of packed_a_storage and scale_storage writes before a task is published, it might be more efficient to use a more targeted synchronization primitive or to only fence when a tile is actually completed, provided that visibility across all contributing threads is guaranteed.
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
benchmarks/routines/moe.py (1)
1362-1445:⚠️ Potential issue | 🟡 MinorUpdate the bandwidth model for the SM12x BF16 path.
This block switches the benchmark input to
x_bf16on SM120/SM121, but the latercalculate_moe_kernel_bandwidth(...)call still hard-codesinput_format="nvfp4". The timing stays correct, but the reported TB/s for the new path is misleading.Suggested follow-up
- tb_per_sec = calculate_moe_kernel_bandwidth( + input_format = "bf16" if sm_major_bm == 12 else "nvfp4" + tb_per_sec = calculate_moe_kernel_bandwidth( num_tokens, hidden_size, intermediate_size, num_experts, top_k, median_time, input_dtype, weight_dtype, - input_format="nvfp4", + input_format=input_format, weight_format="nvfp4", routing_logits_dtype=None, active_experts=num_active_experts, verbose=args.verbose, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/moe.py` around lines 1362 - 1445, The bandwidth model still uses input_format="nvfp4" even when the SM12x BF16 path is selected (x_input = tensors["x_bf16"] for sm_major_bm == 12); update the call to calculate_moe_kernel_bandwidth(...) to select the correct input_format based on sm_major_bm (e.g., use "bf16" for SM12x and "nvfp4" otherwise) so reported TB/s matches the actual BF16 kernel path; find references to x_input and the calculate_moe_kernel_bandwidth invocation in this file (and any callers of run_cute_dsl_moe if present) and branch the input_format argument accordingly.
🧹 Nitpick comments (4)
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py (1)
679-683: Unusedbidxvariable.The
bidxfromcute.arch.block_idx()is unpacked but never used. Consider using_, _, bidz = cute.arch.block_idx()to satisfy the linter, though this is a minor cosmetic issue.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py` around lines 679 - 683, The variable `bidx` from cute.arch.block_idx() is unused; update the unpacking to discard it (e.g., use "_, _, bidz = cute.arch.block_idx()") or remove the unused name so the linter stops complaining, ensuring the rest of the code references bidz as before; change the line that currently assigns "bidx, _, bidz = cute.arch.block_idx()" accordingly in moe_dynamic_kernel.py.flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py (1)
1186-1249: Workspace cache grows monotonically without bounds.The
_WORKSPACE_CACHEgrows when encountering new problem shapes or when existing workspaces have insufficient capacity. Since workspaces are never evicted, long-running inference services could accumulate significant GPU memory if they see many distinct configurations.Consider documenting this behavior or adding an optional cache clear API for production deployments.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` around lines 1186 - 1249, The workspace cache (_WORKSPACE_CACHE) currently grows without bounds in _get_cached_workspace and should be made bounded and controllable; change the implementation to accept a configurable max cache size (module-level DEFAULT_MAX_WORKSPACE_CACHE or env/config), implement LRU eviction when inserting a new key into _WORKSPACE_CACHE (evict the least-recently-used entry before assigning the new workspace), and add a public API function clear_workspace_cache() (and optionally set_max_workspace_cache(size) or disable_workspace_cache()) to allow manual clearing or sizing from production code; update _get_cached_workspace to touch/access the key on lookup so the LRU policy works and document the new behavior in the function docstring.flashinfer/fused_moe/cute_dsl/fused_moe.py (2)
439-449: Duplicate CUDA version check across__init__and_moe_core_impl.The CUDA 13 version check is performed both here in
__init__and again in_moe_core_impl(line 193). While defensive, this is redundant since the wrapper'srun()method dispatches directly tolaunch_sm120_moebypassing_moe_core_implon SM120.Consider consolidating the check or documenting why both are needed. Also,
minoris unused—usemajor, _to satisfy the linter.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/fused_moe.py` around lines 439 - 449, The CUDA 13 runtime check is duplicated; consolidate by keeping the version guard in the class initializer (where self._is_sm120 is set) and remove the redundant check in _moe_core_impl (or vice versa if you prefer runtime check nearer dispatch), and update comments to document why the check stays in that single location; also fix the unused minor variable by changing the unpack to "major, _" where torch.cuda.get_device_capability(device) is called. Ensure references to get_cuda_version, self._is_sm120, _moe_core_impl, and launch_sm120_moe are updated accordingly so only one CUDA-version validation remains.
700-724: Weight view caching invalidation looks correct but has a subtle lifetime concern.The caching key uses
data_ptr()values, which correctly invalidates when weight tensors change. However, the cached_WeightViewsobject in_get_sm120_weight_viewsholds references to contiguous copies of scale tensors. If the original weight tensors are freed while cached views remain, the cache key (pointer-based) could collide with a new tensor allocated at the same address.This is likely fine in practice since weights are typically long-lived, but worth noting for production use with dynamic weight swapping.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/fused_moe.py` around lines 700 - 724, The cache invalidation currently keys off tensor.data_ptr(), which can collide if a freed tensor's memory is reallocated; update the key construction used with _sm120_weight_views/_sm120_weight_key to include each tensor's storage identity and size (e.g., use tensor.storage().data_ptr(), tensor.storage().size(), tensor.dtype and tensor.device or tensor.storage().element_size()) for w1_weight, w1_weight_sf, w1_alpha, w2_weight, w2_weight_sf, w2_alpha so the key reflects the underlying allocation and layout, preventing accidental collisions when memory is reused; locate the weight_key tuple creation and replace each data_ptr() entry with a small tuple of (storage().data_ptr(), storage().size(), dtype, device) for robustness.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/routines/moe.py`:
- Around line 1360-1415: Detect SM120/SM121 (sm_major_bm == 12) and explicitly
reject Expert-Parallel (EP) configs by validating local_num_experts ==
num_experts and local_expert_offset == 0 before selecting the SM12x backend; if
the check fails, raise a clear error (or exit) so neither
cute_dsl_fused_moe_nvfp4 (used when use_functional) nor CuteDslMoEWrapper::run
are invoked with unsupported local-expert remapping parameters. Ensure the guard
references sm_major_bm, local_num_experts, local_expert_offset,
cute_dsl_fused_moe_nvfp4 and CuteDslMoEWrapper so it runs for both the
functional and wrapper code paths.
In `@flashinfer/cute_dsl/fp4_common.py`:
- Around line 1451-1471: The non-fast path in quantize_block_fp4 incorrectly
passes quantized_scale * global_scale_val into quantize_and_pack_16, but
quantize_and_pack_16 expects an inverse scale (reciprocal) like
quantize_block_fp4_fast; change the argument to the reciprocal of
(quantized_scale * global_scale_val) (ensuring Float32 math and guarding
divide-by-zero as currently done by the if) so quantize_and_pack_16 receives 1.0
/ (quantized_scale * global_scale_val) instead of the product.
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py`:
- Around line 198-218: The current inline asm in _spin_wait_global_eq_i32 loops
while val == expected (using setp.eq and branching on %p0) which is the opposite
of the function name; change the predicate to test inequality so it spins while
val != expected and exits when equal: in the asm block inside
_spin_wait_global_eq_i32 replace "setp.eq.s32 %p0, %val, $1;" with "setp.ne.s32
%p0, %val, $1;" (keep the rest of the loop/branch the same) so the function
semantics match its name and usage.
---
Outside diff comments:
In `@benchmarks/routines/moe.py`:
- Around line 1362-1445: The bandwidth model still uses input_format="nvfp4"
even when the SM12x BF16 path is selected (x_input = tensors["x_bf16"] for
sm_major_bm == 12); update the call to calculate_moe_kernel_bandwidth(...) to
select the correct input_format based on sm_major_bm (e.g., use "bf16" for SM12x
and "nvfp4" otherwise) so reported TB/s matches the actual BF16 kernel path;
find references to x_input and the calculate_moe_kernel_bandwidth invocation in
this file (and any callers of run_cute_dsl_moe if present) and branch the
input_format argument accordingly.
---
Nitpick comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py`:
- Around line 1186-1249: The workspace cache (_WORKSPACE_CACHE) currently grows
without bounds in _get_cached_workspace and should be made bounded and
controllable; change the implementation to accept a configurable max cache size
(module-level DEFAULT_MAX_WORKSPACE_CACHE or env/config), implement LRU eviction
when inserting a new key into _WORKSPACE_CACHE (evict the least-recently-used
entry before assigning the new workspace), and add a public API function
clear_workspace_cache() (and optionally set_max_workspace_cache(size) or
disable_workspace_cache()) to allow manual clearing or sizing from production
code; update _get_cached_workspace to touch/access the key on lookup so the LRU
policy works and document the new behavior in the function docstring.
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py`:
- Around line 679-683: The variable `bidx` from cute.arch.block_idx() is unused;
update the unpacking to discard it (e.g., use "_, _, bidz =
cute.arch.block_idx()") or remove the unused name so the linter stops
complaining, ensuring the rest of the code references bidz as before; change the
line that currently assigns "bidx, _, bidz = cute.arch.block_idx()" accordingly
in moe_dynamic_kernel.py.
In `@flashinfer/fused_moe/cute_dsl/fused_moe.py`:
- Around line 439-449: The CUDA 13 runtime check is duplicated; consolidate by
keeping the version guard in the class initializer (where self._is_sm120 is set)
and remove the redundant check in _moe_core_impl (or vice versa if you prefer
runtime check nearer dispatch), and update comments to document why the check
stays in that single location; also fix the unused minor variable by changing
the unpack to "major, _" where torch.cuda.get_device_capability(device) is
called. Ensure references to get_cuda_version, self._is_sm120, _moe_core_impl,
and launch_sm120_moe are updated accordingly so only one CUDA-version validation
remains.
- Around line 700-724: The cache invalidation currently keys off
tensor.data_ptr(), which can collide if a freed tensor's memory is reallocated;
update the key construction used with _sm120_weight_views/_sm120_weight_key to
include each tensor's storage identity and size (e.g., use
tensor.storage().data_ptr(), tensor.storage().size(), tensor.dtype and
tensor.device or tensor.storage().element_size()) for w1_weight, w1_weight_sf,
w1_alpha, w2_weight, w2_weight_sf, w2_alpha so the key reflects the underlying
allocation and layout, preventing accidental collisions when memory is reused;
locate the weight_key tuple creation and replace each data_ptr() entry with a
small tuple of (storage().data_ptr(), storage().size(), dtype, device) for
robustness.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 61d7c731-6ce7-4c72-934b-349b759d1fb8
📒 Files selected for processing (8)
benchmarks/routines/flashinfer_benchmark_utils.pybenchmarks/routines/moe.pyflashinfer/cute_dsl/fp4_common.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/__init__.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_static_kernel.pyflashinfer/fused_moe/cute_dsl/fused_moe.py
| use_functional = getattr(args, "use_functional_api", False) | ||
|
|
||
| # SM120 passes bf16 as x (kernel fuses quantization); SM100 passes FP4. | ||
| sm_major_bm = torch.cuda.get_device_capability(device)[0] | ||
| x_input = tensors["x_bf16"] if sm_major_bm == 12 else tensors["x"] | ||
|
|
||
| if use_functional: | ||
| from flashinfer import cute_dsl_fused_moe_nvfp4 | ||
| from functools import partial | ||
|
|
||
| if args.verbose >= 1: | ||
| print( | ||
| "[INFO] Using functional API (cute_dsl_fused_moe_nvfp4) with workspace cache" | ||
| ) | ||
|
|
||
| # Pre-allocate output buffer to avoid per-call allocation | ||
| moe_output = torch.empty( | ||
| num_tokens, hidden_size, dtype=torch.bfloat16, device=device | ||
| ) | ||
|
|
||
| runner = partial( | ||
| cute_dsl_fused_moe_nvfp4, | ||
| num_experts=num_experts, | ||
| top_k=top_k, | ||
| num_local_experts=local_num_experts, | ||
| local_expert_offset=local_expert_offset, | ||
| moe_output=moe_output, | ||
| ) | ||
|
|
||
| # Warmup call to populate workspace cache before timed region | ||
| runner( | ||
| x=x_input, | ||
| x_sf=tensors["x_sf"], | ||
| token_selected_experts=tensors["token_selected_experts"], | ||
| token_final_scales=tensors["token_final_scales"], | ||
| w1_weight=tensors["w1_weight"], | ||
| w1_weight_sf=tensors["w1_weight_sf"], | ||
| w1_alpha=tensors["w1_alpha"], | ||
| fc2_input_scale=tensors["fc2_input_scale"], | ||
| w2_weight=tensors["w2_weight"], | ||
| w2_weight_sf=tensors["w2_weight_sf"], | ||
| w2_alpha=tensors["w2_alpha"], | ||
| ) | ||
| else: | ||
| moe = CuteDslMoEWrapper( | ||
| num_experts=num_experts, | ||
| top_k=top_k, | ||
| hidden_size=hidden_size, | ||
| intermediate_size=intermediate_size, | ||
| use_cuda_graph=is_cuda_graph_compatible, | ||
| max_num_tokens=num_tokens, | ||
| num_local_experts=local_num_experts, | ||
| local_expert_offset=local_expert_offset, | ||
| ) | ||
| runner = moe.run | ||
|
|
There was a problem hiding this comment.
Reject EP-style configs on SM120/SM121.
This path now runs the SM12x backend, but it still accepts local_num_experts != num_experts and non-zero local_expert_offset. SM120/SM121 does not support local-expert remapping, so those arguments can produce unsupported benchmark cases or route into experts that are not present in the locally-created weight tensors.
Suggested guard
use_functional = getattr(args, "use_functional_api", False)
# SM120 passes bf16 as x (kernel fuses quantization); SM100 passes FP4.
sm_major_bm = torch.cuda.get_device_capability(device)[0]
+ if sm_major_bm == 12 and (
+ local_num_experts != num_experts or local_expert_offset != 0
+ ):
+ raise ValueError(
+ "cute_dsl_fp4_block_scale_moe on SM120/SM121 does not support "
+ "local expert sharding; use local_num_experts=num_experts and "
+ "local_expert_offset=0."
+ )
x_input = tensors["x_bf16"] if sm_major_bm == 12 else tensors["x"]Based on learnings: Expert Parallelism (EP) is unsupported on SM120, and the SM120 dispatch paths intentionally do not forward local_expert_offset because kernel-side remapping is missing.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@benchmarks/routines/moe.py` around lines 1360 - 1415, Detect SM120/SM121
(sm_major_bm == 12) and explicitly reject Expert-Parallel (EP) configs by
validating local_num_experts == num_experts and local_expert_offset == 0 before
selecting the SM12x backend; if the check fails, raise a clear error (or exit)
so neither cute_dsl_fused_moe_nvfp4 (used when use_functional) nor
CuteDslMoEWrapper::run are invoked with unsupported local-expert remapping
parameters. Ensure the guard references sm_major_bm, local_num_experts,
local_expert_offset, cute_dsl_fused_moe_nvfp4 and CuteDslMoEWrapper so it runs
for both the functional and wrapper code paths.
| @cute.jit | ||
| def quantize_block_fp4( | ||
| values: cute.Tensor, | ||
| max_abs: Float32, | ||
| global_scale_val: Float32, | ||
| ) -> Tuple[Uint64, Uint8]: | ||
| """Quantize 16 float32 values to packed FP4 + e4m3 scale byte. | ||
|
|
||
| Given 16 values and their pre-computed max_abs, derives the NVFP4 block | ||
| scale, quantizes to FP4, and packs into a uint64. Returns | ||
| (packed_fp4_u64, scale_byte). | ||
| """ | ||
| scale_float = max_abs / (Float32(FLOAT4_E2M1_MAX) * global_scale_val) | ||
| scale_float = fmin_f32(scale_float, Float32(FLOAT8_E4M3_MAX)) | ||
| scale_u32 = cvt_f32_to_e4m3(scale_float) | ||
| scale_byte = Uint8(scale_u32 & Uint32(0xFF)) | ||
| quantized_scale = fp8_e4m3_to_f32(scale_u32) | ||
| packed64 = Uint64(0) | ||
| if quantized_scale != Float32(0.0) and global_scale_val != Float32(0.0): | ||
| packed64 = quantize_and_pack_16(values, quantized_scale * global_scale_val) | ||
| return packed64, scale_byte |
There was a problem hiding this comment.
Pass the reciprocal scale into quantize_and_pack_16.
quantize_and_pack_16() multiplies by an inverse scale, but this branch passes quantized_scale * global_scale_val directly. That makes the non-fast path quantize by the scale instead of dividing by it, while quantize_block_fp4_fast() already uses the reciprocal form. As written, the default fast_math=False path will pack incorrect FP4 values.
Suggested fix
`@cute.jit`
def quantize_block_fp4(
values: cute.Tensor,
max_abs: Float32,
global_scale_val: Float32,
) -> Tuple[Uint64, Uint8]:
@@
scale_u32 = cvt_f32_to_e4m3(scale_float)
scale_byte = Uint8(scale_u32 & Uint32(0xFF))
quantized_scale = fp8_e4m3_to_f32(scale_u32)
packed64 = Uint64(0)
if quantized_scale != Float32(0.0) and global_scale_val != Float32(0.0):
- packed64 = quantize_and_pack_16(values, quantized_scale * global_scale_val)
+ inv_scale = Float32(1.0) / (quantized_scale * global_scale_val)
+ packed64 = quantize_and_pack_16(values, inv_scale)
return packed64, scale_byte🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/cute_dsl/fp4_common.py` around lines 1451 - 1471, The non-fast
path in quantize_block_fp4 incorrectly passes quantized_scale * global_scale_val
into quantize_and_pack_16, but quantize_and_pack_16 expects an inverse scale
(reciprocal) like quantize_block_fp4_fast; change the argument to the reciprocal
of (quantized_scale * global_scale_val) (ensuring Float32 math and guarding
divide-by-zero as currently done by the if) so quantize_and_pack_16 receives 1.0
/ (quantized_scale * global_scale_val) instead of the product.
| @dsl_user_op | ||
| def _spin_wait_global_eq_i32(addr, expected, *, loc=None, ip=None): | ||
| llvm.inline_asm( | ||
| None, | ||
| [ | ||
| Int64(addr).ir_value(loc=loc, ip=ip), | ||
| Int32(expected).ir_value(loc=loc, ip=ip), | ||
| ], | ||
| "{\n" | ||
| ".reg .pred %p0;\n" | ||
| ".reg .s32 %val;\n" | ||
| "spin_loop:\n" | ||
| " ld.global.acquire.gpu.s32 %val, [$0];\n" | ||
| " setp.eq.s32 %p0, %val, $1;\n" | ||
| " @%p0 bra spin_loop;\n" | ||
| "}", | ||
| "l,r", | ||
| has_side_effects=True, | ||
| is_align_stack=False, | ||
| asm_dialect=llvm.AsmDialect.AD_ATT, | ||
| ) |
There was a problem hiding this comment.
Spin wait logic appears inverted.
The inline assembly spins while val == expected:
" setp.eq.s32 %p0, %val, $1;\n"
" @%p0 bra spin_loop;\n"
This means it exits when val != expected. However, the function is named _spin_wait_global_eq_i32 suggesting it should wait until val == expected. Looking at usage at line 404:
_spin_wait_global_eq_i32(barrier_epoch_addr, old_epoch)This spins while the epoch equals old_epoch, exiting when it changes. The semantic is "wait until NOT equal" but the name suggests "wait until equal." Either the logic or the name should be corrected to avoid confusion.
🔧 Suggested name fix to match semantics
-def _spin_wait_global_eq_i32(addr, expected, *, loc=None, ip=None):
+def _spin_wait_global_ne_i32(addr, expected, *, loc=None, ip=None):
+ """Spin until the value at `addr` is NOT equal to `expected`."""Or alternatively fix the assembly logic to spin until equal:
- " @%p0 bra spin_loop;\n"
+ " @!%p0 bra spin_loop;\n"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py` around
lines 198 - 218, The current inline asm in _spin_wait_global_eq_i32 loops while
val == expected (using setp.eq and branching on %p0) which is the opposite of
the function name; change the predicate to test inequality so it spins while val
!= expected and exits when equal: in the asm block inside
_spin_wait_global_eq_i32 replace "setp.eq.s32 %p0, %val, $1;" with "setp.ne.s32
%p0, %val, $1;" (keep the rest of the loop/branch the same) so the function
semantics match its name and usage.
|
/bot stop |
|
The GitLab CI pipeline #48553770 has been cancelled. |
|
/bot run |
6b8800e to
f75ea36
Compare
|
/bot stop |
|
/bot run |
|
The GitLab CI pipeline #48557894 has been cancelled. |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (3)
flashinfer/fused_moe/cute_dsl/fused_moe.py (3)
189-189: Unused variableminorcan be prefixed with underscore.The
minorvariable from unpacking is not used. Prefix it with_to suppress linter warnings.🔧 Suggested fix
- major, minor = torch.cuda.get_device_capability(x.device) + major, _ = torch.cuda.get_device_capability(x.device)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/fused_moe.py` at line 189, The unpacked `minor` from torch.cuda.get_device_capability is unused and triggers linter warnings; change the unpacking in the call to torch.cuda.get_device_capability(x.device) to use an underscore-prefixed variable (e.g., `_,` or `_minor`) instead of `minor`, so only the used `major` variable remains named and the linter warning is suppressed while leaving the call and x.device reference intact.
440-440: Unused variableminorcan be prefixed with underscore.Same as in
_moe_core_impl: theminorvariable is unused.🔧 Suggested fix
- major, minor = torch.cuda.get_device_capability(device) + major, _ = torch.cuda.get_device_capability(device)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/fused_moe.py` at line 440, The assignment in fused_moe.py currently unpacks torch.cuda.get_device_capability(device) into major, minor while `minor` is unused; update the unpacking in the relevant function (the line with `major, minor = torch.cuda.get_device_capability(device)`) to use an underscore-prefixed name (e.g., `major, _minor = torch.cuda.get_device_capability(device)`) to silence the unused-variable warning, and apply the same change where `minor` is unused in `_moe_core_impl` if present.
464-474: Consider skipping autotuner runner creation on SM120.The
_runner(autotuner) is created unconditionally but is never used on the SM120 path since SM120 bypasses autotuning entirely. This is a minor inefficiency that could be addressed in a follow-up.💡 Optional optimization
- # Create auto-tuner runner (SM100 path only — SM120 bypasses autotuner) - self._runner = CuteDslFusedMoENvfp4Runner( - forward_impl=self._forward_with_tactic, - num_experts=num_experts, - top_k=top_k, - num_local_experts=self.num_local_experts, - local_expert_offset=local_expert_offset, - use_fused_finalize=True, - output_dtype=output_dtype, - enable_pdl=enable_pdl, - ) + # Create auto-tuner runner (SM100 path only — SM120 bypasses autotuner) + if not self._is_sm120: + self._runner = CuteDslFusedMoENvfp4Runner( + forward_impl=self._forward_with_tactic, + num_experts=num_experts, + top_k=top_k, + num_local_experts=self.num_local_experts, + local_expert_offset=local_expert_offset, + use_fused_finalize=True, + output_dtype=output_dtype, + enable_pdl=enable_pdl, + ) + else: + self._runner = None🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/fused_moe.py` around lines 464 - 474, The code unconditionally constructs self._runner (CuteDslFusedMoENvfp4Runner using forward_impl=self._forward_with_tactic) even though the SM120 path bypasses autotuning; wrap this instantiation in a conditional so the runner is only created when autotuning is enabled (i.e., not the SM120 path). Add a guard like "if not self._is_sm120:" or "if self._autotune_enabled:" around the CuteDslFusedMoENvfp4Runner creation and ensure self._runner stays None when skipped so other code expects that state.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/fused_moe/cute_dsl/fused_moe.py`:
- Around line 390-391: Replace the direct `@supported_compute_capability` usage by
extracting the capability check into small helper functions and applying
`@backend_requirement` with common_check on the API entrypoints: create a helper
like def _check___init__compute_capability(...): decorated with
`@supported_compute_capability`([100,103,120,121]) that returns True, then remove
`@supported_compute_capability` from the __init__ and instead annotate __init__
with `@backend_requirement`({}, common_check=_check___init__compute_capability)
and keep `@flashinfer_api`; do the same for cute_dsl_fused_moe_nvfp4 by adding a
`@supported_compute_capability-decorated` helper (e.g.,
_check_cute_dsl_fused_moe_nvfp4) returning True and replacing the direct
decorator with `@backend_requirement`({},
common_check=_check_cute_dsl_fused_moe_nvfp4) on the cute_dsl_fused_moe_nvfp4
function.
---
Nitpick comments:
In `@flashinfer/fused_moe/cute_dsl/fused_moe.py`:
- Line 189: The unpacked `minor` from torch.cuda.get_device_capability is unused
and triggers linter warnings; change the unpacking in the call to
torch.cuda.get_device_capability(x.device) to use an underscore-prefixed
variable (e.g., `_,` or `_minor`) instead of `minor`, so only the used `major`
variable remains named and the linter warning is suppressed while leaving the
call and x.device reference intact.
- Line 440: The assignment in fused_moe.py currently unpacks
torch.cuda.get_device_capability(device) into major, minor while `minor` is
unused; update the unpacking in the relevant function (the line with `major,
minor = torch.cuda.get_device_capability(device)`) to use an underscore-prefixed
name (e.g., `major, _minor = torch.cuda.get_device_capability(device)`) to
silence the unused-variable warning, and apply the same change where `minor` is
unused in `_moe_core_impl` if present.
- Around line 464-474: The code unconditionally constructs self._runner
(CuteDslFusedMoENvfp4Runner using forward_impl=self._forward_with_tactic) even
though the SM120 path bypasses autotuning; wrap this instantiation in a
conditional so the runner is only created when autotuning is enabled (i.e., not
the SM120 path). Add a guard like "if not self._is_sm120:" or "if
self._autotune_enabled:" around the CuteDslFusedMoENvfp4Runner creation and
ensure self._runner stays None when skipped so other code expects that state.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 6cbdfe34-76cd-4595-980f-48cca5a47def
📒 Files selected for processing (1)
flashinfer/fused_moe/cute_dsl/fused_moe.py
| @supported_compute_capability([100, 103, 120, 121]) | ||
| @flashinfer_api |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Description: Search for both decorators to understand their usage patterns in the codebase.
echo "=== `@backend_requirement` usage ==="
rg -n '@backend_requirement' --type=py -C2
echo ""
echo "=== `@supported_compute_capability` usage ==="
rg -n '@supported_compute_capability' --type=py -C2
echo ""
echo "=== Decorator definitions ==="
ast-grep --pattern 'def backend_requirement($$$)'
ast-grep --pattern 'def supported_compute_capability($$$)'Repository: flashinfer-ai/flashinfer
Length of output: 50380
🏁 Script executed:
cd flashinfer/fused_moe/cute_dsl && sed -n '388,410p' fused_moe.pyRepository: flashinfer-ai/flashinfer
Length of output: 686
🏁 Script executed:
cd flashinfer/fused_moe/cute_dsl && sed -n '843,860p' fused_moe.pyRepository: flashinfer-ai/flashinfer
Length of output: 543
Replace @supported_compute_capability with @backend_requirement pattern per coding guidelines.
The functions at these lines should follow the codebase pattern for APIs with architecture-specific requirements. Instead of applying @supported_compute_capability directly to the API function, extract the check logic into a separate function and use @backend_requirement({}, common_check=...) on the main API. For example, other similar functions in the codebase use:
`@supported_compute_capability`([100, 103, 120, 121])
def _check_function(...):
return True
`@backend_requirement`({}, common_check=_check_function)
`@flashinfer_api`
def api_function(...):This applies to lines 390-391 (__init__ method) and 845-846 (cute_dsl_fused_moe_nvfp4 function).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/fused_moe/cute_dsl/fused_moe.py` around lines 390 - 391, Replace
the direct `@supported_compute_capability` usage by extracting the capability
check into small helper functions and applying `@backend_requirement` with
common_check on the API entrypoints: create a helper like def
_check___init__compute_capability(...): decorated with
`@supported_compute_capability`([100,103,120,121]) that returns True, then remove
`@supported_compute_capability` from the __init__ and instead annotate __init__
with `@backend_requirement`({}, common_check=_check___init__compute_capability)
and keep `@flashinfer_api`; do the same for cute_dsl_fused_moe_nvfp4 by adding a
`@supported_compute_capability-decorated` helper (e.g.,
_check_cute_dsl_fused_moe_nvfp4) returning True and replacing the direct
decorator with `@backend_requirement`({},
common_check=_check_cute_dsl_fused_moe_nvfp4) on the cute_dsl_fused_moe_nvfp4
function.
|
/bot stop |
|
The GitLab CI pipeline #48563921 has been cancelled. |
📌 Description
Debugging issue only reproducible on CI for issue #3066
🔍 Related Issues
#3066
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements