Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions src/tl_templates/cuda/atomic.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ using cutlass::bfloat16_t;
using cutlass::half_t;

#define TL_DEVICE __forceinline__ __device__

#define TL_NOT_IMPLEMENTED() \
{ \
printf("%s not implemented\n", __PRETTY_FUNCTION__); \
asm volatile("brkpt;\n"); \
}
template <typename T> struct normalize_atomic_type {
using type = T;
};
Expand Down Expand Up @@ -63,8 +67,12 @@ TL_DEVICE void AtomicMax(T1 &ref, T2 val,
}
}
} else {
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_max(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
#else
TL_NOT_IMPLEMENTED();
#endif
Comment on lines +70 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify CUDA version requirement matches PR objectives.

The code gates atomic operations behind CUDART_VERSION >= 11080 (CUDA 11.8), but the PR title indicates support for CUDA 11.3 (version 11030). Please confirm whether CUDA 11.3 lacks cuda::atomic_ref support, or if the version check should be lowered to 11030.

Also applies to: 100-106, 163-169, 240-245, 318-324, 650-655, 661-666


🌐 Web query:

What CUDA version introduced cuda::atomic_ref support?

💡 Result:

cuda::atomic_ref was introduced in CUDA Toolkit 11.6. [1][2]

Sources:
[1] CUDA Core Compute Libraries — Synchronization Primitives (cuda::atomic_ref listed as CUDA 11.6).
[2] libcu++ 1.7.0 release notes (libcu++ 1.7.0 / CUDA Toolkit 11.6 adds atomic_ref).


Confirm PR target version—CUDA 11.3 cannot support cuda::atomic_ref operations.

The cuda::atomic_ref API was introduced in CUDA Toolkit 11.6, not 11.3. The version gates at CUDART_VERSION >= 11080 (CUDA 11.8) are unnecessarily restrictive, but lowering them to 11030 (CUDA 11.3) will not work—the API didn't exist until 11.6.

Either the PR target version should be CUDA 11.6+, or alternative implementations must be provided for CUDA 11.3-11.5. The current TL_NOT_IMPLEMENTED() fallback would fail at runtime for CUDA 11.3 users.

Applies to: lines 70-75, 100-106, 163-169, 240-245, 318-324, 650-655, 661-666

🤖 Prompt for AI Agents
In src/tl_templates/cuda/atomic.h around lines 70-75, 100-106, 163-169, 240-245,
318-324, 650-655 and 661-666 the code gates use cuda::atomic_ref which was
introduced in CUDA 11.6; update the CUDART_VERSION checks to >= 11060 if we
intend to require CUDA 11.6+ and document that PR target or CI must use 11.6+,
otherwise replace the TL_NOT_IMPLEMENTED() fallbacks with a portable atomic
implementation (e.g. an atomicCAS loop or intrinsics) for CUDA 11.3–11.5 to
perform the same operations, and ensure all version guards and documentation
reflect the chosen approach.

}
}

Expand All @@ -89,9 +97,13 @@ TL_DEVICE T1 AtomicMaxRet(T1 &ref, T2 val,
}
return static_cast<T1>(*reinterpret_cast<T1 *>(&old_val_ushort));
} else {
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
return static_cast<T1>(
aref.fetch_max(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
#else
TL_NOT_IMPLEMENTED();
#endif
}
}

Expand All @@ -117,8 +129,13 @@ TL_DEVICE void AtomicMin(T1 &ref, T2 val,
}
}
} else {
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
return static_cast<T1>(
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
#else
TL_NOT_IMPLEMENTED();
#endif
Comment on lines +132 to +138
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Return statement in void function.

AtomicMin is declared as returning void (line 111), but lines 134-135 contain a return statement with a value. This will cause a compilation error.

Apply this diff to remove the erroneous return statement:

 #if CUDART_VERSION >= 11080
     cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
-    return static_cast<T1>(
-        aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
+    aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
 #else
     TL_NOT_IMPLEMENTED();
 #endif
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
return static_cast<T1>(
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
#else
TL_NOT_IMPLEMENTED();
#endif
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
#else
TL_NOT_IMPLEMENTED();
#endif
🤖 Prompt for AI Agents
In src/tl_templates/cuda/atomic.h around lines 132 to 138, the AtomicMin
implementation is in a function declared void but contains a returned value;
remove the erroneous return so the function does not return a value. Replace the
line returning static_cast<T1>(aref.fetch_min(...)) with a call to
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order)) and
discard the result (or cast val beforehand) so the operation executes without
returning anything; keep the TL_NOT_IMPLEMENTED() fallback as-is.

}
}

