Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ __global__ void buildMinLatencyActiveExpertMapsKernel(
bool const smart_routing, int const cluster_rank, int const cluster_size,
int const num_experts_smem) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif
// Use one block to process the min latency case
int tid = threadIdx.x;
Expand Down Expand Up @@ -247,7 +247,7 @@ __global__ void buildMinLatencyActiveExpertMapsKernel(
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

Expand Down Expand Up @@ -309,7 +309,7 @@ __global__ void fusedBuildExpertMapsSortFirstTokenKernel(

// Wait PDL before reading token_selected_experts
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif

// build expert map
Expand Down Expand Up @@ -350,7 +350,7 @@ __global__ void fusedBuildExpertMapsSortFirstTokenKernel(

// We are done with compute, launch the dependent kernels while the stores are in flight
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif

// write to shared memory and global memory
Expand Down Expand Up @@ -550,7 +550,7 @@ __global__ void blockExpertPrefixSumKernel(int const* token_selected_experts,
int const token_id = block_id * kNumTokensPerBlock + threadIdx.x;

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif

int expanded_token_id = -1;
Expand Down Expand Up @@ -579,7 +579,7 @@ __global__ void blockExpertPrefixSumKernel(int const* token_selected_experts,
}

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

Expand Down Expand Up @@ -633,7 +633,7 @@ __global__ void globalExpertPrefixSumLargeKernel(int const* blocked_expert_count
int cnt = 0;

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif

// Note: Because of limited registers, cannot store thread-level prefix sum or enable #pragma
Expand Down Expand Up @@ -662,7 +662,7 @@ __global__ void globalExpertPrefixSumLargeKernel(int const* blocked_expert_count
}

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

Expand All @@ -676,7 +676,7 @@ __global__ void globalExpertPrefixSumKernel(int const* blocked_expert_counts,
__shared__ typename BlockScan::TempStorage temp_storage;

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif

int const cnt = threadIdx.x < num_experts_per_node * num_blocks_per_seq
Expand All @@ -696,7 +696,7 @@ __global__ void globalExpertPrefixSumKernel(int const* blocked_expert_counts,
}

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

Expand Down Expand Up @@ -759,7 +759,7 @@ __global__ void mergeExpertPrefixSumKernel(int const* blocked_expert_counts,
int const token_id = block_id * blockDim.x + threadIdx.x;

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif

int const cnt = blocked_expert_counts[target_expert_id * num_blocks_per_seq + block_id];
Expand All @@ -774,7 +774,7 @@ __global__ void mergeExpertPrefixSumKernel(int const* blocked_expert_counts,
}

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

Expand Down Expand Up @@ -1241,7 +1241,7 @@ __global__ void computeStridesTmaWarpSpecializedKernel(
}

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif

// Both gemms use the same token offset
Expand Down Expand Up @@ -1319,7 +1319,7 @@ __global__ void computeStridesTmaWarpSpecializedKernel(
quant_params.groupwise.fc2.weight_scales),
bias2, gemm2_output, router_scales, permuted_row_to_unpermuted_row, expert);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

Expand Down Expand Up @@ -1386,7 +1386,7 @@ __global__ void expandInputRowsKernel(
"of the expansion");

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif

constexpr int VecSize = is_nvfp4 ? TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize
Expand Down Expand Up @@ -1508,7 +1508,7 @@ __global__ void expandInputRowsKernel(
}

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif

// Pad zeros in the extra SFs along the N dimension, we do this to ensure there are no nan values
Expand Down Expand Up @@ -1710,7 +1710,7 @@ __global__ void finalizeMoeRoutingKernel(
auto* reduced_row_ptr_v = reinterpret_cast<OutputElem*>(reduced_row_ptr);

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif

#pragma unroll
Expand Down Expand Up @@ -1746,7 +1746,7 @@ __global__ void finalizeMoeRoutingKernel(
reduced_row_ptr_v[elem_index] = output_elem;
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

Expand All @@ -1766,7 +1766,7 @@ __global__ void finalizeMoeRoutingNoFillingKernel(
assert(unpadded_cols <= padded_cols);

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif

int64_t const num_valid_tokens = expert_first_token_offset[num_experts_per_node];
Expand Down Expand Up @@ -1849,7 +1849,7 @@ __global__ void finalizeMoeRoutingNoFillingKernel(
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

Expand Down Expand Up @@ -2101,7 +2101,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
int64_t const num_valid_tokens = expert_first_token_offset[num_experts_per_node];

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif
for (int64_t token = blockIdx.x; token < num_valid_tokens; token += gridDim.x) {
size_t gemm_result_offset = token * inter_size * gated_size_mul;
Expand Down Expand Up @@ -2216,7 +2216,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
}

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif

// Pad zeros in the extra SFs along the N dimension, we do this to ensure there are no nan values
Expand Down
4 changes: 2 additions & 2 deletions csrc/fused_moe/noAuxTcKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx
int64_t const numExpertsPerGroup,
double const routedScalingFactor) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif

// declare shared memory structure
Expand Down Expand Up @@ -216,7 +216,7 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx
}

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ namespace arch {
CUTLASS_DEVICE
void launch_dependent_grids() {
#if (defined(CUTLASS_GDC_ENABLED))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

Expand All @@ -53,7 +53,7 @@ void launch_dependent_grids() {
CUTLASS_DEVICE
void wait_on_dependent_grids() {
#if (defined(CUTLASS_GDC_ENABLED))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif
}

Expand Down
8 changes: 4 additions & 4 deletions csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ quantize_with_block_size(
int numPaddedColThreads = numPaddedCols / ELTS_PER_THREAD;
int numColThreadsForSf = numColsForSf / ELTS_PER_THREAD;

asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();

// Input tensor batch/row/col loops.
// Optimization: Iterate over actual rows first (hot path), then padding rows (cold path)
Expand Down Expand Up @@ -313,7 +313,7 @@ quantize_with_block_size(
}
}
}
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

Expand Down Expand Up @@ -376,7 +376,7 @@ quantize_with_block_size_tma(
int numPaddedRowsForSf = isSfSwizzledLayout ? PadUpFn(numRows, rowTile) : numRows;
int numColsForSf = isSfSwizzledLayout ? PadUpFn(numPaddedCols, 4 * SF_VEC_SIZE) : numPaddedCols;

asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();

// TMA barrier initialization.
if (warpIdx == 0 and laneIdx == 0) {
Expand Down Expand Up @@ -501,7 +501,7 @@ quantize_with_block_size_tma(
}
}
}
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

Expand Down
4 changes: 2 additions & 2 deletions csrc/xqa/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -768,13 +768,13 @@ __device__ inline bool warpElectSync() {

__device__ inline void preExit() {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("griddepcontrol.launch_dependents;\n");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

__device__ inline void acqBulk() {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("griddepcontrol.wait;\n");
cudaGridDependencySynchronize();
#endif
}
Comment on lines 769 to 779
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Programmatic launch API usage and guards =="
rg -n -C2 'cudaGridDependencySynchronize|cudaTriggerProgrammaticLaunchCompletion|__CUDACC_VER_MAJOR__' csrc include

echo
echo "== Repository CUDA-version references =="
fd '^(README.*|CMakeLists\.txt|.*\.cmake|Dockerfile.*)$' . -X rg -n 'CUDA|11\.8|12\.' {} 2>/dev/null || true

Repository: flashinfer-ai/flashinfer

Length of output: 50380


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Repository CUDA version constraints =="
find . -maxdepth 2 \( -name "CMakeLists.txt" -o -name "README.md" -o -name "setup.py" -o -name "*.cmake" \) -type f | head -20 | while read f; do
  echo "--- $f ---"
  grep -i "cuda" "$f" | grep -E "(11|12|version|minimum)" | head -10 || true
done

echo
echo "== CUDA version checks in recently modified kernels =="
rg -n "__CUDACC_VER" csrc/fused_moe/ csrc/xqa/ include/flashinfer/ -A1 -B1 | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 7420


Restore the CUDA 12 guard around these runtime calls.

Lines 771 and 777 switch from inline PTX to CUDA runtime APIs (cudaTriggerProgrammaticLaunchCompletion and cudaGridDependencySynchronize), but only guard on __CUDA_ARCH__. The repository officially supports CUDA 12.6+ and every other use of these APIs in the codebase gates them with __CUDACC_VER_MAJOR__ >= 12 (see include/flashinfer/sampling.cuh, include/flashinfer/norm.cuh, include/flashinfer/trtllm/fmha/lse.cuh). Without the version guard, compilation can fail on older toolchains.

Suggested patch
 __device__ inline void preExit() {
-#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
+#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
   cudaTriggerProgrammaticLaunchCompletion();
 `#endif`
 }
 
 __device__ inline void acqBulk() {
-#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
+#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
   cudaGridDependencySynchronize();
 `#endif`
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
__device__ inline void preExit() {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("griddepcontrol.launch_dependents;\n");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
__device__ inline void acqBulk() {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("griddepcontrol.wait;\n");
cudaGridDependencySynchronize();
#endif
}
__device__ inline void preExit() {
`#if` (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
`#endif`
}
__device__ inline void acqBulk() {
`#if` (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
`#endif`
}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/xqa/utils.cuh` around lines 769 - 779, The runtime calls in preExit()
and acqBulk() are only guarded by __CUDA_ARCH__ but must also be gated by the
CUDA compiler version; wrap the cudaTriggerProgrammaticLaunchCompletion() call
in preExit and cudaGridDependencySynchronize() call in acqBulk with an
additional compile-time check for __CUDACC_VER_MAJOR__ >= 12 (i.e. require both
(__CUDA_ARCH__ >= 900) and (__CUDACC_VER_MAJOR__ >= 12)) so these APIs are only
used when the toolchain supports CUDA 12+; update the preExit and acqBulk macros
accordingly to match the pattern used elsewhere (e.g.,
include/flashinfer/*.cuh).


Expand Down
4 changes: 2 additions & 2 deletions include/flashinfer/activation.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ __global__ void act_and_mul_kernel(T* __restrict__ out, const T* __restrict__ in
const int64_t offset = token_idx * 2 * d;

#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif

#pragma unroll 1
Expand All @@ -59,7 +59,7 @@ __global__ void act_and_mul_kernel(T* __restrict__ out, const T* __restrict__ in
}

#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {

CUTLASS_DEVICE void operator()(const Params& params, char* smem) {
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif

TileScheduler tile_scheduler{params.tile_scheduler};
Expand Down
2 changes: 1 addition & 1 deletion include/flashinfer/attention/blackwell/plan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ __global__ void plan_kernel(int* qo_segment_offsets, int* kv_segment_offsets, in
}
}
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

Expand Down
8 changes: 4 additions & 4 deletions include/flashinfer/attention/cascade.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ __global__ void PersistentVariableLengthMergeStatesKernel(
float* s_smem = (float*)(smem + num_smem_stages * bdy * head_dim * sizeof(DTypeIn));

#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif

#pragma unroll 1
Expand Down Expand Up @@ -462,7 +462,7 @@ __global__ void PersistentVariableLengthMergeStatesKernel(
}
}
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

Expand All @@ -485,7 +485,7 @@ __global__ void PersistentVariableLengthAttentionSumKernel(DTypeIn* __restrict__

vec_t<float, vec_size> v_sum_vec;
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif

#pragma unroll 1
Expand Down Expand Up @@ -548,7 +548,7 @@ __global__ void PersistentVariableLengthAttentionSumKernel(DTypeIn* __restrict__
v_sum_vec.cast_store(v_sum + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
}
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

Expand Down
Loading