Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ci/setup_python.env
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@

# Uncomment to override TVM-FFI version:
# TVM_FFI_REF=

# Uncomment to override nvidia-cutlass-dsl version:
# CUTLASS_DSL_VERSION=
44 changes: 38 additions & 6 deletions flashinfer/cute_dsl/gemm_allreduce_two_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def spin_lock_multimem_arrive(lock_ptr: Pointer, loc=None, ip=None) -> None:


# HACK https://github.com/NVIDIA/cutlass/issues/2845
import functools
import inspect

from cutlass._mlir.dialects import nvvm
from cutlass.cutlass_dsl import T
from cutlass._mlir.dialects.nvvm import (
Expand All @@ -39,6 +42,35 @@ def spin_lock_multimem_arrive(lock_ptr: Pointer, loc=None, ip=None) -> None:
)


@functools.lru_cache(maxsize=None)
def _nvvm_atomicrmw_has_res_param():
return "res" in inspect.signature(nvvm.atomicrmw).parameters


def _nvvm_atomicrmw_compat(
res_type, op, ptr, a, *, b=None, mem_order=None, syncscope=None, loc=None, ip=None
):
"""Call nvvm.atomicrmw compatible with both CUDA 12 and CUDA 13."""
if _nvvm_atomicrmw_has_res_param():
# CUDA 12: nvvm.atomicrmw(res, op, ptr, a, ...)
return nvvm.atomicrmw(
res_type,
op,
ptr,
a,
b=b,
mem_order=mem_order,
syncscope=syncscope,
loc=loc,
ip=ip,
)
else:
# CUDA 13: nvvm.atomicrmw(op, ptr, a, ...) β€” res removed
return nvvm.atomicrmw(
op, ptr, a, b=b, mem_order=mem_order, syncscope=syncscope, loc=loc, ip=ip
)


@cute.jit
def spin_lock_atom_cas_acquire_wait(
lock_ptr: Pointer,
Expand All @@ -55,7 +87,7 @@ def spin_lock_atom_cas_acquire_wait(
if scope == "gpu":
result = 0
while result != expected_val:
result = nvvm.atomicrmw(
result = _nvvm_atomicrmw_compat(
T.i32(),
AtomicOpKind.CAS,
lock_ptr.llvm_ptr,
Expand All @@ -69,7 +101,7 @@ def spin_lock_atom_cas_acquire_wait(
elif scope == "sys":
result = 0
while result != expected_val:
result = nvvm.atomicrmw(
result = _nvvm_atomicrmw_compat(
T.i32(),
AtomicOpKind.CAS,
lock_ptr.llvm_ptr,
Expand All @@ -92,7 +124,7 @@ def sm_wise_inter_gpu_multimem_barrier(
bdimx, bdimy, _ = cute.arch.grid_dim()
pid = bidx + bidy * bdimx + bidz * bdimx * bdimy
distributed.multimem_red_release_sys_add1(barrier_mc + pid, loc=loc, ip=ip)
cute.arch.fence_proxy(cute.arch.ProxyKind.alias)
cute.arch.fence_proxy("alias")

# v4.3.1 does not have mem_order="acquire" variant in `distributed` module
# filed issue https://github.com/NVIDIA/cutlass/issues/2845
Expand Down Expand Up @@ -1251,8 +1283,8 @@ def kernel(
)
# Fence and barrier to make sure shared memory store is visible to TMA store
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
epilog_threads = 32 * len(self.epilog_warp_id)
cute.arch.barrier(
Expand Down Expand Up @@ -1312,7 +1344,7 @@ def kernel(
flag = barrier_flag_mc.iterator + tile_id
cute.arch.fence_acq_rel_gpu()
spin_lock_multimem_arrive(flag)
cute.arch.fence_proxy(cute.arch.ProxyKind.alias)
cute.arch.fence_proxy("alias")

#
# Advance to next tile
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1512,8 +1512,8 @@ def kernel(
sInfo[(4, tile_info_producer_state.index)] = mn_limit
# fence view async shared
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)

self.sched_sync_barrier.arrive_and_wait()
Expand Down Expand Up @@ -1548,8 +1548,8 @@ def kernel(
sInfo[(4, tile_info_producer_state.index)] = mn_limit
# fence view async shared
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)

self.sched_sync_barrier.arrive_and_wait()
Expand All @@ -1569,8 +1569,8 @@ def kernel(
sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(0)
sInfo[(4, tile_info_producer_state.index)] = -1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
self.sched_sync_barrier.arrive_and_wait()
tile_info_pipeline.producer_commit(tile_info_producer_state)
Expand Down Expand Up @@ -1669,8 +1669,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -1844,8 +1844,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -1886,8 +1886,8 @@ def kernel(
valid_tile_info[0] = sInfo[(3, tile_info_consumer_state.index)]
is_valid_tile = valid_tile_info[0] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -1927,8 +1927,8 @@ def kernel(
valid_tile_info[0] = sInfo[(3, tile_info_consumer_state.index)]
is_valid_tile = valid_tile_info[0] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -1968,8 +1968,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -2051,8 +2051,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -2152,8 +2152,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -2368,8 +2368,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -2480,8 +2480,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -2811,8 +2811,8 @@ def kernel(
)
# Fence and barrier to make sure shared memory store is visible to TMA store
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
self.epilog_sync_barrier.arrive_and_wait()
#
Expand Down Expand Up @@ -2845,8 +2845,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1380,8 +1380,8 @@ def kernel(
sInfo[(4, tile_info_producer_state.index)] = mn_limit
# fence view async shared
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)

self.sched_sync_barrier.arrive_and_wait()
Expand Down Expand Up @@ -1416,8 +1416,8 @@ def kernel(
sInfo[(4, tile_info_producer_state.index)] = mn_limit
# fence view async shared
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)

self.sched_sync_barrier.arrive_and_wait()
Expand All @@ -1438,8 +1438,8 @@ def kernel(
sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(0)
sInfo[(4, tile_info_producer_state.index)] = cutlass.Int32(0)
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
self.sched_sync_barrier.arrive_and_wait()
tile_info_pipeline.producer_commit(tile_info_producer_state)
Expand Down Expand Up @@ -1467,8 +1467,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -1573,8 +1573,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -1659,8 +1659,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -1818,8 +1818,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -1886,8 +1886,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down Expand Up @@ -2023,8 +2023,8 @@ def kernel(

if cutlass.const_expr(self.use_blkred):
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
#
# Async arrive accumulator buffer empty
Expand All @@ -2037,8 +2037,8 @@ def kernel(

if cutlass.const_expr(self.use_blkred):
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
if is_valid_row:
coord_n = mma_tile_coord_mnl[1] * self.cta_tile_shape_mnk[1]
Expand Down Expand Up @@ -2073,8 +2073,8 @@ def kernel(
tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
is_valid_tile = tile_info[3] == 1
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
"async.shared",
space="cta",
)
tile_info_pipeline.consumer_release(tile_info_consumer_state)
tile_info_consumer_state.advance()
Expand Down
Loading
Loading