Expand All @@ -143,9 +160,13 @@ TL_DEVICE T1 AtomicMinRet(T1 &ref, T2 val,
}
return static_cast<T1>(*reinterpret_cast<T1 *>(&old_val_ushort));
} else {
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
return static_cast<T1>(
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
#else
TL_NOT_IMPLEMENTED();
#endif
}
}

Expand Down Expand Up @@ -216,8 +237,12 @@ TL_DEVICE void AtomicAdd(T1 &ref, T2 val,
}
}
} else {
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_add(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
#else
TL_NOT_IMPLEMENTED();
#endif
}
}

Expand Down Expand Up @@ -290,9 +315,13 @@ TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val,
}
}
} else {
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
return static_cast<T1>(
aref.fetch_add(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
#else
TL_NOT_IMPLEMENTED();
#endif
}
}

Expand Down Expand Up @@ -618,13 +647,21 @@ AtomicAddx4Ret(float *ref, float *val,
#endif

template <typename T> TL_DEVICE T AtomicLoad(T &ref, int memory_order) {
#if CUDART_VERSION >= 11080
cuda::atomic_ref<T, cuda::thread_scope_device> aref(ref);
return aref.load(cuda::memory_order(memory_order));
#else
TL_NOT_IMPLEMENTED();
#endif
}

template <typename T1, typename T2>
TL_DEVICE void AtomicStore(T1 &ref, T2 value, int memory_order) {
using NT1 = typename normalize_atomic_type<T1>::type;
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(ref);
aref.store(cuda_cast<NT1>(value), cuda::memory_order(memory_order));
#else
TL_NOT_IMPLEMENTED();
#endif
}
9 changes: 9 additions & 0 deletions src/tl_templates/cuda/debug.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#pragma once

#if __CUDA_ARCH_LIST__ >= 890
#include "./cuda_fp8.h"
#endif

#include "common.h"

#ifndef __CUDACC_RTC__
Expand Down Expand Up @@ -117,6 +120,7 @@ __device__ void debug_print_var<double>(const char *msg, double var) {
threadIdx.z, var);
}

#if __CUDA_ARCH_LIST__ >= 890
// Specialization for fp8_e4_t type
template <>
__device__ void debug_print_var<fp8_e4_t>(const char *msg, fp8_e4_t var) {
Expand All @@ -137,6 +141,8 @@ __device__ void debug_print_var<fp8_e5_t>(const char *msg, fp8_e5_t var) {
threadIdx.z, (float)var);
}

#endif

// Template declaration for device-side debug printing (buffer only)
template <typename T>
__device__ void debug_print_buffer_value(const char *msg, const char *buf_name,
Expand Down Expand Up @@ -242,6 +248,7 @@ __device__ void debug_print_buffer_value<double>(const char *msg,
}

// Specialization for fp8_e4_t type
#if __CUDA_ARCH_LIST__ >= 890
template <>
__device__ void debug_print_buffer_value<fp8_e4_t>(const char *msg,
const char *buf_name,
Expand All @@ -263,6 +270,8 @@ __device__ void debug_print_buffer_value<fp8_e5_t>(const char *msg,
threadIdx.z, buf_name, index, (float)var);
}

#endif

// Specialization for int16 type
template <>
__device__ void debug_print_buffer_value<int16_t>(const char *msg,
Expand Down
1 change: 0 additions & 1 deletion src/tl_templates/cuda/gemm_mma.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include <cute/underscore.hpp>

#include "common.h"
#include "cuda_fp8.h"
#include "intrin.h"

namespace cute::tl_mma {
Expand Down
Loading