diff --git a/ci/setup_python.env b/ci/setup_python.env index ffd9c49ac2..0305a6329b 100644 --- a/ci/setup_python.env +++ b/ci/setup_python.env @@ -15,3 +15,6 @@ # Uncomment to override TVM-FFI version: # TVM_FFI_REF= + +# Uncomment to override nvidia-cutlass-dsl version: +# CUTLASS_DSL_VERSION= diff --git a/flashinfer/cute_dsl/gemm_allreduce_two_shot.py b/flashinfer/cute_dsl/gemm_allreduce_two_shot.py index baf55468a4..20b0614304 100644 --- a/flashinfer/cute_dsl/gemm_allreduce_two_shot.py +++ b/flashinfer/cute_dsl/gemm_allreduce_two_shot.py @@ -30,6 +30,9 @@ def spin_lock_multimem_arrive(lock_ptr: Pointer, loc=None, ip=None) -> None: # HACK https://github.com/NVIDIA/cutlass/issues/2845 +import functools +import inspect + from cutlass._mlir.dialects import nvvm from cutlass.cutlass_dsl import T from cutlass._mlir.dialects.nvvm import ( @@ -39,6 +42,35 @@ def spin_lock_multimem_arrive(lock_ptr: Pointer, loc=None, ip=None) -> None: ) +@functools.lru_cache(maxsize=None) +def _nvvm_atomicrmw_has_res_param(): + return "res" in inspect.signature(nvvm.atomicrmw).parameters + + +def _nvvm_atomicrmw_compat( + res_type, op, ptr, a, *, b=None, mem_order=None, syncscope=None, loc=None, ip=None +): + """Call nvvm.atomicrmw compatible with both CUDA 12 and CUDA 13.""" + if _nvvm_atomicrmw_has_res_param(): + # CUDA 12: nvvm.atomicrmw(res, op, ptr, a, ...) + return nvvm.atomicrmw( + res_type, + op, + ptr, + a, + b=b, + mem_order=mem_order, + syncscope=syncscope, + loc=loc, + ip=ip, + ) + else: + # CUDA 13: nvvm.atomicrmw(op, ptr, a, ...) — res removed + return nvvm.atomicrmw( + op, ptr, a, b=b, mem_order=mem_order, syncscope=syncscope, loc=loc, ip=ip + ) + + @cute.jit def spin_lock_atom_cas_acquire_wait( lock_ptr: Pointer, @@ -55,7 +87,7 @@ def spin_lock_atom_cas_acquire_wait( if scope == "gpu": result = 0 while result != expected_val: - result = nvvm.atomicrmw( + result = _nvvm_atomicrmw_compat( T.i32(), AtomicOpKind.CAS, lock_ptr.llvm_ptr, @@ -69,7 +101,7 @@ def spin_lock_atom_cas_acquire_wait( elif scope == "sys": result = 0 while result != expected_val: - result = nvvm.atomicrmw( + result = _nvvm_atomicrmw_compat( T.i32(), AtomicOpKind.CAS, lock_ptr.llvm_ptr, @@ -92,7 +124,7 @@ def sm_wise_inter_gpu_multimem_barrier( bdimx, bdimy, _ = cute.arch.grid_dim() pid = bidx + bidy * bdimx + bidz * bdimx * bdimy distributed.multimem_red_release_sys_add1(barrier_mc + pid, loc=loc, ip=ip) - cute.arch.fence_proxy(cute.arch.ProxyKind.alias) + cute.arch.fence_proxy("alias") # v4.3.1 does not have mem_order="acquire" variant in `distributed` module # filed issue https://github.com/NVIDIA/cutlass/issues/2845 @@ -1251,8 +1283,8 @@ def kernel( ) # Fence and barrier to make sure shared memory store is visible to TMA store cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) epilog_threads = 32 * len(self.epilog_warp_id) cute.arch.barrier( @@ -1312,7 +1344,7 @@ def kernel( flag = barrier_flag_mc.iterator + tile_id cute.arch.fence_acq_rel_gpu() spin_lock_multimem_arrive(flag) - cute.arch.fence_proxy(cute.arch.ProxyKind.alias) + cute.arch.fence_proxy("alias") # # Advance to next tile diff --git a/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py b/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py index 10a1f7f822..f8c50c624f 100644 --- a/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py +++ b/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py @@ -1512,8 +1512,8 @@ def kernel( sInfo[(4, tile_info_producer_state.index)] = mn_limit # fence view async shared cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.sched_sync_barrier.arrive_and_wait() @@ -1548,8 +1548,8 @@ def kernel( sInfo[(4, tile_info_producer_state.index)] = mn_limit # fence view async shared cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.sched_sync_barrier.arrive_and_wait() @@ -1569,8 +1569,8 @@ def kernel( sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(0) sInfo[(4, tile_info_producer_state.index)] = -1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.sched_sync_barrier.arrive_and_wait() tile_info_pipeline.producer_commit(tile_info_producer_state) @@ -1669,8 +1669,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1844,8 +1844,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1886,8 +1886,8 @@ def kernel( valid_tile_info[0] = sInfo[(3, tile_info_consumer_state.index)] is_valid_tile = valid_tile_info[0] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1927,8 +1927,8 @@ def kernel( valid_tile_info[0] = sInfo[(3, tile_info_consumer_state.index)] is_valid_tile = valid_tile_info[0] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1968,8 +1968,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -2051,8 +2051,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -2152,8 +2152,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -2368,8 +2368,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -2480,8 +2480,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -2811,8 +2811,8 @@ def kernel( ) # Fence and barrier to make sure shared memory store is visible to TMA store cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.epilog_sync_barrier.arrive_and_wait() # @@ -2845,8 +2845,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() diff --git a/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py b/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py index ce4fb6269b..e07fab4eb6 100644 --- a/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py +++ b/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py @@ -1380,8 +1380,8 @@ def kernel( sInfo[(4, tile_info_producer_state.index)] = mn_limit # fence view async shared cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.sched_sync_barrier.arrive_and_wait() @@ -1416,8 +1416,8 @@ def kernel( sInfo[(4, tile_info_producer_state.index)] = mn_limit # fence view async shared cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.sched_sync_barrier.arrive_and_wait() @@ -1438,8 +1438,8 @@ def kernel( sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(0) sInfo[(4, tile_info_producer_state.index)] = cutlass.Int32(0) cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) self.sched_sync_barrier.arrive_and_wait() tile_info_pipeline.producer_commit(tile_info_producer_state) @@ -1467,8 +1467,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1573,8 +1573,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1659,8 +1659,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1818,8 +1818,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -1886,8 +1886,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() @@ -2023,8 +2023,8 @@ def kernel( if cutlass.const_expr(self.use_blkred): cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) # # Async arrive accumulator buffer empty @@ -2037,8 +2037,8 @@ def kernel( if cutlass.const_expr(self.use_blkred): cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) if is_valid_row: coord_n = mma_tile_coord_mnl[1] * self.cta_tile_shape_mnk[1] @@ -2073,8 +2073,8 @@ def kernel( tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] is_valid_tile = tile_info[3] == 1 cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) tile_info_pipeline.consumer_release(tile_info_consumer_state) tile_info_consumer_state.advance() diff --git a/flashinfer/fused_moe/cute_dsl/blackwell/utils.py b/flashinfer/fused_moe/cute_dsl/blackwell/utils.py index b1c5349de1..b57e260171 100644 --- a/flashinfer/fused_moe/cute_dsl/blackwell/utils.py +++ b/flashinfer/fused_moe/cute_dsl/blackwell/utils.py @@ -44,6 +44,7 @@ # This file is copied and modified from cutlass https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/cute/core.py import ctypes +import functools from typing import Union import cutlass @@ -197,6 +198,13 @@ def is_power_of_2(x: int) -> bool: return x > 0 and (x & (x - 1)) == 0 +@functools.lru_cache(maxsize=None) +def _nvvm_fmin_needs_res(): + import inspect + + return "res" in inspect.signature(nvvm.fmin).parameters + + @dsl_user_op def fmin( a: Union[float, cutlass.Float32], @@ -206,16 +214,15 @@ def fmin( loc=None, ip=None, ) -> cutlass.Float32: - return cutlass.Float32( - nvvm.fmin( - T.f32(), - cutlass.Float32(a).ir_value(loc=loc, ip=ip), - cutlass.Float32(b).ir_value(loc=loc, ip=ip), - nan=nan, - loc=loc, - ip=ip, - ) - ) + a_val = cutlass.Float32(a).ir_value(loc=loc, ip=ip) + b_val = cutlass.Float32(b).ir_value(loc=loc, ip=ip) + if _nvvm_fmin_needs_res(): + # CUDA 12: nvvm.fmin(res, a, b, ...) + result = nvvm.fmin(T.f32(), a_val, b_val, nan=nan, loc=loc, ip=ip) + else: + # CUDA 13: nvvm.fmin(a, b, ...) + result = nvvm.fmin(a_val, b_val, nan=nan, loc=loc, ip=ip) + return cutlass.Float32(result) def sigmoid_f32( diff --git a/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py b/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py index 1475850e46..5710f97fac 100644 --- a/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py +++ b/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py @@ -1469,8 +1469,8 @@ def kernel( ) # Fence and barrier to make sure shared memory store is visible to TMA store cute.arch.fence_proxy( - cute.arch.ProxyKind.async_shared, - space=cute.arch.SharedSpace.shared_cta, + "async.shared", + space="cta", ) epilog_threads = 32 * len(self.epilog_warp_id) cute.arch.barrier( diff --git a/requirements.txt b/requirements.txt index 7dd93c67f7..7eb97a4ab9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ einops ninja numpy nvidia-cudnn-frontend>=1.13.0 -nvidia-cutlass-dsl>=4.3.4 +nvidia-cutlass-dsl>=4.4.2 nvidia-ml-py packaging>=24.2 requests diff --git a/scripts/setup_test_env.sh b/scripts/setup_test_env.sh index 83480cbd6a..5cd61330f1 100755 --- a/scripts/setup_test_env.sh +++ b/scripts/setup_test_env.sh @@ -23,3 +23,22 @@ if [ -n "${TVM_FFI_REF:-}" ]; then echo "TVM-FFI override complete." echo "" fi + +# Override nvidia-cutlass-dsl if specified +if [ -n "${CUTLASS_DSL_VERSION:-}" ]; then + # Detect CUDA major version: only CUDA 13+ needs [cu13] extra + CUDA_MAJOR=$(python -c "import torch; print(torch.version.cuda.split('.')[0])" 2>/dev/null || echo "12") + if [ "$CUDA_MAJOR" = "13" ]; then + CUTLASS_DSL_PKG="nvidia-cutlass-dsl[cu13]==${CUTLASS_DSL_VERSION}" + else + CUTLASS_DSL_PKG="nvidia-cutlass-dsl==${CUTLASS_DSL_VERSION}" + fi + echo "========================================" + echo "Overriding nvidia-cutlass-dsl with: ${CUTLASS_DSL_PKG}" + echo "========================================" + # Clean uninstall old packages first (recommended by NVIDIA docs) + pip uninstall nvidia-cutlass-dsl nvidia-cutlass-dsl-libs-base nvidia-cutlass-dsl-libs-cu12 nvidia-cutlass-dsl-libs-cu13 -y 2>/dev/null || true + pip install "${CUTLASS_DSL_PKG}" + echo "nvidia-cutlass-dsl override complete." + echo "" +fi