Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/op/atomic_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

#include "../layout/layout.h"
#include "../target/utils.h"
#include "../transform/atomicadd_vectorize.h"
#include "../transform/common/loop_fusion_utils.h"
#include "../transform/loop_partition.h"
#include "builtin.h"
Expand Down
89 changes: 60 additions & 29 deletions src/tl_templates/cuda/atomic.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@ TL_DEVICE T1 AtomicAddRet(T1 *address, T2 val,
}
}

// TODO add memory_order for vectorized atomic add
TL_DEVICE void AtomicAddx2(half_t *ref, half_t *val,
template <typename src_type>
TL_DEVICE void AtomicAddx2(half_t *ref, src_type *val,
int memory_order = int(cuda::memory_order_relaxed)) {
if (memory_order == int(cuda::memory_order_relaxed)) {
atomicAdd(reinterpret_cast<half2 *>(ref),
Expand Down Expand Up @@ -374,8 +374,9 @@ TL_DEVICE void AtomicAddx2(half_t *ref, half_t *val,
}
}

template <typename src_type>
TL_DEVICE half2
AtomicAddx2Ret(half_t *ref, half_t *val,
AtomicAddx2Ret(half_t *ref, src_type *val,
int memory_order = int(cuda::memory_order_relaxed)) {
if (memory_order == int(cuda::memory_order_relaxed)) {
return atomicAdd(reinterpret_cast<half2 *>(ref),
Expand Down Expand Up @@ -419,7 +420,8 @@ AtomicAddx2Ret(half_t *ref, half_t *val,
}

#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
TL_DEVICE void AtomicAddx2(bfloat16_t *ref, bfloat16_t *val,
template <typename src_type>
TL_DEVICE void AtomicAddx2(bfloat16_t *ref, src_type *val,
int memory_order = int(cuda::memory_order_relaxed)) {
if (memory_order == int(cuda::memory_order_relaxed)) {
atomicAdd(
Expand Down Expand Up @@ -458,8 +460,9 @@ TL_DEVICE void AtomicAddx2(bfloat16_t *ref, bfloat16_t *val,
}
}

template <typename src_type>
TL_DEVICE __nv_bfloat162
AtomicAddx2Ret(bfloat16_t *ref, bfloat16_t *val,
AtomicAddx2Ret(bfloat16_t *ref, src_type *val,
int memory_order = int(cuda::memory_order_relaxed)) {
if (memory_order == int(cuda::memory_order_relaxed)) {
return atomicAdd(
Expand Down Expand Up @@ -502,13 +505,19 @@ AtomicAddx2Ret(bfloat16_t *ref, bfloat16_t *val,
#endif

#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
TL_DEVICE void AtomicAddx2(float *ref, float *val,
template <typename T> TL_DEVICE float2 ToFloat2(T *val) {
return *reinterpret_cast<float2 *>(val);
}

TL_DEVICE float2 ToFloat2(float2 val) { return val; }

template <typename ValType>
TL_DEVICE void AtomicAddx2(float *ref, ValType val,
int memory_order = int(cuda::memory_order_relaxed)) {
float2 add_val = ToFloat2(val);
if (memory_order == int(cuda::memory_order_relaxed)) {
atomicAdd(reinterpret_cast<float2 *>(ref),
static_cast<float2>(*reinterpret_cast<float2 *>(val)));
atomicAdd(reinterpret_cast<float2 *>(ref), add_val);
} else {
float2 add_val = *reinterpret_cast<float2 *>(val);
unsigned long long ref_addr = reinterpret_cast<unsigned long long>(ref);
float2 ret_val;
if (memory_order == int(cuda::memory_order_release) ||
Expand All @@ -532,14 +541,14 @@ TL_DEVICE void AtomicAddx2(float *ref, float *val,
}
}

template <typename ValType>
TL_DEVICE float2
AtomicAddx2Ret(float *ref, float *val,
AtomicAddx2Ret(float *ref, ValType val,
int memory_order = int(cuda::memory_order_relaxed)) {
float2 add_val = ToFloat2(val);
if (memory_order == int(cuda::memory_order_relaxed)) {
return atomicAdd(reinterpret_cast<float2 *>(ref),
static_cast<float2>(*reinterpret_cast<float2 *>(val)));
return atomicAdd(reinterpret_cast<float2 *>(ref), add_val);
} else {
float2 add_val = *reinterpret_cast<float2 *>(val);
unsigned long long ref_addr = reinterpret_cast<unsigned long long>(ref);
float2 ret_val;
if (memory_order == int(cuda::memory_order_release) ||
Expand All @@ -564,16 +573,22 @@ AtomicAddx2Ret(float *ref, float *val,
}
}

TL_DEVICE void AtomicAddx4(float *ref, float *val,
template <typename T> TL_DEVICE float4 ToFloat4(T *val) {
return *reinterpret_cast<float4 *>(val);
}

TL_DEVICE float4 ToFloat4(float4 val) { return val; }

template <typename dst_dtype, typename ValType>
TL_DEVICE void AtomicAddx4(dst_dtype *ref, ValType val,
int memory_order = int(cuda::memory_order_relaxed)) {
float4 add_val = ToFloat4(val);
if (memory_order == int(cuda::memory_order_relaxed)) {
atomicAdd(reinterpret_cast<float4 *>(ref),
static_cast<float4>(*reinterpret_cast<float4 *>(val)));
atomicAdd(reinterpret_cast<float4 *>(ref), add_val);
} else {
// Since atomicAdd does not support memory order, atomic_ref does not
// support vectorized atomic operation we can only inline ptx code here
// Note: Vectorized atomic operations only support global space
float4 add_val = *reinterpret_cast<float4 *>(val);
unsigned long long ref_addr = reinterpret_cast<unsigned long long>(ref);
float4 ret_val;
if (memory_order == int(cuda::memory_order_release) ||
Expand Down Expand Up @@ -606,14 +621,14 @@ TL_DEVICE void AtomicAddx4(float *ref, float *val,
}
}

template <typename dst_dtype, typename ValType>
TL_DEVICE float4
AtomicAddx4Ret(float *ref, float *val,
AtomicAddx4Ret(dst_dtype *ref, ValType val,
int memory_order = int(cuda::memory_order_relaxed)) {
float4 add_val = ToFloat4(val);
if (memory_order == int(cuda::memory_order_relaxed)) {
return atomicAdd(reinterpret_cast<float4 *>(ref),
static_cast<float4>(*reinterpret_cast<float4 *>(val)));
return atomicAdd(reinterpret_cast<float4 *>(ref), add_val);
} else {
float4 add_val = *reinterpret_cast<float4 *>(val);
unsigned long long ref_addr = reinterpret_cast<unsigned long long>(ref);
float4 ret_val;
if (memory_order == int(cuda::memory_order_release) ||
Expand Down Expand Up @@ -647,40 +662,56 @@ AtomicAddx4Ret(float *ref, float *val,
}
}
#else
TL_DEVICE void AtomicAddx2(float *ref, float *val,
template <typename T> TL_DEVICE float2 ToFloat2(T *val) {
return *reinterpret_cast<float2 *>(val);
}

TL_DEVICE float2 ToFloat2(float2 val) { return val; }

template <typename T> TL_DEVICE float4 ToFloat4(T *val) {
return *reinterpret_cast<float4 *>(val);
}

TL_DEVICE float4 ToFloat4(float4 val) { return val; }

template <typename ValType>
TL_DEVICE void AtomicAddx2(float *ref, ValType val,
int memory_order = int(cuda::memory_order_relaxed)) {
(void)memory_order;
float2 add_val = *reinterpret_cast<float2 *>(val);
float2 add_val = ToFloat2(val);
atomicAdd(ref + 0, add_val.x);
atomicAdd(ref + 1, add_val.y);
}

template <typename ValType>
TL_DEVICE float2
AtomicAddx2Ret(float *ref, float *val,
AtomicAddx2Ret(float *ref, ValType val,
int memory_order = int(cuda::memory_order_relaxed)) {
(void)memory_order;
float2 add_val = *reinterpret_cast<float2 *>(val);
float2 add_val = ToFloat2(val);
float2 ret;
ret.x = atomicAdd(ref + 0, add_val.x);
ret.y = atomicAdd(ref + 1, add_val.y);
return ret;
}

TL_DEVICE void AtomicAddx4(float *ref, float *val,
template <typename dst_dtype, typename ValType>
TL_DEVICE void AtomicAddx4(dst_dtype *ref, ValType val,
int memory_order = int(cuda::memory_order_relaxed)) {
(void)memory_order;
float4 add_val = *reinterpret_cast<float4 *>(val);
float4 add_val = ToFloat4(val);
atomicAdd(ref + 0, add_val.x);
atomicAdd(ref + 1, add_val.y);
atomicAdd(ref + 2, add_val.z);
atomicAdd(ref + 3, add_val.w);
}

template <typename dst_dtype, typename ValType>
TL_DEVICE float4
AtomicAddx4Ret(float *ref, float *val,
AtomicAddx4Ret(dst_dtype *ref, ValType val,
int memory_order = int(cuda::memory_order_relaxed)) {
(void)memory_order;
float4 add_val = *reinterpret_cast<float4 *>(val);
float4 add_val = ToFloat4(val);
float4 ret;
ret.x = atomicAdd(ref + 0, add_val.x);
ret.y = atomicAdd(ref + 1, add_val.y);
Expand Down
162 changes: 0 additions & 162 deletions src/transform/atomicadd_vectorize.cc

This file was deleted.

Loading
Loading