-
Notifications
You must be signed in to change notification settings - Fork 905
Add flashinfer.fused_rmsnorm_silu() with native kernel backend #2965
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
kahyunnam
merged 18 commits into
flashinfer-ai:main
from
kahyunnam:knam/fused-rmsnorm-silu-option3_direct_kernel
Apr 8, 2026
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
0040d23
native kernel
kahyunnam 0aede25
remove cudnn references
kahyunnam 1f425eb
support checks
kahyunnam 5e192b0
Remove unused C++ knob LUT; Python LUT is the sole source of truth
kahyunnam 44829ed
address gemini-code-assist comment
kahyunnam f41322a
fix and clean up
kahyunnam 6fb1746
Fix include order in rmsnorm_silu.cu to match header dependencies
kahyunnam 5609758
add fallback logic (if misses LUT) to aot precompile for better dynam…
kahyunnam 3611dea
nvfp4 return Union of y and block_scale
kahyunnam 837768c
address https://github.com/flashinfer-ai/flashinfer/pull/2965#discuss…
kahyunnam 23bc908
address https://github.com/flashinfer-ai/flashinfer/pull/2965#discuss…
kahyunnam 92a2edd
changes
kahyunnam c9e785c
add optional user-pre-allocated-input for block scale output tensor
kahyunnam 9c4d9ad
add notes
kahyunnam fc18fbb
clarification on sm100 optimized, sm80+ supported
kahyunnam 86e38f0
update docs
kahyunnam 89df4cb
add to microbenchmarks
kahyunnam 9f15d23
Merge branch 'main' into knam/fused-rmsnorm-silu-option3_direct_kernel
kahyunnam File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| /* | ||
| * Copyright (c) 2026 by FlashInfer team. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| #include "tvm_ffi_utils.h" | ||
|
|
||
| void rmsnorm_silu(TensorView output, TensorView input, TensorView weight, double eps, | ||
| TensorView workspace, TensorView scale_row_out, int64_t sm_count); | ||
|
|
||
| TVM_FFI_DLL_EXPORT_TYPED_FUNC(rmsnorm_silu, rmsnorm_silu); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| /* | ||
| * Copyright (c) 2026 by FlashInfer team. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| // clang-format off | ||
| // Include order matters: headers → config (defines Ktraits) → kernel (uses Ktraits) | ||
| #include <algorithm> | ||
| #include <flashinfer/norm/ln_silu_headers.cuh> | ||
| #include "rmsnorm_silu_config.inc" | ||
| #include <flashinfer/norm/ln_fwd_silu_kernel.cuh> | ||
| // clang-format on | ||
|
|
||
| #include "tvm_ffi_utils.h" | ||
|
|
||
| void rmsnorm_silu(TensorView output, TensorView input, TensorView weight, double eps, | ||
| TensorView workspace, TensorView scale_row_out, int64_t sm_count) { | ||
| CHECK_LAST_DIM_CONTIGUOUS_INPUT(input); | ||
| CHECK_LAST_DIM_CONTIGUOUS_INPUT(output); | ||
| CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight); | ||
| CHECK_DEVICE(input, weight); | ||
| CHECK_DIM(2, input); | ||
| CHECK_DIM(2, output); | ||
| CHECK_DIM(1, weight); | ||
|
|
||
| int rows = input.size(0); | ||
| int cols = input.size(1); | ||
| TVM_FFI_ICHECK_EQ(cols, HIDDEN_SIZE) << "Input cols must match compiled HIDDEN_SIZE"; | ||
| TVM_FFI_ICHECK_EQ(output.size(0), rows); | ||
|
|
||
| ffi::CUDADeviceGuard device_guard(input.device().device_id); | ||
| const cudaStream_t stream = get_stream(input.device()); | ||
|
|
||
| // Grid dimensions (same logic as Sm100RmsNormSiluEngine::execute) | ||
| int ctas_per_col_max = (rows + WARPS_M - 1) / WARPS_M; | ||
| int ctas_per_col; | ||
| if (KERNEL_CFG == 2) { | ||
| ctas_per_col = ctas_per_col_max; | ||
| } else { | ||
| ctas_per_col = | ||
| std::min(static_cast<int>(sm_count) * DESIRED_OCCUPANCY / CTAS_PER_ROW, ctas_per_col_max); | ||
| } | ||
| ctas_per_col = std::max(ctas_per_col, 1); | ||
|
|
||
| dim3 grid(CTAS_PER_ROW * ctas_per_col); | ||
| dim3 block(WARPS_M * WARPS_N * 32); | ||
|
|
||
| // Pack kernel params | ||
| PersistentLnFwdParams params{}; | ||
| params.rows = rows; | ||
| params.cols = cols; | ||
| params.ctas_per_col = ctas_per_col; | ||
| params.isRMSNorm = true; | ||
| params.noScale = false; | ||
| params.noBias = true; | ||
| params.isBatchFirst = true; | ||
| params.batchSize = 1; | ||
| params.seqLen = rows; | ||
| params.epsilon = static_cast<float>(eps); | ||
| params.x = input.data_ptr(); | ||
| params.z = output.data_ptr(); | ||
| params.gamma = weight.data_ptr(); | ||
|
|
||
| // Workspace layout (128-byte aligned segments) | ||
| char* ws_ptr = static_cast<char*>(workspace.data_ptr()); | ||
|
|
||
| // [0] rs: rows * sizeof(float) | ||
| params.rs = ws_ptr; | ||
| int64_t off = static_cast<int64_t>(rows) * sizeof(float); | ||
| off = ((off + 127) / 128) * 128; | ||
|
|
||
| // [aligned] fp8_scale: sizeof(float) | ||
| if (isFP8Out) { | ||
| params.fp8_out = true; | ||
| float* default_scale = reinterpret_cast<float*>(ws_ptr + off); | ||
| // Set scale = 1.0f via cudaMemcpyAsync from host | ||
| static const float one = 1.0f; | ||
| cudaMemcpyAsync(default_scale, &one, sizeof(float), cudaMemcpyHostToDevice, stream); | ||
| params.scale = default_scale; | ||
| } | ||
| off += sizeof(float); | ||
| off = ((off + 127) / 128) * 128; | ||
|
|
||
| // scale_row: passed as separate output tensor (NVFP4 only) | ||
| if (isFP4Out) { | ||
| params.scale_row = scale_row_out.data_ptr(); | ||
| } | ||
|
|
||
| // [aligned] cooperative workspace + barriers (multi-CTA only) | ||
| if (CTAS_PER_ROW > 1) { | ||
| params.workspace = ws_ptr + off; | ||
| int64_t coop_ws_size = | ||
| static_cast<int64_t>(ctas_per_col) * WARPS_M * CTAS_PER_ROW * sizeof(float) * 2 * 2; | ||
| off += coop_ws_size; | ||
| off = ((off + 127) / 128) * 128; | ||
|
|
||
| params.barrier = reinterpret_cast<int*>(ws_ptr + off); | ||
| cudaMemsetAsync(params.barrier, 0, 2 * ctas_per_col * sizeof(int32_t), stream); | ||
| } | ||
|
|
||
| reduced_divisor divisor(rows); | ||
|
|
||
| ln_fwd_kernel<<<grid, block, 0, stream>>>(params, divisor); | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,3 +17,4 @@ Kernels for normalization layers. | |
| gemma_rmsnorm | ||
| gemma_fused_add_rmsnorm | ||
| layernorm | ||
| fused_rmsnorm_silu | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.