From 36be85535b08f9cc567d640ad407a78ff4346f6e Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Wed, 22 Oct 2025 14:35:59 +0000 Subject: [PATCH 1/9] [Feature] Add memory_order PTX for vectorized (2x) atomic add --- src/tl_templates/cuda/atomic.h | 81 +++++++++++++++++++++++++++++----- 1 file changed, 69 insertions(+), 12 deletions(-) diff --git a/src/tl_templates/cuda/atomic.h b/src/tl_templates/cuda/atomic.h index 4ee85a1ad..8d3d8a2e3 100644 --- a/src/tl_templates/cuda/atomic.h +++ b/src/tl_templates/cuda/atomic.h @@ -6,6 +6,7 @@ #include #include +#include using cutlass::bfloat16_t; using cutlass::half_t; @@ -45,8 +46,8 @@ TL_DEVICE void AtomicMax(T1 &ref, T2 val, int memory_order = int(cuda::memory_order_relaxed)) { using NT1 = typename normalize_atomic_type::type; T1 *address = &ref; - if constexpr (std::is_same_v || - std::is_same_v) { + if constexpr ((std::is_same_v || + std::is_same_v) && memory_order == int(cuda::memory_order_relaxed)) { atomicMax(reinterpret_cast(address), static_cast(val)); } else { cuda::atomic_ref aref(*address); @@ -59,8 +60,8 @@ TL_DEVICE T1 AtomicMaxRet(T1 &ref, T2 val, int memory_order = int(cuda::memory_order_relaxed)) { using NT1 = typename normalize_atomic_type::type; T1 *address = &ref; - if constexpr (std::is_same_v || - std::is_same_v) { + if constexpr ((std::is_same_v || + std::is_same_v) && memory_order == int(cuda::memory_order_relaxed)) { return static_cast( atomicMax(reinterpret_cast(address), static_cast(val))); } else { @@ -75,8 +76,8 @@ TL_DEVICE void AtomicMin(T1 &ref, T2 val, int memory_order = int(cuda::memory_order_relaxed)) { using NT1 = typename normalize_atomic_type::type; T1 *address = &ref; - if constexpr (std::is_same_v || - std::is_same_v) { + if constexpr ((std::is_same_v || + std::is_same_v) && memory_order == int(cuda::memory_order_relaxed)) { atomicMin(reinterpret_cast(address), static_cast(val)); } else { cuda::atomic_ref aref(*address); @@ -89,8 +90,8 @@ TL_DEVICE T1 AtomicMinRet(T1 &ref, T2 val, int memory_order = int(cuda::memory_order_relaxed)) { using NT1 = typename normalize_atomic_type::type; T1 *address = &ref; - if constexpr (std::is_same_v || - std::is_same_v) { + if constexpr ((std::is_same_v || + std::is_same_v) && memory_order == int(cuda::memory_order_relaxed)) { return static_cast( atomicMin(reinterpret_cast(address), static_cast(val))); } else { @@ -135,15 +136,71 @@ TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val, // TODO add memory_order for vectorized atomic add TL_DEVICE void AtomicAddx2(half_t *ref, half_t *val, int memory_order = int(cuda::memory_order_relaxed)) { - atomicAdd(reinterpret_cast(ref), - static_cast(*reinterpret_cast(val))); + if (memory_order == int(cuda::memory_order_relaxed)) { + atomicAdd(reinterpret_cast(ref), + static_cast(*reinterpret_cast(val))); + } else { + // Since atomicAdd does not support memory order, atomic_ref does not support vectorized atomic operation + // we can only inline ptx code here + // Note: Vectorized atomic operations only support global space + __half2 add_val = *reinterpret_cast<__half2 *>(val); + unsigned short add_val_x_cast = *reinterpret_cast(&add_val.x); + unsigned short add_val_y_cast = *reinterpret_cast(&add_val.y); + unsigned long long ref_addr = reinterpret_cast(ref); + __half ret_val_x, ret_val_y; + unsigned short ret_val_x_cast = *reinterpret_cast(&ret_val_x); + unsigned short ret_val_y_cast = *reinterpret_cast(&ret_val_y); + if (memory_order == int(cuda::memory_order_release) || memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.acquire.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } + } } TL_DEVICE half2 AtomicAddx2Ret(half_t *ref, half_t *val, int memory_order = int(cuda::memory_order_relaxed)) { - return atomicAdd(reinterpret_cast(ref), - static_cast(*reinterpret_cast(val))); + if (memory_order == int(cuda::memory_order_relaxed)) { + return atomicAdd(reinterpret_cast(ref), + static_cast(*reinterpret_cast(val))); + } else { + __half2 add_val = *reinterpret_cast<__half2 *>(val); + unsigned short add_val_x_cast = *reinterpret_cast(&add_val.x); + unsigned short add_val_y_cast = *reinterpret_cast(&add_val.y); + unsigned long long ref_addr = reinterpret_cast(ref); + __half ret_val_x, ret_val_y; + unsigned short ret_val_x_cast = *reinterpret_cast(&ret_val_x); + unsigned short ret_val_y_cast = *reinterpret_cast(&ret_val_y); + if (memory_order == int(cuda::memory_order_release) || memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.acquire.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } + return half2(*reinterpret_cast<__half *>(&ret_val_x_cast), *reinterpret_cast<__half *>(&ret_val_y_cast)); + } } #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750)) From 5ced7762b52041d10e18e560f6d25157ff09956c Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Thu, 23 Oct 2025 05:54:58 +0000 Subject: [PATCH 2/9] [Feature] Add memory_order PTX for all vectorized atomic add --- src/tl_templates/cuda/atomic.h | 171 ++++++++++++++++++++++++++++++--- 1 file changed, 159 insertions(+), 12 deletions(-) diff --git a/src/tl_templates/cuda/atomic.h b/src/tl_templates/cuda/atomic.h index 8d3d8a2e3..203b14cf0 100644 --- a/src/tl_templates/cuda/atomic.h +++ b/src/tl_templates/cuda/atomic.h @@ -143,6 +143,7 @@ TL_DEVICE void AtomicAddx2(half_t *ref, half_t *val, // Since atomicAdd does not support memory order, atomic_ref does not support vectorized atomic operation // we can only inline ptx code here // Note: Vectorized atomic operations only support global space + // Note: for 16-bit value, we need to reinterpret_cast the value to unsigned short and use "h" register in assembly __half2 add_val = *reinterpret_cast<__half2 *>(val); unsigned short add_val_x_cast = *reinterpret_cast(&add_val.x); unsigned short add_val_y_cast = *reinterpret_cast(&add_val.y); @@ -206,45 +207,191 @@ AtomicAddx2Ret(half_t *ref, half_t *val, #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750)) TL_DEVICE void AtomicAddx2(bfloat16_t *ref, bfloat16_t *val, int memory_order = int(cuda::memory_order_relaxed)) { - atomicAdd( + if (memory_order == int(cuda::memory_order_relaxed)) { + atomicAdd( reinterpret_cast<__nv_bfloat162 *>(ref), static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val))); + } else { + __nv_bfloat162 add_val = *reinterpret_cast<__nv_bfloat162 *>(val); + unsigned short add_val_x_cast = *reinterpret_cast(&add_val.x); + unsigned short add_val_y_cast = *reinterpret_cast(&add_val.y); + unsigned long long ref_addr = reinterpret_cast(ref); + __nv_bfloat162 ret_val; + unsigned short ret_val_x_cast = *reinterpret_cast(&ret_val.x); + unsigned short ret_val_y_cast = *reinterpret_cast(&ret_val.y); + if (memory_order == int(cuda::memory_order_release) || memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.acquire.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } + } } TL_DEVICE __nv_bfloat162 AtomicAddx2Ret(bfloat16_t *ref, bfloat16_t *val, int memory_order = int(cuda::memory_order_relaxed)) { - return atomicAdd( - reinterpret_cast<__nv_bfloat162 *>(ref), - static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val))); + if (memory_order == int(cuda::memory_order_relaxed)) { + return atomicAdd( + reinterpret_cast<__nv_bfloat162 *>(ref), + static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val))); + } else { + __nv_bfloat162 add_val = *reinterpret_cast<__nv_bfloat162 *>(val); + unsigned short add_val_x_cast = *reinterpret_cast(&add_val.x); + unsigned short add_val_y_cast = *reinterpret_cast(&add_val.y); + unsigned long long ref_addr = reinterpret_cast(ref); + __nv_bfloat162 ret_val; + unsigned short ret_val_x_cast = *reinterpret_cast(&ret_val.x); + unsigned short ret_val_y_cast = *reinterpret_cast(&ret_val.y); + if (memory_order == int(cuda::memory_order_release) || memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.acquire.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } + return __nv_bfloat162(*reinterpret_cast<__nv_bfloat16 *>(&ret_val_x_cast), *reinterpret_cast<__nv_bfloat16 *>(&ret_val_y_cast)); + } } #endif #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900)) TL_DEVICE void AtomicAddx2(float *ref, float *val, int memory_order = int(cuda::memory_order_relaxed)) { - atomicAdd(reinterpret_cast(ref), - static_cast(*reinterpret_cast(val))); + if (memory_order == int(cuda::memory_order_relaxed)) { + atomicAdd(reinterpret_cast(ref), + static_cast(*reinterpret_cast(val))); + } else { + float2 add_val = *reinterpret_cast(val); + unsigned long long ref_addr = reinterpret_cast(ref); + float2 ret_val; + if (memory_order == int(cuda::memory_order_release) || memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" + : "=f"(ret_val.x), "=f"(ret_val.y) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.acquire.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" + : "=f"(ret_val.x), "=f"(ret_val.y) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" + : "=f"(ret_val.x), "=f"(ret_val.y) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) + : "memory"); + } + } } TL_DEVICE float2 AtomicAddx2Ret(float *ref, float *val, int memory_order = int(cuda::memory_order_relaxed)) { - return atomicAdd(reinterpret_cast(ref), - static_cast(*reinterpret_cast(val))); + if (memory_order == int(cuda::memory_order_relaxed)) { + return atomicAdd(reinterpret_cast(ref), + static_cast(*reinterpret_cast(val))); + } else { + float2 add_val = *reinterpret_cast(val); + unsigned long long ref_addr = reinterpret_cast(ref); + float2 ret_val; + if (memory_order == int(cuda::memory_order_release) || memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" + : "=f"(ret_val.x), "=f"(ret_val.y) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.acquire.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" + : "=f"(ret_val.x), "=f"(ret_val.y) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" + : "=f"(ret_val.x), "=f"(ret_val.y) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) + : "memory"); + } + return ret_val; + } } TL_DEVICE void AtomicAddx4(float *ref, float *val, int memory_order = int(cuda::memory_order_relaxed)) { - atomicAdd(reinterpret_cast(ref), - static_cast(*reinterpret_cast(val))); + if (memory_order == int(cuda::memory_order_relaxed)) { + atomicAdd(reinterpret_cast(ref), + static_cast(*reinterpret_cast(val))); + } else { + // Since atomicAdd does not support memory order, atomic_ref does not support vectorized atomic operation + // we can only inline ptx code here + // Note: Vectorized atomic operations only support global space + float4 add_val = *reinterpret_cast(val); + unsigned long long ref_addr = reinterpret_cast(ref); + float4 ret_val; + if (memory_order == int(cuda::memory_order_release) || memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.v4.f32 {%0,%1,%2,%3}, [%4], {%5,%6,%7,%8};" + : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), "=f"(ret_val.w) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), "f"(add_val.z), "f"(add_val.w) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.acquire.gpu.global.add.v4.f32 {%0,%1,%2,%3}, [%4], {%5,%6,%7,%8};" + : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), "=f"(ret_val.w) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), "f"(add_val.z), "f"(add_val.w) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.v4.f32 {%0,%1,%2,%3}, [%4], {%5,%6,%7,%8};" + : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), "=f"(ret_val.w) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), "f"(add_val.z), "f"(add_val.w) + : "memory"); + } + } } TL_DEVICE float4 AtomicAddx4Ret(float *ref, float *val, int memory_order = int(cuda::memory_order_relaxed)) { - return atomicAdd(reinterpret_cast(ref), - static_cast(*reinterpret_cast(val))); + if (memory_order == int(cuda::memory_order_relaxed)) { + return atomicAdd(reinterpret_cast(ref), + static_cast(*reinterpret_cast(val))); + } else { + float4 add_val = *reinterpret_cast(val); + unsigned long long ref_addr = reinterpret_cast(ref); + float4 ret_val; + if (memory_order == int(cuda::memory_order_release) || memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.global.gpu.release.add.v4.f32 {%0,%1,%2,%3}, [%4], {%5,%6,%7,%8};" + : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), "=f"(ret_val.w) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), "f"(add_val.z), "f"(add_val.w) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acquire)) { + asm volatile("atom.global.gpu.acquire.add.v4.f32 {%0,%1,%2,%3}, [%4], {%5,%6,%7,%8};" + : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), "=f"(ret_val.w) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), "f"(add_val.z), "f"(add_val.w) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.global.gpu.acq_rel.add.v4.f32 {%0,%1,%2,%3}, [%4], {%5,%6,%7,%8};" + : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), "=f"(ret_val.w) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), "f"(add_val.z), "f"(add_val.w) + : "memory"); + } + return ret_val; + } } #endif From c65c9d3536b1e37b663cfa983e0ac7b4e891e194 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Thu, 23 Oct 2025 05:58:05 +0000 Subject: [PATCH 3/9] [Lint] --- src/tl_templates/cuda/atomic.h | 239 +++++++++++++++++++++------------ 1 file changed, 151 insertions(+), 88 deletions(-) diff --git a/src/tl_templates/cuda/atomic.h b/src/tl_templates/cuda/atomic.h index 203b14cf0..82eeccfda 100644 --- a/src/tl_templates/cuda/atomic.h +++ b/src/tl_templates/cuda/atomic.h @@ -5,8 +5,8 @@ #endif #include -#include #include +#include using cutlass::bfloat16_t; using cutlass::half_t; @@ -47,7 +47,8 @@ TL_DEVICE void AtomicMax(T1 &ref, T2 val, using NT1 = typename normalize_atomic_type::type; T1 *address = &ref; if constexpr ((std::is_same_v || - std::is_same_v) && memory_order == int(cuda::memory_order_relaxed)) { + std::is_same_v) && + memory_order == int(cuda::memory_order_relaxed)) { atomicMax(reinterpret_cast(address), static_cast(val)); } else { cuda::atomic_ref aref(*address); @@ -61,7 +62,8 @@ TL_DEVICE T1 AtomicMaxRet(T1 &ref, T2 val, using NT1 = typename normalize_atomic_type::type; T1 *address = &ref; if constexpr ((std::is_same_v || - std::is_same_v) && memory_order == int(cuda::memory_order_relaxed)) { + std::is_same_v) && + memory_order == int(cuda::memory_order_relaxed)) { return static_cast( atomicMax(reinterpret_cast(address), static_cast(val))); } else { @@ -77,7 +79,8 @@ TL_DEVICE void AtomicMin(T1 &ref, T2 val, using NT1 = typename normalize_atomic_type::type; T1 *address = &ref; if constexpr ((std::is_same_v || - std::is_same_v) && memory_order == int(cuda::memory_order_relaxed)) { + std::is_same_v) && + memory_order == int(cuda::memory_order_relaxed)) { atomicMin(reinterpret_cast(address), static_cast(val)); } else { cuda::atomic_ref aref(*address); @@ -91,7 +94,8 @@ TL_DEVICE T1 AtomicMinRet(T1 &ref, T2 val, using NT1 = typename normalize_atomic_type::type; T1 *address = &ref; if constexpr ((std::is_same_v || - std::is_same_v) && memory_order == int(cuda::memory_order_relaxed)) { + std::is_same_v) && + memory_order == int(cuda::memory_order_relaxed)) { return static_cast( atomicMin(reinterpret_cast(address), static_cast(val))); } else { @@ -140,32 +144,42 @@ TL_DEVICE void AtomicAddx2(half_t *ref, half_t *val, atomicAdd(reinterpret_cast(ref), static_cast(*reinterpret_cast(val))); } else { - // Since atomicAdd does not support memory order, atomic_ref does not support vectorized atomic operation - // we can only inline ptx code here + // Since atomicAdd does not support memory order, atomic_ref does not + // support vectorized atomic operation we can only inline ptx code here // Note: Vectorized atomic operations only support global space - // Note: for 16-bit value, we need to reinterpret_cast the value to unsigned short and use "h" register in assembly + // Note: for 16-bit value, we need to reinterpret_cast the value to unsigned + // short and use "h" register in assembly __half2 add_val = *reinterpret_cast<__half2 *>(val); - unsigned short add_val_x_cast = *reinterpret_cast(&add_val.x); - unsigned short add_val_y_cast = *reinterpret_cast(&add_val.y); + unsigned short add_val_x_cast = + *reinterpret_cast(&add_val.x); + unsigned short add_val_y_cast = + *reinterpret_cast(&add_val.y); unsigned long long ref_addr = reinterpret_cast(ref); __half ret_val_x, ret_val_y; - unsigned short ret_val_x_cast = *reinterpret_cast(&ret_val_x); - unsigned short ret_val_y_cast = *reinterpret_cast(&ret_val_y); - if (memory_order == int(cuda::memory_order_release) || memory_order == int(cuda::memory_order_consume)) { - asm volatile("atom.release.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" - : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) - : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) - : "memory"); + unsigned short ret_val_x_cast = + *reinterpret_cast(&ret_val_x); + unsigned short ret_val_y_cast = + *reinterpret_cast(&ret_val_y); + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile( + "atom.release.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); } else if (memory_order == int(cuda::memory_order_acquire)) { - asm volatile("atom.acquire.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" - : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) - : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acq_rel) || memory_order == int(cuda::memory_order_seq_cst)) { - asm volatile("atom.acq_rel.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" - : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) - : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) - : "memory"); + asm volatile( + "atom.acquire.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile( + "atom.acq_rel.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); } } } @@ -178,29 +192,39 @@ AtomicAddx2Ret(half_t *ref, half_t *val, static_cast(*reinterpret_cast(val))); } else { __half2 add_val = *reinterpret_cast<__half2 *>(val); - unsigned short add_val_x_cast = *reinterpret_cast(&add_val.x); - unsigned short add_val_y_cast = *reinterpret_cast(&add_val.y); + unsigned short add_val_x_cast = + *reinterpret_cast(&add_val.x); + unsigned short add_val_y_cast = + *reinterpret_cast(&add_val.y); unsigned long long ref_addr = reinterpret_cast(ref); __half ret_val_x, ret_val_y; - unsigned short ret_val_x_cast = *reinterpret_cast(&ret_val_x); - unsigned short ret_val_y_cast = *reinterpret_cast(&ret_val_y); - if (memory_order == int(cuda::memory_order_release) || memory_order == int(cuda::memory_order_consume)) { - asm volatile("atom.release.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" - : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) - : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) - : "memory"); + unsigned short ret_val_x_cast = + *reinterpret_cast(&ret_val_x); + unsigned short ret_val_y_cast = + *reinterpret_cast(&ret_val_y); + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile( + "atom.release.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); } else if (memory_order == int(cuda::memory_order_acquire)) { - asm volatile("atom.acquire.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" - : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) - : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) - : "memory"); - } else if (memory_order == int(cuda::memory_order_acq_rel) || memory_order == int(cuda::memory_order_seq_cst)) { - asm volatile("atom.acq_rel.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" - : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) - : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) - : "memory"); + asm volatile( + "atom.acquire.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile( + "atom.acq_rel.gpu.global.add.noftz.v2.f16 {%0,%1}, [%2], {%3,%4};" + : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) + : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) + : "memory"); } - return half2(*reinterpret_cast<__half *>(&ret_val_x_cast), *reinterpret_cast<__half *>(&ret_val_y_cast)); + return half2(*reinterpret_cast<__half *>(&ret_val_x_cast), + *reinterpret_cast<__half *>(&ret_val_y_cast)); } } @@ -209,17 +233,22 @@ TL_DEVICE void AtomicAddx2(bfloat16_t *ref, bfloat16_t *val, int memory_order = int(cuda::memory_order_relaxed)) { if (memory_order == int(cuda::memory_order_relaxed)) { atomicAdd( - reinterpret_cast<__nv_bfloat162 *>(ref), - static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val))); + reinterpret_cast<__nv_bfloat162 *>(ref), + static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val))); } else { __nv_bfloat162 add_val = *reinterpret_cast<__nv_bfloat162 *>(val); - unsigned short add_val_x_cast = *reinterpret_cast(&add_val.x); - unsigned short add_val_y_cast = *reinterpret_cast(&add_val.y); + unsigned short add_val_x_cast = + *reinterpret_cast(&add_val.x); + unsigned short add_val_y_cast = + *reinterpret_cast(&add_val.y); unsigned long long ref_addr = reinterpret_cast(ref); __nv_bfloat162 ret_val; - unsigned short ret_val_x_cast = *reinterpret_cast(&ret_val.x); - unsigned short ret_val_y_cast = *reinterpret_cast(&ret_val.y); - if (memory_order == int(cuda::memory_order_release) || memory_order == int(cuda::memory_order_consume)) { + unsigned short ret_val_x_cast = + *reinterpret_cast(&ret_val.x); + unsigned short ret_val_y_cast = + *reinterpret_cast(&ret_val.y); + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { asm volatile("atom.release.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) @@ -229,7 +258,8 @@ TL_DEVICE void AtomicAddx2(bfloat16_t *ref, bfloat16_t *val, : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) : "memory"); - } else if (memory_order == int(cuda::memory_order_acq_rel) || memory_order == int(cuda::memory_order_seq_cst)) { + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { asm volatile("atom.acq_rel.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) @@ -247,13 +277,18 @@ AtomicAddx2Ret(bfloat16_t *ref, bfloat16_t *val, static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val))); } else { __nv_bfloat162 add_val = *reinterpret_cast<__nv_bfloat162 *>(val); - unsigned short add_val_x_cast = *reinterpret_cast(&add_val.x); - unsigned short add_val_y_cast = *reinterpret_cast(&add_val.y); + unsigned short add_val_x_cast = + *reinterpret_cast(&add_val.x); + unsigned short add_val_y_cast = + *reinterpret_cast(&add_val.y); unsigned long long ref_addr = reinterpret_cast(ref); __nv_bfloat162 ret_val; - unsigned short ret_val_x_cast = *reinterpret_cast(&ret_val.x); - unsigned short ret_val_y_cast = *reinterpret_cast(&ret_val.y); - if (memory_order == int(cuda::memory_order_release) || memory_order == int(cuda::memory_order_consume)) { + unsigned short ret_val_x_cast = + *reinterpret_cast(&ret_val.x); + unsigned short ret_val_y_cast = + *reinterpret_cast(&ret_val.y); + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { asm volatile("atom.release.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) @@ -263,13 +298,15 @@ AtomicAddx2Ret(bfloat16_t *ref, bfloat16_t *val, : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) : "memory"); - } else if (memory_order == int(cuda::memory_order_acq_rel) || memory_order == int(cuda::memory_order_seq_cst)) { + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { asm volatile("atom.acq_rel.gpu.global.add.v2.bf16 {%0,%1}, [%2], {%3,%4};" : "=h"(ret_val_x_cast), "=h"(ret_val_y_cast) : "l"(ref_addr), "h"(add_val_x_cast), "h"(add_val_y_cast) : "memory"); } - return __nv_bfloat162(*reinterpret_cast<__nv_bfloat16 *>(&ret_val_x_cast), *reinterpret_cast<__nv_bfloat16 *>(&ret_val_y_cast)); + return __nv_bfloat162(*reinterpret_cast<__nv_bfloat16 *>(&ret_val_x_cast), + *reinterpret_cast<__nv_bfloat16 *>(&ret_val_y_cast)); } } #endif @@ -284,7 +321,8 @@ TL_DEVICE void AtomicAddx2(float *ref, float *val, float2 add_val = *reinterpret_cast(val); unsigned long long ref_addr = reinterpret_cast(ref); float2 ret_val; - if (memory_order == int(cuda::memory_order_release) || memory_order == int(cuda::memory_order_consume)) { + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { asm volatile("atom.release.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" : "=f"(ret_val.x), "=f"(ret_val.y) : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) @@ -294,7 +332,8 @@ TL_DEVICE void AtomicAddx2(float *ref, float *val, : "=f"(ret_val.x), "=f"(ret_val.y) : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) : "memory"); - } else if (memory_order == int(cuda::memory_order_acq_rel) || memory_order == int(cuda::memory_order_seq_cst)) { + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { asm volatile("atom.acq_rel.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" : "=f"(ret_val.x), "=f"(ret_val.y) : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) @@ -313,7 +352,8 @@ AtomicAddx2Ret(float *ref, float *val, float2 add_val = *reinterpret_cast(val); unsigned long long ref_addr = reinterpret_cast(ref); float2 ret_val; - if (memory_order == int(cuda::memory_order_release) || memory_order == int(cuda::memory_order_consume)) { + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { asm volatile("atom.release.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" : "=f"(ret_val.x), "=f"(ret_val.y) : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) @@ -323,7 +363,8 @@ AtomicAddx2Ret(float *ref, float *val, : "=f"(ret_val.x), "=f"(ret_val.y) : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) : "memory"); - } else if (memory_order == int(cuda::memory_order_acq_rel) || memory_order == int(cuda::memory_order_seq_cst)) { + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { asm volatile("atom.acq_rel.gpu.global.add.v2.f32 {%0,%1}, [%2], {%3,%4};" : "=f"(ret_val.x), "=f"(ret_val.y) : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y) @@ -339,26 +380,37 @@ TL_DEVICE void AtomicAddx4(float *ref, float *val, atomicAdd(reinterpret_cast(ref), static_cast(*reinterpret_cast(val))); } else { - // Since atomicAdd does not support memory order, atomic_ref does not support vectorized atomic operation - // we can only inline ptx code here + // Since atomicAdd does not support memory order, atomic_ref does not + // support vectorized atomic operation we can only inline ptx code here // Note: Vectorized atomic operations only support global space float4 add_val = *reinterpret_cast(val); unsigned long long ref_addr = reinterpret_cast(ref); float4 ret_val; - if (memory_order == int(cuda::memory_order_release) || memory_order == int(cuda::memory_order_consume)) { - asm volatile("atom.release.gpu.global.add.v4.f32 {%0,%1,%2,%3}, [%4], {%5,%6,%7,%8};" - : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), "=f"(ret_val.w) - : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), "f"(add_val.z), "f"(add_val.w) + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.release.gpu.global.add.v4.f32 {%0,%1,%2,%3}, [%4], " + "{%5,%6,%7,%8};" + : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), + "=f"(ret_val.w) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), + "f"(add_val.z), "f"(add_val.w) : "memory"); } else if (memory_order == int(cuda::memory_order_acquire)) { - asm volatile("atom.acquire.gpu.global.add.v4.f32 {%0,%1,%2,%3}, [%4], {%5,%6,%7,%8};" - : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), "=f"(ret_val.w) - : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), "f"(add_val.z), "f"(add_val.w) + asm volatile("atom.acquire.gpu.global.add.v4.f32 {%0,%1,%2,%3}, [%4], " + "{%5,%6,%7,%8};" + : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), + "=f"(ret_val.w) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), + "f"(add_val.z), "f"(add_val.w) : "memory"); - } else if (memory_order == int(cuda::memory_order_acq_rel) || memory_order == int(cuda::memory_order_seq_cst)) { - asm volatile("atom.acq_rel.gpu.global.add.v4.f32 {%0,%1,%2,%3}, [%4], {%5,%6,%7,%8};" - : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), "=f"(ret_val.w) - : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), "f"(add_val.z), "f"(add_val.w) + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.acq_rel.gpu.global.add.v4.f32 {%0,%1,%2,%3}, [%4], " + "{%5,%6,%7,%8};" + : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), + "=f"(ret_val.w) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), + "f"(add_val.z), "f"(add_val.w) : "memory"); } } @@ -374,20 +426,31 @@ AtomicAddx4Ret(float *ref, float *val, float4 add_val = *reinterpret_cast(val); unsigned long long ref_addr = reinterpret_cast(ref); float4 ret_val; - if (memory_order == int(cuda::memory_order_release) || memory_order == int(cuda::memory_order_consume)) { - asm volatile("atom.global.gpu.release.add.v4.f32 {%0,%1,%2,%3}, [%4], {%5,%6,%7,%8};" - : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), "=f"(ret_val.w) - : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), "f"(add_val.z), "f"(add_val.w) + if (memory_order == int(cuda::memory_order_release) || + memory_order == int(cuda::memory_order_consume)) { + asm volatile("atom.global.gpu.release.add.v4.f32 {%0,%1,%2,%3}, [%4], " + "{%5,%6,%7,%8};" + : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), + "=f"(ret_val.w) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), + "f"(add_val.z), "f"(add_val.w) : "memory"); } else if (memory_order == int(cuda::memory_order_acquire)) { - asm volatile("atom.global.gpu.acquire.add.v4.f32 {%0,%1,%2,%3}, [%4], {%5,%6,%7,%8};" - : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), "=f"(ret_val.w) - : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), "f"(add_val.z), "f"(add_val.w) + asm volatile("atom.global.gpu.acquire.add.v4.f32 {%0,%1,%2,%3}, [%4], " + "{%5,%6,%7,%8};" + : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), + "=f"(ret_val.w) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), + "f"(add_val.z), "f"(add_val.w) : "memory"); - } else if (memory_order == int(cuda::memory_order_acq_rel) || memory_order == int(cuda::memory_order_seq_cst)) { - asm volatile("atom.global.gpu.acq_rel.add.v4.f32 {%0,%1,%2,%3}, [%4], {%5,%6,%7,%8};" - : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), "=f"(ret_val.w) - : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), "f"(add_val.z), "f"(add_val.w) + } else if (memory_order == int(cuda::memory_order_acq_rel) || + memory_order == int(cuda::memory_order_seq_cst)) { + asm volatile("atom.global.gpu.acq_rel.add.v4.f32 {%0,%1,%2,%3}, [%4], " + "{%5,%6,%7,%8};" + : "=f"(ret_val.x), "=f"(ret_val.y), "=f"(ret_val.z), + "=f"(ret_val.w) + : "l"(ref_addr), "f"(add_val.x), "f"(add_val.y), + "f"(add_val.z), "f"(add_val.w) : "memory"); } return ret_val; From 0f4b8c74a3a08ddd7f68087b87fb5c5af8e865d5 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Thu, 23 Oct 2025 06:05:24 +0000 Subject: [PATCH 4/9] test --- tilelang/language/allocate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index 55e1fdfd5..d92a24380 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -14,6 +14,7 @@ with the appropriate memory scope. """ +from __future__ import annotations from tilelang import tvm as tvm from tvm.script import tir as T from tvm.tir import PrimExpr From 0185eb537d4f62f43766e73a00947a5529641e53 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Thu, 23 Oct 2025 06:06:51 +0000 Subject: [PATCH 5/9] [BugFix] FIx init optional argument in alloc_var --- tilelang/language/allocate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index d92a24380..8c51691fb 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -19,7 +19,7 @@ from tvm.script import tir as T from tvm.tir import PrimExpr from tvm.script.parser.tir import block_attr -from typing import Union +from typing import Optional def alloc_shared(shape, dtype, scope="shared.dyn"): @@ -68,7 +68,7 @@ def alloc_fragment(shape, dtype, scope="local.fragment"): return T.alloc_buffer(shape, dtype, scope=scope) -def alloc_var(dtype, *args, scope="local.var", init: Union[PrimExpr] = None): +def alloc_var(dtype, *args, scope="local.var", init: Optional[PrimExpr] = None): """Allocate a single-element variable buffer. Args: From f10dd820028361ac1843845b2d8d418fb948cefc Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Thu, 23 Oct 2025 06:07:27 +0000 Subject: [PATCH 6/9] bug fix --- tilelang/language/allocate.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index 8c51691fb..7403a7d3e 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -14,7 +14,6 @@ with the appropriate memory scope. """ -from __future__ import annotations from tilelang import tvm as tvm from tvm.script import tir as T from tvm.tir import PrimExpr From 694f2019d703b2d491747adaa204e6e09dd98b5f Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Thu, 23 Oct 2025 06:12:21 +0000 Subject: [PATCH 7/9] bug fix --- tilelang/language/allocate.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index 7403a7d3e..924ac5d7d 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -14,6 +14,8 @@ with the appropriate memory scope. """ +from __future__ import annotations + from tilelang import tvm as tvm from tvm.script import tir as T from tvm.tir import PrimExpr From 5097d34fda6b4830c1b5349c363e87a3881d0561 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Thu, 23 Oct 2025 06:13:03 +0000 Subject: [PATCH 8/9] lint fix --- tilelang/language/allocate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index 924ac5d7d..4e03aec88 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -69,7 +69,7 @@ def alloc_fragment(shape, dtype, scope="local.fragment"): return T.alloc_buffer(shape, dtype, scope=scope) -def alloc_var(dtype, *args, scope="local.var", init: Optional[PrimExpr] = None): +def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None): """Allocate a single-element variable buffer. Args: From e27534717d0fc5747f3fe85b406e06ff06a01975 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Thu, 23 Oct 2025 06:13:24 +0000 Subject: [PATCH 9/9] lint fix --- tilelang/language/allocate.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index 4e03aec88..19c22990a 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -20,7 +20,6 @@ from tvm.script import tir as T from tvm.tir import PrimExpr from tvm.script.parser.tir import block_attr -from typing import Optional def alloc_shared(shape, dtype, scope="shared.dyn"):