Add BGMV MoE CUDA kernels for multi-LoRA#3249
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughImplements fused BGMV MoE shrink/expand CUDA kernels with configurable thread/tile parameters across dtype combinations (fp16, bf16, float, mixed-precision), C++ dispatch routing tensors to the correct instantiation, TVM-FFI bindings, lazy Python API with JIT compilation, comprehensive tests including reference implementations, and standalone/integrated benchmarks. ChangesBGMV MoE Implementation
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request implements multi-LoRA Mixture-of-Experts (MoE) BGMV CUDA kernels, providing optimized shrink and expand operations with PyTorch bindings and JIT support. The reviewer identified a critical stability issue where a missing check for an empty output_slices vector could cause a crash. Furthermore, several performance improvements were suggested, specifically moving the scale multiplication out of inner loops in both kernels. A bug in the shrink kernel's epilogue was also noted, where per-warp bounds checking could lead to incorrect results when dimensions are not warp-aligned.
I am having trouble creating individual review comments. Click here to see my feedback.
csrc/bgmv_moe/moe_bgmv_ops.cu (168-172)
The code accesses output_slices[0] without checking if the vector is empty. This will cause a crash if num_slices is 0. Adding a TORCH_CHECK ensures stability.
TORCH_CHECK(!output_slices.empty(), "BGMV MoE expand: output_slices must not be empty");
int32_t first_feat_out = static_cast<int32_t>(output_slices[0]);
for (size_t i = 1; i < output_slices.size(); ++i) {
TORCH_CHECK(output_slices[i] == first_feat_out,
"BGMV MoE expand: all output_slices must be equal");
}
csrc/bgmv_moe/moe_bgmv_impl.cuh (213-229)
The bounds check for partial tiles in the epilogue is performed per-warp, which is incorrect when feat_in is not a multiple of the warp tile size. It must be checked per-thread to avoid summing uninitialized or stale data from shared memory. Additionally, moving the scale multiplication out of the inner loop improves performance.
if (ts + toff < feat_in) x_vec.load(X_shared + (cs * PAIRS_PER_BLOCK + pp) * tile_size + toff);
#pragma unroll
for (int r = 0; r < RANK_TILE; ++r) {
if (j0 + r < feat_out) {
float sum = 0.f;
if (ts + toff < feat_in) {
w_vec.load(W_shared + ((cs * PAIRS_PER_BLOCK + pp) * RANK_TILE + r) * tile_size + toff);
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) sum += float(w_vec[i]) * float(x_vec[i]);
}
#pragma unroll
for (size_t off = tx / 2; off > 0; off /= 2)
sum += __shfl_down_sync(0xffffffff, sum, off);
if (threadIdx.x == 0) {
y_warpwise[pp * RANK_TILE * ty + r * ty + threadIdx.y] = sum * scale;
}
}
}
csrc/bgmv_moe/moe_bgmv_impl.cuh (151-164)
Multiplying by scale inside the inner loop is redundant. It is more efficient to apply it once after the warp reduction when writing to shared memory.
x_vec.load(X_shared + (cs * PAIRS_PER_BLOCK + pp) * tile_size + toff);
#pragma unroll
for (int r = 0; r < RANK_TILE; ++r) {
if (j0 + r < feat_out) {
w_vec.load(W_shared + ((cs * PAIRS_PER_BLOCK + pp) * RANK_TILE + r) * tile_size + toff);
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) sum += float(w_vec[i]) * float(x_vec[i]);
#pragma unroll
for (size_t off = tx / 2; off > 0; off /= 2)
sum += __shfl_down_sync(0xffffffff, sum, off);
if (threadIdx.x == 0) y_warpwise[pp * RANK_TILE * ty + r * ty + threadIdx.y] = sum * scale;
}
}
csrc/bgmv_moe/moe_bgmv_impl.cuh (290)
Applying scale inside the inner loop is redundant. It should be applied once after the loop for better performance.
for (size_t i = 0; i < vec_size; ++i) sum += float(w_vec[i]) * float(x_vec[i]);
sum *= scale;
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (7)
csrc/bgmv_moe/moe_bgmv_fp32_fp16_fp16.cu (1)
11-14: 💤 Low valueOptional:
#undef INST_MOE_BGMV_SHRINK_ONLYafter the macro expansion.The macro
INST_MOE_BGMV_SHRINK_ONLYis defined at translation-unit scope and never undefined. Since this is a.cufile with no further uses, it is harmless today, but adding#undefafter theFOR_MOE_ALL_WIDE_NARROW(...)line keeps the per-variant naming hygienic and aligns with how analogous files might be amalgamated/included in unit tests later.FOR_MOE_ALL_WIDE_NARROW(INST_MOE_BGMV_SHRINK_ONLY, float, nv_half, nv_half) + +#undef INST_MOE_BGMV_SHRINK_ONLY🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@csrc/bgmv_moe/moe_bgmv_fp32_fp16_fp16.cu` around lines 11 - 14, The macro INST_MOE_BGMV_SHRINK_ONLY is left defined after its use; after the FOR_MOE_ALL_WIDE_NARROW(INST_MOE_BGMV_SHRINK_ONLY, float, nv_half, nv_half) line, add an undef for INST_MOE_BGMV_SHRINK_ONLY to remove it from translation-unit scope (i.e., insert an `#undef` INST_MOE_BGMV_SHRINK_ONLY immediately following that macro expansion).csrc/bgmv_moe/moe_bgmv_impl.cuh (4)
246-258: ⚡ Quick winDocument the
Y += ...accumulation contract.Both kernels accumulate into
Y(+=for shrink at line 254,atomicAddfor expand at line 297), which silently requires callers to zero-initializeYbefore invocation. The Python tests/bench currently do this (torch.zeros(...)), but the contract isn't documented in the kernel header or host wrappers. A short comment onmoe_bgmv_shrink_sliced/moe_bgmv_expand_slicedhost-side declarations would prevent misuse.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@csrc/bgmv_moe/moe_bgmv_impl.cuh` around lines 246 - 258, Document that the device kernels accumulate into the output buffer Y (the device code uses Y += ... in the shrink kernel and atomicAdd in the expand kernel), so callers must provide a zero-initialized Y before invoking moe_bgmv_shrink_sliced and moe_bgmv_expand_sliced; add a brief comment in the host-side declarations / header above moe_bgmv_shrink_sliced and moe_bgmv_expand_sliced (or in the wrapper functions that call the kernels) stating the accumulation contract and that the caller is responsible for zeroing Y (e.g., use torch.zeros or cudaMemset) to avoid silent incorrect results.
320-326: 💤 Low valueUnchecked CUDA runtime calls in dispatch path.
cudaGetDeviceandcudaDeviceGetAttributereturncudaError_tbut the values are discarded; if either fails,sm_majorstays at 0 and the kernel silently picks the prefill (non-extended) path. On systems where the device is properly initialized this won't matter, but defensive checking (or caching the capability once viaat::cuda::getCurrentDeviceProperties()when ATen is available) would surface real failures.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@csrc/bgmv_moe/moe_bgmv_impl.cuh` around lines 320 - 326, The cuda runtime calls cudaGetDevice and cudaDeviceGetAttribute in the dispatch path are unchecked, so errors leave sm_major at 0 and silently force the non-extended path; update the dispatch logic (around the sm_major/extended/decode calculation that uses cudaGetDevice, cudaDeviceGetAttribute, sm_major, extended and MoeShrinkKernelConfig::decode_threshold) to check and handle the returned cudaError_t values (propagate/log and choose a safe fallback) or, when ATen is available, obtain and cache the device compute capability from at::cuda::getCurrentDeviceProperties() and derive sm_major from that to avoid calling the runtime repeatedly and to surface real failures.
359-371: ⚖️ Poor tradeoffSilent kernel-not-launched path when no
if constexprbranch matches.In both
moe_bgmv_shrink_sliced(lines 359–367) andmoe_bgmv_expand_sliced(lines 392–413), if none of the cascadedif constexprconditions are satisfied for the instantiatedfeat_in/feat_out, the function returns without launching anything and the user seesYunchanged with no diagnostic. For the expand path this is more likely (e.g.,feat_in/vec_sizenot dividing 32, 16, and 8 withfeat_outdivisibility).Adding a terminal
else { static_assert(false, ...); }(using a dependent-type trick to defer the assert) catches misconfigurations at compile time when an instantiation is added without a matching launch path.🛠️ Sketch
} else if constexpr (feat_in % cfg_tx == 0) { DISPATCH(1); + } else { + static_assert(sizeof(in_T) == 0, + "moe_bgmv_shrink_sliced: feat_in is not divisible by cfg_tx; " + "no kernel variant available for this configuration."); }Apply the analogous pattern to
moe_bgmv_expand_sliced.Also applies to: 392-414
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@csrc/bgmv_moe/moe_bgmv_impl.cuh` around lines 359 - 371, The cascaded if constexpr chains in moe_bgmv_shrink_sliced and moe_bgmv_expand_sliced end with no terminal branch, causing a silent no-kernel-launched path for unmatched feat_in/feat_out cases; add a final else that triggers a compile-time failure using a dependent false static_assert (e.g., use an always_false<T> or dependent_false<decltype(feat_in)> trick) after the DISPATCH/LAUNCH chain so that when none of the DISPATCH(...) branches match the compiler emits a clear error message about unsupported feat_in/feat_out configuration; apply this same pattern to both functions immediately after the existing if constexpr cascade (before the `#undef` DISPATCH / `#undef` LAUNCH) and reference the same DISPATCH symbol names so the assert runs only when no DISPATCH branch was selected.
263-299: ⚖️ Poor tradeoffMinor:
x_vecis reloaded by every(threadIdx.y, threadIdx.z)lane group.
x_vec.load(...)at line 285 only varies onthreadIdx.x, so allty * tzlane groups within the block redundantly fetch the same slice ofX. Since the kernel is memory-bound onWthis isn't critical, but stagingx_vecthrough shared memory once per block (or through a single warp's load + warp shuffle) would cut redundant global reads on long-rank/large-feat_outconfigs. Optional.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@csrc/bgmv_moe/moe_bgmv_impl.cuh` around lines 263 - 299, The kernel moe_bgmv_expand_sliced_kernel redundantly reloads x_vec in every (threadIdx.y, threadIdx.z) lane group because x_vec.load(X + ...) only varies with threadIdx.x; to fix, load the per-threadIdx.x slice of X once per block and share it across ty*tz lanes by staging into shared memory (e.g., a shared array keyed by block.thread_rank()/threadIdx.x) or have one warp/wavefront load x_vec and broadcast via warp shuffle, then use the shared/broadcasted vec for the subsequent w_vec multiply; update references to x_vec in moe_bgmv_expand_sliced_kernel accordingly so the global load occurs only once per unique threadIdx.x rather than for every (threadIdx.y, threadIdx.z).tests/moe/bench_bgmv_moe.py (1)
370-372: 💤 Low valueSurface the Triton failure reason during runs.
The
except Exceptionswallows JIT/launch failures and only stores a truncated error string into the results dict, which is never printed bymain(). Consider printing a warning so users can tell why Triton timing fell back to NaN. This is acceptable for a benchmark but currently silent.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/moe/bench_bgmv_moe.py` around lines 370 - 372, The except block that sets results["triton_sgmv_us"]=NaN and stores a truncated results["triton_error"] hides the real Triton failure; update that except handler to emit a visible warning/log message (e.g., using logging.warning or warnings.warn) including the full exception message and traceback (use traceback.format_exc()) so users see why Triton timing fell back to NaN, keep storing the error in results["triton_error"] but do not rely on main() to surface it.flashinfer/fused_moe/bgmv_moe.py (1)
256-279: ⚡ Quick winAssert (or document) that all slices share the same
lora_stride.
lora_stride_aandlora_stride_bare overwritten on each loop iteration and only the last slice's value is forwarded to the kernel. This is fine when every slice'slora_a/b_weightstensor has the samestride(0), but nothing here enforces it; a heterogeneously-shaped weight list would silently produce wrong addressing.A cheap
assert weights.stride(0) == lora_stride_a(and same for_b) inside the loop would turn a silent miscompute into a clear error. Optional, but recommended for the public high-level API.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/fused_moe/bgmv_moe.py` around lines 256 - 279, The loop that sets lora_stride_a/lora_stride_b overwrites the stride per slice and can silently misaddress kernels if slices have differing stride(0); update the loops that call fill_w_ptr (the one using w_ptr_a/lora_a_weights and the one for w_ptr_b/lora_b_weights) to capture the stride returned for the first slice and assert on each subsequent iteration that the current slice's tensor stride(0) matches that saved stride (or raise a clear error); specifically, in the for s in range(num_slices) loops around fill_w_ptr, after calling fill_w_ptr for s==0 store the stride value (or use the returned lora_stride_*), and for s>0 assert lora_{a,b}_weights[s].stride(0) == lora_stride_{a,b} (or raise ValueError) so the kernel always receives a consistent lora_stride.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@csrc/bgmv_moe/moe_bgmv_ops.cu`:
- Around line 168-172: Add an explicit guard that output_slices is non-empty
before dereferencing: check output_slices.empty() and TORCH_CHECK(false, ...) if
empty; then validate the int64→int32 narrowing by checking the first element
fits in int32 (e.g., compare against INT32_MIN/INT32_MAX) before performing
static_cast and use TORCH_CHECK to fail with a clear message if it does not;
keep the existing loop that ensures all entries in output_slices equal the first
(and add a short comment near output_slices/first_feat_out documenting the “all
slices must be equal” contract so the kernel-launch requirement is explicit).
In `@flashinfer/fused_moe/bgmv_moe.py`:
- Around line 48-88: Remove the entire PyTorch-JIT fallback block that calls
torch.utils.cpp_extension.load (the try/except that builds module
name="flashinfer_bgmv_moe_cuda" from the csrc list and then raises ImportError
on failure); instead do not perform this separate load and let failures from
gen_bgmv_moe_module() (in flashinfer.jit.bgmv_moe) surface to the caller. In
practice, delete the load(...) block and its except, and ensure any callers
still rely on gen_bgmv_moe_module() or the FLASHINFER_JIT_DIR path so we avoid
duplicating the source list and bypassing FlashInfer's JIT layout.
- Around line 253-289: The code treats output_dim incorrectly and does
per-element CUDA writes for slice_start_loc: change the total_feat_out
calculation to check "output_dim is not None" when computing total_feat_out from
feat_out_per_slice so an explicit 0 is honored; and avoid the GPU sync loop by
building slice_start_loc on the CPU (e.g., compute cumulative start locations
from feat_out_per_slice into a CPU tensor or use torch.cumsum on a CPU tensor)
and then move the completed tensor to the device once before use (replace the
per-iteration assignments to slice_start_loc with a CPU-side build and a single
.to(device) transfer).
In `@tests/moe/bench_bgmv_moe.py`:
- Around line 404-409: slice_start_loc currently uses torch.zeros(...) so every
slice writes to the same y_accum columns; change it to compute per-slice
starting offsets like torch.arange(num_slices, dtype=torch.int64, device=device)
* feat_out so each slice maps to a distinct column range in y_accum; also ensure
any per-slice outputs (output_slices) are created/used per-slice (not by
replicating the same object) so writes target the correct y_accum segment for
each slice.
- Around line 28-30: Replace the no-op sys.path.insert(0, sys.path[0]) with
inserting the actual directory containing this script so the local import of
generate_test_data and reference_bgmv_moe works reliably; specifically, import
os and insert os.path.dirname(__file__) (or its abspath) at the front of
sys.path before the from test_bgmv_moe import generate_test_data,
reference_bgmv_moe line so the module is found regardless of current working
directory or how the file is executed.
---
Nitpick comments:
In `@csrc/bgmv_moe/moe_bgmv_fp32_fp16_fp16.cu`:
- Around line 11-14: The macro INST_MOE_BGMV_SHRINK_ONLY is left defined after
its use; after the FOR_MOE_ALL_WIDE_NARROW(INST_MOE_BGMV_SHRINK_ONLY, float,
nv_half, nv_half) line, add an undef for INST_MOE_BGMV_SHRINK_ONLY to remove it
from translation-unit scope (i.e., insert an `#undef` INST_MOE_BGMV_SHRINK_ONLY
immediately following that macro expansion).
In `@csrc/bgmv_moe/moe_bgmv_impl.cuh`:
- Around line 246-258: Document that the device kernels accumulate into the
output buffer Y (the device code uses Y += ... in the shrink kernel and
atomicAdd in the expand kernel), so callers must provide a zero-initialized Y
before invoking moe_bgmv_shrink_sliced and moe_bgmv_expand_sliced; add a brief
comment in the host-side declarations / header above moe_bgmv_shrink_sliced and
moe_bgmv_expand_sliced (or in the wrapper functions that call the kernels)
stating the accumulation contract and that the caller is responsible for zeroing
Y (e.g., use torch.zeros or cudaMemset) to avoid silent incorrect results.
- Around line 320-326: The cuda runtime calls cudaGetDevice and
cudaDeviceGetAttribute in the dispatch path are unchecked, so errors leave
sm_major at 0 and silently force the non-extended path; update the dispatch
logic (around the sm_major/extended/decode calculation that uses cudaGetDevice,
cudaDeviceGetAttribute, sm_major, extended and
MoeShrinkKernelConfig::decode_threshold) to check and handle the returned
cudaError_t values (propagate/log and choose a safe fallback) or, when ATen is
available, obtain and cache the device compute capability from
at::cuda::getCurrentDeviceProperties() and derive sm_major from that to avoid
calling the runtime repeatedly and to surface real failures.
- Around line 359-371: The cascaded if constexpr chains in
moe_bgmv_shrink_sliced and moe_bgmv_expand_sliced end with no terminal branch,
causing a silent no-kernel-launched path for unmatched feat_in/feat_out cases;
add a final else that triggers a compile-time failure using a dependent false
static_assert (e.g., use an always_false<T> or
dependent_false<decltype(feat_in)> trick) after the DISPATCH/LAUNCH chain so
that when none of the DISPATCH(...) branches match the compiler emits a clear
error message about unsupported feat_in/feat_out configuration; apply this same
pattern to both functions immediately after the existing if constexpr cascade
(before the `#undef` DISPATCH / `#undef` LAUNCH) and reference the same DISPATCH
symbol names so the assert runs only when no DISPATCH branch was selected.
- Around line 263-299: The kernel moe_bgmv_expand_sliced_kernel redundantly
reloads x_vec in every (threadIdx.y, threadIdx.z) lane group because
x_vec.load(X + ...) only varies with threadIdx.x; to fix, load the
per-threadIdx.x slice of X once per block and share it across ty*tz lanes by
staging into shared memory (e.g., a shared array keyed by
block.thread_rank()/threadIdx.x) or have one warp/wavefront load x_vec and
broadcast via warp shuffle, then use the shared/broadcasted vec for the
subsequent w_vec multiply; update references to x_vec in
moe_bgmv_expand_sliced_kernel accordingly so the global load occurs only once
per unique threadIdx.x rather than for every (threadIdx.y, threadIdx.z).
In `@flashinfer/fused_moe/bgmv_moe.py`:
- Around line 256-279: The loop that sets lora_stride_a/lora_stride_b overwrites
the stride per slice and can silently misaddress kernels if slices have
differing stride(0); update the loops that call fill_w_ptr (the one using
w_ptr_a/lora_a_weights and the one for w_ptr_b/lora_b_weights) to capture the
stride returned for the first slice and assert on each subsequent iteration that
the current slice's tensor stride(0) matches that saved stride (or raise a clear
error); specifically, in the for s in range(num_slices) loops around fill_w_ptr,
after calling fill_w_ptr for s==0 store the stride value (or use the returned
lora_stride_*), and for s>0 assert lora_{a,b}_weights[s].stride(0) ==
lora_stride_{a,b} (or raise ValueError) so the kernel always receives a
consistent lora_stride.
In `@tests/moe/bench_bgmv_moe.py`:
- Around line 370-372: The except block that sets results["triton_sgmv_us"]=NaN
and stores a truncated results["triton_error"] hides the real Triton failure;
update that except handler to emit a visible warning/log message (e.g., using
logging.warning or warnings.warn) including the full exception message and
traceback (use traceback.format_exc()) so users see why Triton timing fell back
to NaN, keep storing the error in results["triton_error"] but do not rely on
main() to surface it.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 4fbc3b7b-9214-4130-8d1a-5a8ddf70a2c0
📥 Commits
Reviewing files that changed from the base of the PR and between cb44e7d and 478addebd343adb6e5812a6a9fa3494223843656.
📒 Files selected for processing (21)
csrc/bgmv_moe/kernel_config.hcsrc/bgmv_moe/moe_bgmv_bf16_bf16_bf16.cucsrc/bgmv_moe/moe_bgmv_bf16_fp32_bf16.cucsrc/bgmv_moe/moe_bgmv_binding.cucsrc/bgmv_moe/moe_bgmv_config.hcsrc/bgmv_moe/moe_bgmv_fp16_fp16_fp16.cucsrc/bgmv_moe/moe_bgmv_fp16_fp32_fp16.cucsrc/bgmv_moe/moe_bgmv_fp32_bf16_bf16.cucsrc/bgmv_moe/moe_bgmv_fp32_fp16_fp16.cucsrc/bgmv_moe/moe_bgmv_impl.cuhcsrc/bgmv_moe/moe_bgmv_ops.cucsrc/bgmv_moe/moe_bgmv_ops.hcsrc/bgmv_moe/setup.pycsrc/bgmv_moe/vec_dtypes.cuhflashinfer/fused_moe/__init__.pyflashinfer/fused_moe/bgmv_moe.pyflashinfer/jit/__init__.pyflashinfer/jit/bgmv_moe.pytests/moe/bench_bgmv_moe.pytests/moe/test_bgmv_moe.pytests/moe/triton_moe_lora_baseline.py
7968961 to
b2eb2ff
Compare
|
ac384ec to
1bdf3ee
Compare
There was a problem hiding this comment.
Actionable comments posted: 6
🧹 Nitpick comments (1)
benchmarks/routines/moe.py (1)
2559-2680: ⚡ Quick winOnly suppress missing baseline imports here.
The broad
Exceptionpath turns real grouped-mm setup regressions into a benign “baseline skipped” message, which makes the benchmark output misleading.ImportErroris the case worth tolerating; the rest should still fail.Suggested fix
- except (ImportError, Exception) as e: + except ImportError as e: if args.verbose >= 1: print(f"[INFO] grouped_mm_bf16 baseline skipped: {e}")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@benchmarks/routines/moe.py` around lines 2559 - 2680, The try/except around the grouped_mm_bf16 setup is currently catching all Exceptions and hiding real setup/runtime errors; change the except clause to only catch ImportError (e.g., "except ImportError as e:") so that missing optional dependency skips the baseline while other errors in the grouped_mm_bf16 setup (inside the try block referencing grouped_mm_bf16, num_groups, g_m_indptr, etc.) will propagate and fail loudly; keep the existing args.verbose print inside that ImportError handler.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@benchmarks/bench_bgmv_moe.py`:
- Around line 184-188: w_ptr_a and w_ptr_b are only populated for slice 0;
change the initialization to call fill_w_ptr for each slice so every row in
w_ptr_* points to the correct weight chunk—iterate over slice indices and call
fill_w_ptr(w_ptr_a, data["lora_a_weights"][i], num_experts, i) and similarly for
w_ptr_b using data["lora_b_weights"][i]; ensure you still allocate
w_ptr_a/w_ptr_b as torch.zeros(num_slices, num_experts, ...) and update
lora_stride_a/lora_stride_b appropriately (e.g., keep or collect per-slice
strides returned by fill_w_ptr).
- Around line 142-158: The custom benchmarking logic in benchmark_fn should be
replaced with the repo-standard helper: call
flashinfer.testing.bench_gpu_time(fn, warmup=warmup, repeat=repeat) instead of
manually using torch.cuda.synchronize and time.perf_counter_ns; update the
function body of benchmark_fn to delegate to bench_gpu_time and add the required
import for flashinfer.testing.bench_gpu_time at the top of the file so
CUPTI/CUDA event handling and fallback behavior are preserved.
In `@benchmarks/routines/moe.py`:
- Around line 301-322: Validate the BGMV CLI args right after parsing (before
any CUDA/JIT dispatch) for the flags defined on parser ("--rank", "--num_loras",
"--num_slices", and "--input_dtype"); ensure rank and num_loras are positive
ints (>0), num_slices is an int >=1, and input_dtype is one of the supported
dtypes (explicitly list allowed strings used later by the kernel/JIT). If a
value is invalid, call parser.error(...) or raise argparse.ArgumentTypeError to
fail fast with a clear message. Add these checks in the same scope where parser
is used (e.g., the routine that calls parser.parse_args() / run_bgmv_moe) so
invalid inputs are rejected at the CLI boundary rather than deferred to
CUDA/JIT.
- Around line 2466-2471: The pre-allocation sets w_ptr_a/w_ptr_b with num_slices
rows but only fills row 0; update the initialization to iterate over slice
indices and call fill_w_ptr for each slice so each row is populated (use
fill_w_ptr(w_ptr_a, lora_a_weights[slice], num_experts, slice) and similarly for
w_ptr_b with lora_b_weights[slice]); ensure any returned stride values
(lora_stride_a/lora_stride_b) are collected per-slice (e.g., into a list or
tensor) if later code expects per-slice strides.
- Around line 119-125: The --intermediate_size argument should remain required
for legacy MoE routines; change the parser.add_argument call for
"--intermediate_size" to required=True (remove the default=0) so non-bgmv_moe
paths must supply a valid size, and relax it only for the bgmv_moe path by
conditionally adding or overriding this argument when the selected routine is
"bgmv_moe" (e.g., add a separate parser branch or set a non-required fallback
for bgmv_moe). Update references to parser.add_argument, the
"--intermediate_size" flag, and the bgmv_moe routine to implement this
conditional behavior.
In `@csrc/bgmv_moe/kernel_config.h`:
- Around line 25-28: The extended 3-stage shared-memory config
(num_stages_extended = 3) can be enabled on A100 because the runtime gate in
moe_bgmv_impl.cuh checks sm_major >= 8, but A100 (sm_80) cannot support the
216KB request; change the runtime gating to require sm_major >= 9 so only sm_90+
uses the 3-stage path and A100 falls back to num_stages_default (2-stage) to
stay within its 164KB limit. Locate the check that reads sm_major >= 8 in
moe_bgmv_impl.cuh (the branch that selects num_stages_extended) and update it to
sm_major >= 9 so extended mode is restricted to H100/H200.
---
Nitpick comments:
In `@benchmarks/routines/moe.py`:
- Around line 2559-2680: The try/except around the grouped_mm_bf16 setup is
currently catching all Exceptions and hiding real setup/runtime errors; change
the except clause to only catch ImportError (e.g., "except ImportError as e:")
so that missing optional dependency skips the baseline while other errors in the
grouped_mm_bf16 setup (inside the try block referencing grouped_mm_bf16,
num_groups, g_m_indptr, etc.) will propagate and fail loudly; keep the existing
args.verbose print inside that ImportError handler.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 3518775a-b310-42ce-b0eb-e0e53dbe8246
📥 Commits
Reviewing files that changed from the base of the PR and between 478addebd343adb6e5812a6a9fa3494223843656 and 1bdf3ee0d0422ead21064db37ff6989ac6324747.
📒 Files selected for processing (22)
benchmarks/bench_bgmv_moe.pybenchmarks/routines/flashinfer_benchmark_utils.pybenchmarks/routines/moe.pycsrc/bgmv_moe/kernel_config.hcsrc/bgmv_moe/moe_bgmv_bf16_bf16_bf16.cucsrc/bgmv_moe/moe_bgmv_bf16_fp32_bf16.cucsrc/bgmv_moe/moe_bgmv_binding.cucsrc/bgmv_moe/moe_bgmv_config.hcsrc/bgmv_moe/moe_bgmv_fp16_fp16_fp16.cucsrc/bgmv_moe/moe_bgmv_fp16_fp32_fp16.cucsrc/bgmv_moe/moe_bgmv_fp32_bf16_bf16.cucsrc/bgmv_moe/moe_bgmv_fp32_fp16_fp16.cucsrc/bgmv_moe/moe_bgmv_impl.cuhcsrc/bgmv_moe/moe_bgmv_ops.cucsrc/bgmv_moe/moe_bgmv_ops.hcsrc/bgmv_moe/vec_dtypes.cuhflashinfer/fused_moe/__init__.pyflashinfer/fused_moe/bgmv_moe.pyflashinfer/jit/__init__.pyflashinfer/jit/bgmv_moe.pyscripts/task_jit_run_tests_part5.shtests/moe/test_bgmv_moe.py
✅ Files skipped from review due to trivial changes (4)
- csrc/bgmv_moe/moe_bgmv_fp16_fp16_fp16.cu
- flashinfer/jit/init.py
- csrc/bgmv_moe/moe_bgmv_fp32_fp16_fp16.cu
- benchmarks/routines/flashinfer_benchmark_utils.py
🚧 Files skipped from review as they are similar to previous changes (11)
- csrc/bgmv_moe/moe_bgmv_fp32_bf16_bf16.cu
- csrc/bgmv_moe/moe_bgmv_bf16_bf16_bf16.cu
- csrc/bgmv_moe/moe_bgmv_fp16_fp32_fp16.cu
- csrc/bgmv_moe/moe_bgmv_bf16_fp32_bf16.cu
- csrc/bgmv_moe/moe_bgmv_binding.cu
- flashinfer/jit/bgmv_moe.py
- csrc/bgmv_moe/moe_bgmv_ops.cu
- tests/moe/test_bgmv_moe.py
- csrc/bgmv_moe/vec_dtypes.cuh
- flashinfer/fused_moe/bgmv_moe.py
- csrc/bgmv_moe/moe_bgmv_impl.cuh
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (4)
csrc/bgmv_moe/moe_bgmv_ops.h (2)
10-10: ⚡ Quick winAdd an explicit
#include <vector>forstd::vector.
std::vectoris used in the signature on line 19 and is currently brought in only transitively throughtorch/all.h. If that header is ever replaced or refactored this TU will silently break.♻️ Proposed fix
`#include` <torch/all.h> +#include <vector>🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@csrc/bgmv_moe/moe_bgmv_ops.h` at line 10, Add an explicit include for the vector header because std::vector is used in the function signature (the symbol std::vector appears on the signature around line 19); update the top of moe_bgmv_ops.h to add `#include` <vector> so the translation unit does not rely on transitive includes from <torch/all.h>.
16-20: ⚡ Quick winPass
output_slicesbyconstreference, not by value.
std::vector<int64_t> output_slicesis passed by value, causing unnecessary heap allocation and copying on every call. The function only reads the vector (never modifies it), so the signature should beconst std::vector<int64_t>&.Update both the declaration in
moe_bgmv_ops.h(line 19) and the definition inmoe_bgmv_ops.cu(line 139). The PyBind11 binding inmoe_bgmv_binding.cuwill continue to work transparently.♻️ Proposed fix
void dispatch_bgmv_moe_expand(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor topk_weights, torch::Tensor lora_indices, - torch::Tensor slice_start_loc, std::vector<int64_t> output_slices, + torch::Tensor slice_start_loc, const std::vector<int64_t>& output_slices, int64_t lora_stride);🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@csrc/bgmv_moe/moe_bgmv_ops.h` around lines 16 - 20, The declaration and definition of dispatch_bgmv_moe_expand currently take std::vector<int64_t> output_slices by value causing unnecessary copies; change the signature in the header (dispatch_bgmv_moe_expand) and the matching definition in the implementation to take const std::vector<int64_t>& output_slices instead, and ensure all references/uses inside the function treat it as a const reference (no mutation); the PyBind11 binding requires no change.csrc/bgmv_moe/kernel_config.h (1)
1-40: 💤 Low valueConsider renaming to
.cuhto suppress C-mode Clang false positives.The static analysis tool (Clang 14) is analyzing this
.has C, whereconstexpris unknown, producing a cascade of "unknown type name 'constexpr'" and "type name does not allow storage class to be specified" errors on everystatic constexprmember. The actual nvcc build is unaffected (C++ mode), but these errors may create persistent CI noise.The common convention in CUDA projects is: if a header file contains CUDA kernels and/or device methods, it should be a
.cuh. Even though this file contains no device code, using.cuh(or.hpp) would signal C++ compilation context to analysis tooling and eliminate the false positives.♻️ Proposed rename
-// kernel_config.h +// kernel_config.cuhAny include sites referencing
"kernel_config.h"incsrc/bgmv_moe/would need updating to"kernel_config.cuh".🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@csrc/bgmv_moe/kernel_config.h` around lines 1 - 40, Clang treats kernel_config.h as C and flags C++ keywords like constexpr; rename the header to kernel_config.cuh and update all includes that reference "kernel_config.h" (e.g., in files under csrc/bgmv_moe/) to "kernel_config.cuh" so the C++/CUDA tooling sees the file as C++/CUDA; ensure the types/structs MoeShrinkKernelConfig and MoeExpandKernelConfig remain unchanged and that build/CI passes after the include updates.benchmarks/bench_bgmv_moe.py (1)
142-158: ⚡ Quick winAdd a short note explaining why this benchmark bypasses
bench_gpu_time().Without the portability rationale, this looks like an accidental divergence from the repo benchmark convention and is likely to get “fixed” again later. A one-line comment here is enough.
Based on learnings: In benchmark scripts under
benchmarks/, it may be intentional to measure GPU time without usingflashinfer.testing.bench_gpu_time. If so, add a short comment explaining thatbench_gpu_timecan require CUPTI or CUDA-event fallback dependencies, and that the custom timer keeps the script portable.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@benchmarks/bench_bgmv_moe.py` around lines 142 - 158, Add a one-line comment above benchmark_fn explaining that this benchmark intentionally bypasses flashinfer.testing.bench_gpu_time to avoid CUPTI or CUDA-event fallback dependencies and to keep the script portable, and that it measures GPU wall-clock time via torch.cuda.synchronize() + time.perf_counter_ns(); reference the function name benchmark_fn in the comment so reviewers understand this is an intentional divergence from the repo convention.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@benchmarks/bench_bgmv_moe.py`:
- Around line 292-295: The blanket except Exception that sets
results["gg_kernel_us"] and results["gg_full_us"] to NaN must be replaced: only
catch the specific, expected optional-failure cases from grouped_mm_bf16 (e.g.,
ImportError, RuntimeError, or the exact exception the optional path raises), log
an explicit skip reason via the existing logger (or print) every time you skip,
and re-raise unexpected exceptions so real regressions/bugs surface; ensure the
code path that prints comparison numbers checks for NaN in
results["gg_kernel_us"]/results["gg_full_us"] and treats them as skipped rather
than silently comparing.
In `@benchmarks/routines/moe.py`:
- Around line 2681-2683: The current broad except in the grouped_mm_bf16
baseline swallow hides real bugs; narrow it to only expected optional failures
(e.g., catch ImportError and ModuleNotFoundError or other known setup
exceptions) and for any other Exception re-raise so genuine errors surface;
always emit the skip reason regardless of args.verbose (use the same print
message that includes the exception text like f"[INFO] grouped_mm_bf16 baseline
skipped: {e}") so the error is visible even when args.verbose == 0; reference
the existing except block handling grouped_mm_bf16 and the args.verbose variable
when making these changes.
- Around line 2590-2592: The baseline grouped_mm is only flattening
lora_a_weights[0]/lora_b_weights[0] which undercounts work when num_slices>1;
update the baseline path to either (A) run per-slice baseline by iterating
slices and reshaping each slice's lora_a_weights[slice_idx] -> g_lora_a and
lora_b_weights[slice_idx] -> g_lora_b (or reshape the weights to shape
(num_slices, num_loras * num_experts, rank, hidden_size) and process each slice)
and accumulate y_accum/TFLOPS the same way BGMV does, or (B) skip/disable the
grouped_mm baseline comparison when num_slices != 1. Modify the code that
creates g_lora_a/g_lora_b (references: lora_a_weights, lora_b_weights, g_lora_a,
g_lora_b, num_slices) to implement the chosen approach so the baseline work
matches the BGMV path.
In `@csrc/bgmv_moe/moe_bgmv_config.h`:
- Around line 1-2: Add a direct include for <cstdint> to make int64_t/int32_t
defined and wrap C++-only declarations (the template forward declarations around
the symbols referenced on lines ~63 and ~71 and the code using int64_t/int32_t
around lines ~65–79) with an `#ifdef` __cplusplus / `#endif` guard so the header is
safe to include from C; in short, `#include` <cstdint> at the top and enclose the
template forward declarations and any other C++-only constructs inside a C++
language guard to prevent C parsers from choking.
In `@flashinfer/fused_moe/__init__.py`:
- Around line 49-61: The except ImportError branch currently hides has_bgmv_moe
so callers can't probe availability; update the except block to define a
fallback function has_bgmv_moe() that returns False and ensure it's exported
(i.e., present in the module namespace) when the import fails, and keep setting
_bgmv_moe_available = False; apply the same pattern to the other similar block
around the 102-110 region so both bgmv_moe-related import blocks always expose
has_bgmv_moe().
---
Nitpick comments:
In `@benchmarks/bench_bgmv_moe.py`:
- Around line 142-158: Add a one-line comment above benchmark_fn explaining that
this benchmark intentionally bypasses flashinfer.testing.bench_gpu_time to avoid
CUPTI or CUDA-event fallback dependencies and to keep the script portable, and
that it measures GPU wall-clock time via torch.cuda.synchronize() +
time.perf_counter_ns(); reference the function name benchmark_fn in the comment
so reviewers understand this is an intentional divergence from the repo
convention.
In `@csrc/bgmv_moe/kernel_config.h`:
- Around line 1-40: Clang treats kernel_config.h as C and flags C++ keywords
like constexpr; rename the header to kernel_config.cuh and update all includes
that reference "kernel_config.h" (e.g., in files under csrc/bgmv_moe/) to
"kernel_config.cuh" so the C++/CUDA tooling sees the file as C++/CUDA; ensure
the types/structs MoeShrinkKernelConfig and MoeExpandKernelConfig remain
unchanged and that build/CI passes after the include updates.
In `@csrc/bgmv_moe/moe_bgmv_ops.h`:
- Line 10: Add an explicit include for the vector header because std::vector is
used in the function signature (the symbol std::vector appears on the signature
around line 19); update the top of moe_bgmv_ops.h to add `#include` <vector> so
the translation unit does not rely on transitive includes from <torch/all.h>.
- Around line 16-20: The declaration and definition of dispatch_bgmv_moe_expand
currently take std::vector<int64_t> output_slices by value causing unnecessary
copies; change the signature in the header (dispatch_bgmv_moe_expand) and the
matching definition in the implementation to take const std::vector<int64_t>&
output_slices instead, and ensure all references/uses inside the function treat
it as a const reference (no mutation); the PyBind11 binding requires no change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 3cb1feb7-c3d3-4c89-833b-ee5a71e17e82
📥 Commits
Reviewing files that changed from the base of the PR and between 1bdf3ee0d0422ead21064db37ff6989ac6324747 and cc6ab3b02260006280ef19487e74af01fec32cbb.
📒 Files selected for processing (22)
benchmarks/bench_bgmv_moe.pybenchmarks/routines/flashinfer_benchmark_utils.pybenchmarks/routines/moe.pycsrc/bgmv_moe/kernel_config.hcsrc/bgmv_moe/moe_bgmv_bf16_bf16_bf16.cucsrc/bgmv_moe/moe_bgmv_bf16_fp32_bf16.cucsrc/bgmv_moe/moe_bgmv_binding.cucsrc/bgmv_moe/moe_bgmv_config.hcsrc/bgmv_moe/moe_bgmv_fp16_fp16_fp16.cucsrc/bgmv_moe/moe_bgmv_fp16_fp32_fp16.cucsrc/bgmv_moe/moe_bgmv_fp32_bf16_bf16.cucsrc/bgmv_moe/moe_bgmv_fp32_fp16_fp16.cucsrc/bgmv_moe/moe_bgmv_impl.cuhcsrc/bgmv_moe/moe_bgmv_ops.cucsrc/bgmv_moe/moe_bgmv_ops.hcsrc/bgmv_moe/vec_dtypes.cuhflashinfer/fused_moe/__init__.pyflashinfer/fused_moe/bgmv_moe.pyflashinfer/jit/__init__.pyflashinfer/jit/bgmv_moe.pyscripts/task_jit_run_tests_part5.shtests/moe/test_bgmv_moe.py
✅ Files skipped from review due to trivial changes (2)
- csrc/bgmv_moe/moe_bgmv_fp16_fp16_fp16.cu
- scripts/task_jit_run_tests_part5.sh
🚧 Files skipped from review as they are similar to previous changes (14)
- csrc/bgmv_moe/moe_bgmv_bf16_fp32_bf16.cu
- csrc/bgmv_moe/moe_bgmv_fp16_fp32_fp16.cu
- csrc/bgmv_moe/moe_bgmv_fp32_bf16_bf16.cu
- flashinfer/jit/init.py
- csrc/bgmv_moe/moe_bgmv_ops.cu
- csrc/bgmv_moe/vec_dtypes.cuh
- benchmarks/routines/flashinfer_benchmark_utils.py
- flashinfer/fused_moe/bgmv_moe.py
- csrc/bgmv_moe/moe_bgmv_binding.cu
- flashinfer/jit/bgmv_moe.py
- csrc/bgmv_moe/moe_bgmv_impl.cuh
- csrc/bgmv_moe/moe_bgmv_bf16_bf16_bf16.cu
- csrc/bgmv_moe/moe_bgmv_fp32_fp16_fp16.cu
- tests/moe/test_bgmv_moe.py
There was a problem hiding this comment.
🧹 Nitpick comments (1)
csrc/bgmv_moe/moe_bgmv_ops.h (1)
19-20: ⚡ Quick winChange
output_slicesparameter to const reference to avoid unnecessary copies.At Line 19,
std::vector<int64_t> output_slicesis passed by value. Since the parameter is only read (lines 168-175 in moe_bgmv_ops.cu verify no mutations), preferconst std::vector<int64_t>&to avoid copying on each call.Proposed change
-void dispatch_bgmv_moe_expand(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr, - torch::Tensor sorted_token_ids, torch::Tensor expert_ids, - torch::Tensor topk_weights, torch::Tensor lora_indices, - torch::Tensor slice_start_loc, std::vector<int64_t> output_slices, - int64_t lora_stride); +void dispatch_bgmv_moe_expand(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr, + torch::Tensor sorted_token_ids, torch::Tensor expert_ids, + torch::Tensor topk_weights, torch::Tensor lora_indices, + torch::Tensor slice_start_loc, const std::vector<int64_t>& output_slices, + int64_t lora_stride);🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@csrc/bgmv_moe/moe_bgmv_ops.h` around lines 19 - 20, The function declaration taking std::vector<int64_t> output_slices should use a const reference to avoid copying; change the parameter in the declaration in moe_bgmv_ops.h from std::vector<int64_t> output_slices to const std::vector<int64_t>& output_slices, and make the matching change in the implementation/definition in moe_bgmv_ops.cu (the function that reads output_slices around lines 168-175). Ensure any callers still compile (they normally will) and do not modify output_slices inside the function.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@csrc/bgmv_moe/moe_bgmv_ops.h`:
- Around line 19-20: The function declaration taking std::vector<int64_t>
output_slices should use a const reference to avoid copying; change the
parameter in the declaration in moe_bgmv_ops.h from std::vector<int64_t>
output_slices to const std::vector<int64_t>& output_slices, and make the
matching change in the implementation/definition in moe_bgmv_ops.cu (the
function that reads output_slices around lines 168-175). Ensure any callers
still compile (they normally will) and do not modify output_slices inside the
function.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ddabfb59-4621-4b71-a905-84bc5dcf4092
📥 Commits
Reviewing files that changed from the base of the PR and between cc6ab3b02260006280ef19487e74af01fec32cbb and 67bc5e0.
📒 Files selected for processing (22)
benchmarks/bench_bgmv_moe.pybenchmarks/routines/flashinfer_benchmark_utils.pybenchmarks/routines/moe.pycsrc/bgmv_moe/kernel_config.hcsrc/bgmv_moe/moe_bgmv_bf16_bf16_bf16.cucsrc/bgmv_moe/moe_bgmv_bf16_fp32_bf16.cucsrc/bgmv_moe/moe_bgmv_binding.cucsrc/bgmv_moe/moe_bgmv_config.hcsrc/bgmv_moe/moe_bgmv_fp16_fp16_fp16.cucsrc/bgmv_moe/moe_bgmv_fp16_fp32_fp16.cucsrc/bgmv_moe/moe_bgmv_fp32_bf16_bf16.cucsrc/bgmv_moe/moe_bgmv_fp32_fp16_fp16.cucsrc/bgmv_moe/moe_bgmv_impl.cuhcsrc/bgmv_moe/moe_bgmv_ops.cucsrc/bgmv_moe/moe_bgmv_ops.hcsrc/bgmv_moe/vec_dtypes.cuhflashinfer/fused_moe/__init__.pyflashinfer/fused_moe/bgmv_moe.pyflashinfer/jit/__init__.pyflashinfer/jit/bgmv_moe.pyscripts/task_jit_run_tests_part5.shtests/moe/test_bgmv_moe.py
✅ Files skipped from review due to trivial changes (4)
- scripts/task_jit_run_tests_part5.sh
- csrc/bgmv_moe/moe_bgmv_fp16_fp16_fp16.cu
- benchmarks/routines/flashinfer_benchmark_utils.py
- flashinfer/jit/init.py
🚧 Files skipped from review as they are similar to previous changes (14)
- csrc/bgmv_moe/moe_bgmv_fp16_fp32_fp16.cu
- csrc/bgmv_moe/moe_bgmv_bf16_bf16_bf16.cu
- csrc/bgmv_moe/moe_bgmv_fp32_fp16_fp16.cu
- csrc/bgmv_moe/moe_bgmv_fp32_bf16_bf16.cu
- csrc/bgmv_moe/moe_bgmv_bf16_fp32_bf16.cu
- csrc/bgmv_moe/vec_dtypes.cuh
- tests/moe/test_bgmv_moe.py
- flashinfer/jit/bgmv_moe.py
- benchmarks/routines/moe.py
- csrc/bgmv_moe/moe_bgmv_impl.cuh
- csrc/bgmv_moe/moe_bgmv_ops.cu
- csrc/bgmv_moe/moe_bgmv_binding.cu
- flashinfer/fused_moe/bgmv_moe.py
- benchmarks/bench_bgmv_moe.py
yzh119
left a comment
There was a problem hiding this comment.
The original bgmv kernel was actually written by me many years ago (https://le.qun.ch/en/blog/2023/09/11/multi-lora-potentials/) in punica.
This PR is an extension of bgmv to MoE setting which make sense to me but please make the following changes before merging:
- the vec_dtypes.cuh is actually based on an old version of flashinfer (in 2023) and have some bugs that are fixed later, please use https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/vec_dtypes.cuh instead.
- this PR uses torch extensions instead of tvm-ffi convention, which is a regression because we expect these operators to be framework agnostic. flashinfer codebase shifted to tvm-ffi since #1641 and #1795, please check these PRs and https://github.com/flashinfer-ai/flashinfer/blob/main/.claude/skills/add-cuda-kernel/SKILL.md on the FFI convention in flashinfer.
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@csrc/bgmv_moe/moe_bgmv_ops.cu`:
- Around line 81-83: The code casts sorted_token_ids and expert_ids to int64_t*
in shrink but only checks lora_indices/w_ptr dtypes; add dtype validation for
sorted_token_ids and expert_ids (ensure torch::kInt64) before performing pointer
casts in the shrink function so the pointer arithmetic uses the correct element
width; apply the same dtype checks where similar casts occur later (the block
around the second occurrence currently at the 99-106 region) and update error
messages to clearly name the offending tensor(s) if the check fails.
- Around line 122-126: In the expand routine, before any raw pointer
reinterpretation of sorted_token_ids, expert_ids, topk_weights, lora_indices,
and slice_start_loc, add explicit metadata guards: validate tensor ranks (ensure
they are 1D/flat as expected) and validate dtypes for sorted_token_ids and
expert_ids (e.g., int32/int64) plus topk_weights (float) and
lora_indices/slice_start_loc (integral), and fail early with a clear CHECK or
TORCH_CHECK if types/ranks mismatch; apply the same additional rank/dtype checks
at the second occurrence around the 149-157 block so both places validate
metadata prior to calling data_ptr or reinterpret_cast.
In `@flashinfer/fused_moe/bgmv_moe.py`:
- Around line 100-102: Validate that output_slices is non-empty and that all
slices share the same shape/stride before any use of output_slices[0] or before
launching kernels: check len(output_slices) > 0, then iterate slices to confirm
each slice length/stride matches the first slice (and matches lora_stride where
relevant). If validation fails, raise a clear ValueError describing which
invariant broke. Apply the same checks at the other kernel-launch sites noted
(the blocks around lines referred to as 202-217 and 231-257) to ensure uniform
per-slice kernel shape/stride before launching.
In `@flashinfer/jit/bgmv_moe.py`:
- Around line 111-113: The call that generates nvcc_flags in gen_bgmv_moe_module
currently uses supported_major_versions=None and must be restricted; update the
compilation context so get_nvcc_flags_list is invoked with
supported_major_versions=[9,10,11,12] (or construct a CompilationContext/kernels
definition that sets supported_major_versions=[9,10,11,12]) to limit JIT
compilation to SM majors 9, 10, 11 and 12 (adjusting the nvcc_flags assignment
and any related CompilationContext construction in
gen_bgmv_moe_module/current_compilation_context accordingly).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2fbb6214-a4d6-4ec0-a289-f41cfb6f8aa9
📒 Files selected for processing (21)
benchmarks/bench_bgmv_moe.pybenchmarks/routines/flashinfer_benchmark_utils.pybenchmarks/routines/moe.pycsrc/bgmv_moe/kernel_config.hcsrc/bgmv_moe/moe_bgmv_bf16_bf16_bf16.cucsrc/bgmv_moe/moe_bgmv_bf16_fp32_bf16.cucsrc/bgmv_moe/moe_bgmv_binding.cucsrc/bgmv_moe/moe_bgmv_config.hcsrc/bgmv_moe/moe_bgmv_fp16_fp16_fp16.cucsrc/bgmv_moe/moe_bgmv_fp16_fp32_fp16.cucsrc/bgmv_moe/moe_bgmv_fp32_bf16_bf16.cucsrc/bgmv_moe/moe_bgmv_fp32_fp16_fp16.cucsrc/bgmv_moe/moe_bgmv_impl.cuhcsrc/bgmv_moe/moe_bgmv_ops.cucsrc/bgmv_moe/moe_bgmv_ops.hflashinfer/fused_moe/__init__.pyflashinfer/fused_moe/bgmv_moe.pyflashinfer/jit/__init__.pyflashinfer/jit/bgmv_moe.pyscripts/task_jit_run_tests_part5.shtests/moe/test_bgmv_moe.py
✅ Files skipped from review due to trivial changes (4)
- csrc/bgmv_moe/moe_bgmv_fp16_fp16_fp16.cu
- csrc/bgmv_moe/moe_bgmv_bf16_bf16_bf16.cu
- benchmarks/routines/flashinfer_benchmark_utils.py
- flashinfer/jit/init.py
🚧 Files skipped from review as they are similar to previous changes (9)
- csrc/bgmv_moe/moe_bgmv_fp32_fp16_fp16.cu
- csrc/bgmv_moe/moe_bgmv_fp16_fp32_fp16.cu
- csrc/bgmv_moe/moe_bgmv_fp32_bf16_bf16.cu
- scripts/task_jit_run_tests_part5.sh
- csrc/bgmv_moe/moe_bgmv_bf16_fp32_bf16.cu
- benchmarks/routines/moe.py
- benchmarks/bench_bgmv_moe.py
- csrc/bgmv_moe/moe_bgmv_impl.cuh
- tests/moe/test_bgmv_moe.py
|
Hi @yzh119 , seems that @taehokim20 has revised the PR according to your suggestions. Would you mind review it again? Thank you! |
yzh119
left a comment
There was a problem hiding this comment.
Would you mind also adding it to https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/aot.py?
Added |
|
Gentle ping @yzh119, the AOT is added, should the next step be CI? 😁 |
|
/bot run |
|
I have added skips to the unit tests (skip 5090 and rtx pro 6000). |
|
Hi @yzh119 do you mind triggering the CI again? Thank you! |
|
/bot run |
|
[FAILED] Pipeline #53028071: 12/20 passed |
|
https://gitlab-master.nvidia.com/dl/flashinfer/flashinfer-ci/-/jobs/329006070/viewer |
Is there anything that I can fix more? I cannot access the link. |
Co-authored-by: Claude
📌 Description
This PR adds fused CUDA kernels for applying multiple LoRA adapters through MoE expert routing. When serving multiple LoRA adapters with MoE models, each (token, expert) pair needs to be routed through the correct LoRA adapter. These kernels fuse this operation into two efficient CUDA kernels:
bgmv_moe_shrink): Projects input through LoRA-A matrices -- compute-bound, uses async pipeline with RANK_TILE tiling and multi-pair blockingbgmv_moe_expand): Projects through LoRA-B matrices with routing weights -- memory-bound, uses warp-level reduction with atomicAddKey optimizations:
cp.asyncpipeline on SM90+ (216 KB shared memory on H100)Main changes:
csrc/bgmv_moe/)flashinfer/fused_moe/bgmv_moe.py)flashinfer/jit/bgmv_moe.py)tests/moe/test_bgmv_moe.py)benchmarks/bench_bgmv_moe.py) and integration intobenchmarks/routines/moe.pyscripts/task_jit_run_tests_part5.sh)Supported configurations:
Out of scope for this PR:
🔍 Related Issues
N/A
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Tested with:
Results:
Correctness Tests 123 passed
📊 Performance (H100 80GB HBM3)
3.3–6.7x faster than FlashInfer grouped_mm_bf16 (kernel only), 3.6–9.5x faster including sort overhead, across decode and prefill regimes.
Setup: 8 LoRA adapters, 128 experts, rank=32, top_k=2, BF16
Benchmark notes:
grouped_mm_bf16kernel only (cuDNN, pre-sorted input, no sort overhead).grouped_mm_bf16with token sorting (sort + kernel). When combining multi-LoRA × MoE, it createsnum_loras × num_expertsgroups (e.g., 8 × 128 = 1024) with 0-1 tokens each — a poor fit for grouped GEMM's large-group optimization.cp.asyncpipeline with explicit shared memory control, warp-level reduction (no minimum tile size constraint), and pointer indirection for zero-copy multi-adapter access.Summary by CodeRabbit
Release Notes
New Features
bgmv_moe,bgmv_moe_shrink,bgmv_moe_expand, andfill_w_ptrTests