Skip to content

Commit ac70a5b

Browse files
authored
Merge pull request #1729 from LLNL/feature/seanofthemillers/adding_radeon_support
Adding Radeon support by controlling wave size
2 parents 0aef7cc + d0a65e7 commit ac70a5b

File tree

6 files changed

+16
-8
lines changed

6 files changed

+16
-8
lines changed

CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ endif()
171171

172172
if(RAJA_ENABLE_HIP)
173173
message(STATUS "HIP version: ${hip_VERSION}")
174+
set(RAJA_HIP_WAVESIZE "64" CACHE STRING "Set the wave size for GPU architecture. E.g. MI200/MI300 this is 64.")
174175
if("${hip_VERSION}" VERSION_LESS "3.5")
175176
message(FATAL_ERROR "Trying to use HIP/ROCm version ${hip_VERSION}. RAJA requires HIP/ROCm version 3.5 or newer. ")
176177
endif()

include/RAJA/config.hpp.in

+2
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ static_assert(RAJA_HAS_SOME_CXX14,
182182
#cmakedefine RAJA_ENABLE_NV_TOOLS_EXT
183183
#cmakedefine RAJA_ENABLE_ROCTX
184184

185+
#cmakedefine RAJA_HIP_WAVESIZE @RAJA_HIP_WAVESIZE@
186+
185187
/*!
186188
******************************************************************************
187189
*

include/RAJA/policy/hip/policy.hpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -324,8 +324,9 @@ struct DeviceConstants
324324
// values for HIP warp size and max block size.
325325
//
326326
#if defined(__HIP_PLATFORM_AMD__)
327-
constexpr DeviceConstants device_constants(64, 1024, 64); // MI300A
328-
// constexpr DeviceConstants device_constants(64, 1024, 128); // MI250X
327+
constexpr DeviceConstants device_constants(RAJA_HIP_WAVESIZE, 1024, 64); // MI300A
328+
// constexpr DeviceConstants device_constants(RAJA_HIP_WAVESIZE, 1024, 128); // MI250X
329+
329330
#elif defined(__HIP_PLATFORM_NVIDIA__)
330331
constexpr DeviceConstants device_constants(32, 1024, 32); // V100
331332
#endif

include/RAJA/policy/tensor/arch/hip/hip_wave.hpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ namespace expt
5757

5858
public:
5959

60-
static constexpr int s_num_elem = 64;
60+
static constexpr int s_num_elem = policy::hip::device_constants.WARP_SIZE;
6161

6262
/*!
6363
* @brief Default constructor, zeros register contents
@@ -780,8 +780,8 @@ namespace expt
780780

781781
// Third: mask off everything but output_segment
782782
// this is because all output segments are valid at this point
783-
// (5-segbits), the 5 is since the warp-width is 32 == 1<<5
784-
int our_output_segment = get_lane()>>(6-segbits);
783+
static constexpr int log2_warp_size = RAJA::log2(RAJA::policy::hip::device_constants.WARP_SIZE);
784+
int our_output_segment = get_lane()>>(log2_warp_size-segbits);
785785
bool in_output_segment = our_output_segment == output_segment;
786786
if(!in_output_segment){
787787
result.get_raw_value() = 0;
@@ -828,8 +828,9 @@ namespace expt
828828

829829
// First: tree reduce values within each segment
830830
element_type x = m_value;
831+
static constexpr int log2_warp_size = RAJA::log2(RAJA::policy::hip::device_constants.WARP_SIZE);
831832
RAJA_UNROLL
832-
for(int i = 0;i < 6-segbits; ++ i){
833+
for(int i = 0;i < log2_warp_size-segbits; ++ i){
833834

834835
// tree shuffle
835836
int delta = s_num_elem >> (i+1);

include/RAJA/policy/tensor/arch/hip/traits.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ namespace expt {
2929
struct RegisterTraits<RAJA::expt::hip_wave_register, T>{
3030
using element_type = T;
3131
using register_policy = RAJA::expt::hip_wave_register;
32-
static constexpr camp::idx_t s_num_elem = 64;
32+
static constexpr camp::idx_t s_num_elem = policy::hip::device_constants.WARP_SIZE;
33+
3334
static constexpr camp::idx_t s_num_bits = sizeof(T) * s_num_elem;
3435
using int_element_type = int32_t;
3536
};

test/include/RAJA_test-tensor.hpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@ struct TensorTestHelper<RAJA::expt::hip_wave_register>
8787
void exec(BODY const &body){
8888
hipDeviceSynchronize();
8989

90-
RAJA::forall<RAJA::hip_exec<64>>(RAJA::RangeSegment(0,64),
90+
static constexpr int warp_size = RAJA::policy::hip::device_constants.WARP_SIZE;
91+
92+
RAJA::forall<RAJA::hip_exec<warp_size>>(RAJA::RangeSegment(0,warp_size),
9193
[=] RAJA_HOST_DEVICE (int ){
9294
body();
9395
});

0 commit comments

Comments
 (0)