Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[r2.18-rocm-enhanced] Use faster atomics on ROCM; Enable kernel argument preloading #2784

Open
wants to merge 2 commits into
base: r2.18-rocm-enhanced
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 93 additions & 74 deletions tensorflow/core/util/gpu_device_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ limitations under the License.
#include "third_party/gpus/cuda/include/cuda.h"
#else
#include "rocm/include/hip/hip_complex.h"
#include "rocm/include/hip/hip_fp16.h"
#include "rocm/include/hip/hip_bf16.h"
#endif

#include "tensorflow/core/platform/types.h"
Expand Down Expand Up @@ -567,9 +569,22 @@ __global__ void SetToValue(const int count, T* __restrict__ ptr, Tvalue value) {
}

namespace detail {

template <int N, typename T>
__device__ T* AddressSpaceHint(T* ptr) {
#if defined(TENSORFLOW_USE_ROCM)
using AS = __attribute__((address_space(N))) T*;
auto ptr_ = reinterpret_cast<AS>(reinterpret_cast<uintptr_t>(ptr));
return (T*)(ptr_);
#else
return ptr; // NOOP
#endif
}

// Helper function for atomic accumulation implemented as CAS.
template <typename T, typename F>
__device__ T GpuAtomicCasHelper(T* ptr, F accumulate) {
ptr = detail::AddressSpaceHint<1>(ptr);
T old = *ptr;
T assumed;
do {
Expand All @@ -591,24 +606,11 @@ __device__ float GpuAtomicCasHelper(float* ptr, F accumulate) {
}
template <typename F>
__device__ double GpuAtomicCasHelper(double* ptr, F accumulate) {
#if TENSORFLOW_USE_ROCM
// FIXME: remove the workaround below once bug is fixed.
// HIP has a bug in the implementation of __longlong_as_double
// So workaround it by using reinterpret_cast<double*>.
uint64_t result =
GpuAtomicCasHelper(reinterpret_cast<unsigned long long*>(ptr),
[accumulate](tensorflow::uint64 a) {
return __double_as_longlong(
accumulate(*(reinterpret_cast<double*>(&a))));
});
return *(reinterpret_cast<double*>(&result));
#else
return __longlong_as_double(GpuAtomicCasHelper(
reinterpret_cast<unsigned long long*>(ptr),
[accumulate](tensorflow::uint64 a) {
return __double_as_longlong(accumulate(__longlong_as_double(a)));
}));
#endif
}

// Overload of above function for half. Note that we don't have
Expand All @@ -628,31 +630,20 @@ __device__ Eigen::half GpuAtomicCasHelper(Eigen::half* ptr, F accumulate) {
#if defined(__BYTE_ORDER__) && defined(__ORDER_LITTLE_ENDIAN__)
static_assert(__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__, "Not little endian");
#endif
intptr_t intptr = reinterpret_cast<intptr_t>(ptr);
uintptr_t intptr = reinterpret_cast<uintptr_t>(ptr);
uint32_t shift = (intptr & 0x2) * 8U;
uint32_t mask = 0xFFFF0000U >> shift;

assert(!(intptr & 0x1)); // should be 2-aligned.
if (intptr & 0x2) {
// The half is in the second part of the uint32 (upper 16 bits).
uint32* address = reinterpret_cast<uint32*>(intptr - 2);
uint32 result = GpuAtomicCasHelper(address, [accumulate](uint32 arg) {
unsigned short high = static_cast<unsigned short>(arg >> 16);
Eigen::half acc = accumulate(Eigen::numext::bit_cast<Eigen::half>(high));
return (static_cast<uint32>(Eigen::numext::bit_cast<uint16>(acc)) << 16) |
(arg & 0xffff);
});
return Eigen::numext::bit_cast<Eigen::half>(
static_cast<uint16>(result >> 16));
} else {
// The half is in the first part of the uint32 (lower 16 bits).
uint32* address = reinterpret_cast<uint32*>(intptr);
uint32 result = GpuAtomicCasHelper(address, [accumulate](uint32 arg) {
unsigned short low = static_cast<unsigned short>(arg & 0xffff);
Eigen::half acc = accumulate(Eigen::numext::bit_cast<Eigen::half>(low));
return (arg & 0xffff0000) |
static_cast<uint32>(Eigen::numext::bit_cast<uint16>(acc));
});
return Eigen::numext::bit_cast<Eigen::half>(
static_cast<uint16>(result & 0xffff));
}
uint32* address = reinterpret_cast<uint32*>(intptr & ~0x3);
uint32 result = GpuAtomicCasHelper(address, [accumulate, shift, mask](uint32 arg) {
uint16_t high = static_cast<uint16_t>(arg >> shift);
Eigen::half acc = accumulate(Eigen::numext::bit_cast<Eigen::half>(high));
return (static_cast<uint32>(Eigen::numext::bit_cast<uint16_t>(acc)) << shift) |
(arg & mask);
});
return Eigen::numext::bit_cast<Eigen::half>(
static_cast<uint16_t>(result >> shift));
}

template <typename F>
Expand Down Expand Up @@ -720,10 +711,11 @@ __device__ CudaSupportedType<T>* ToCudaSupportedPtr(T* ptr) {

template <typename T, typename U>
__device__ detail::ToTypeIfConvertible<U, T> GpuAtomicAdd(T* ptr, U value) {
return atomicAdd(detail::ToCudaSupportedPtr(ptr), value);
return atomicAdd(detail::ToCudaSupportedPtr(detail::AddressSpaceHint<1>(ptr)),
value);
}


#if !defined(TENSORFLOW_USE_ROCM)
__device__ inline Eigen::bfloat16 GpuAtomicAdd(Eigen::bfloat16* ptr,
Eigen::bfloat16 value) {
return detail::GpuAtomicCasHelper(
Expand All @@ -735,26 +727,73 @@ __device__ inline Eigen::half GpuAtomicAdd(Eigen::half* ptr,
return detail::GpuAtomicCasHelper(
ptr, [value](Eigen::half a) { return a + value; });
}
#endif

#if (__CUDA_ARCH__ < 600) || TENSORFLOW_USE_ROCM
#if (__CUDA_ARCH__ < 600)
__device__ inline double GpuAtomicAdd(double* ptr, double value) {
return detail::GpuAtomicCasHelper(ptr,
[value](double a) { return a + value; });
}
#endif

#if __gfx908__ || __gfx90a__ || __gfx940__ || __gfx941__ || __gfx942__
#if TENSORFLOW_USE_ROCM
template <typename T>
__device__ T GpuAtomicAddShared(T* dst, T val) {
return atomicAdd(detail::AddressSpaceHint<3>(dst), val);
}

#define ADDRSP1 __attribute__((address_space(1)))
__device__ float
#if __clang_major__ < 16
__llvm_amdgcn_global_atomic_add_f32(ADDRSP1 float* dst, float val) __asm("llvm.amdgcn.global.atomic.fadd.f32.p1f32.f32");
#else
__llvm_amdgcn_global_atomic_add_f32(ADDRSP1 float* dst, float val) __asm("llvm.amdgcn.global.atomic.fadd.f32.p1.f32");
#endif // clang_major
#endif // gfx
namespace detail {

template <typename P, typename T, typename F>
__device__ inline T GpuAtomicAddHalfHelper(T* ptr, T value, F add) {
typedef P __attribute__((ext_vector_type(2))) P2;
auto ptr2 = (__attribute__((address_space(1)))
P2*)(reinterpret_cast<uintptr_t>(ptr) & ~0x3);
uintptr_t shift = ((reinterpret_cast<uintptr_t>(ptr) & 0x2) * 8);
// Eigen::numext::bit_cast on ext_vector produces redudant inlined memcpy.
// Use union instead.
union {
P2 v2;
uint32_t i;
} u;

u.i = static_cast<uint32_t>(Eigen::numext::bit_cast<uint16_t>(value)) << shift;

// Performs + (T)0 on adjacent location, so this is not idempotent with
// regards to its bit pattern. Should be fine as long as that locations is
// used to hold T.
u.v2 = add(ptr2, u.v2);
return Eigen::numext::bit_cast<T>(static_cast<uint16_t>(u.i >> shift));
}

} // namespace detail

__device__ inline Eigen::bfloat16 GpuAtomicAdd(Eigen::bfloat16* ptr,
Eigen::bfloat16 value) {
#if __has_builtin(__builtin_amdgcn_global_atomic_fadd_v2bf16)
return detail::GpuAtomicAddHalfHelper<short>(
ptr, value, [](auto p, auto v) {
return __builtin_amdgcn_global_atomic_fadd_v2bf16(p, v);
});
#else
return detail::GpuAtomicCasHelper(
ptr, [value](Eigen::bfloat16 a) { return a + value; });
#endif
}

__device__ inline Eigen::half GpuAtomicAdd(Eigen::half* ptr,
Eigen::half value) {
#if __has_builtin(__builtin_amdgcn_global_atomic_fadd_v2f16)
return detail::GpuAtomicAddHalfHelper<_Float16>(
ptr, value, [](auto p, auto v) {
return __builtin_amdgcn_global_atomic_fadd_v2f16(p, v);
});
#else
return detail::GpuAtomicCasHelper(
ptr, [value](Eigen::half a) { return a + value; });
#endif
}
#endif

// GpuAtomicAdd
// Specializations of GpuAtomicAdd for complex types, which GpuAtomicAdd does
Expand Down Expand Up @@ -783,7 +822,7 @@ CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicAdd, CudaAtomicAdd);
// GpuAtomicSub
template <typename T, typename U>
__device__ detail::ToTypeIfConvertible<U, T> GpuAtomicSub(T* ptr, U value) {
return atomicSub(ptr, value);
return atomicSub(detail::AddressSpaceHint<1>(ptr), value);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may I ask why atomicAdd(detail::AddressSpaceHint<3>(dst), val); is using shared memory while atomicSub(detail::AddressSpaceHint<1>(ptr), value); is using global memory?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AddressSpaceHint<3> is only for GpuAtomicAddShared that is used by rocm specific kernels.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like there is no GpuAtomicSubShared, which would use addr space 3 if it existed. I'm not sure why it doesn't exist since sub just calls add anyway.

Copy link

@i-chaochen i-chaochen Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why AddressSpaceHint<3> is not used by atomicSub?

atomicSub is rocm specific kernel as well I think? https://rocm.docs.amd.com/projects/HIP/en/docs-5.7.0/reference/kernel_language.html#atomic-functions

int atomicSub(int* address, int val)

}

// Specializations of substraction which add the negative value.
Expand Down Expand Up @@ -821,7 +860,8 @@ CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicSub, CudaAtomicSub);
// GpuAtomicMax
template <typename T, typename U>
__device__ detail::ToTypeIfConvertible<U, T> GpuAtomicMax(T* ptr, U value) {
return atomicMax(detail::ToCudaSupportedPtr(ptr), value);
return atomicMax(detail::ToCudaSupportedPtr(detail::AddressSpaceHint<1>(ptr)),
value);
}

#if TENSORFLOW_USE_ROCM
Expand Down Expand Up @@ -894,7 +934,8 @@ CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicMax, CudaAtomicMax);
// GpuAtomicMin
template <typename T, typename U>
__device__ detail::ToTypeIfConvertible<U, T> GpuAtomicMin(T* ptr, U value) {
return atomicMin(detail::ToCudaSupportedPtr(ptr), value);
return atomicMin(detail::ToCudaSupportedPtr(detail::AddressSpaceHint<1>(ptr)),
value);
}

#if TENSORFLOW_USE_ROCM
Expand Down Expand Up @@ -963,28 +1004,6 @@ __device__ inline int64_t GpuAtomicMin(int64_t* ptr, int64_t value) {
}
#endif

#if __gfx908__ || __gfx90a__ || __gfx940__ || __gfx941__ || __gfx942__
// Low level instructions don't return. For now, assume that return value
// is always unused.
__device__ float GpuAtomicAdd(float* dst, float val) {
ADDRSP1 float* p = (ADDRSP1 float*) dst;
__llvm_amdgcn_global_atomic_add_f32(p, val);
return val;
}
#endif

template <typename T>
__device__ inline T GpuAtomicAddShared(T* ptr, T value) {
return GpuAtomicAdd(ptr, value);
}

#if __gfx908__ || __gfx90a__ || __gfx940__ || __gfx941__ || __gfx942__
__device__ float GpuAtomicAddShared(float* dst, float val) {
atomicAdd(dst, val);
return val;
}
#endif

CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicMin, CudaAtomicMin);

// GpuAtomicMul
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,12 @@ def InvokeHipcc(argv, log=False):
# of link time. This allows the default host compiler (gcc) be used as the
# linker for TensorFlow on ROCm platform.
hipccopts += ' -fno-gpu-rdc '
hipccopts += ' -fcuda-flush-denormals-to-zero '
hipccopts += ' -fcuda-flush-denormals-to-zero -munsafe-fp-atomics '
hipccopts += undefines
hipccopts += defines
hipccopts += std_options
hipccopts += m_options
hipccopts += ' -mllvm=-amdgpu-kernarg-preload-count=16 '

if depfiles:
# Generate the dependency file
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,10 @@ std::string GetROCDLDir(const DebugOptions& debug_options) {

void AMDGPUBackendInit(const DebugOptions& debug_options,
std::string& rocdl_dir_path) {
FeedLLVMWithFlags({
"-amdgpu-kernarg-preload-count=16",
});

llvm_ir::InitializeLLVMCommandLineOptions(
debug_options.xla_backend_extra_options());

Expand Down