File tree 6 files changed +16
-8
lines changed
6 files changed +16
-8
lines changed Original file line number Diff line number Diff line change @@ -171,6 +171,7 @@ endif()
171
171
172
172
if (RAJA_ENABLE_HIP)
173
173
message (STATUS "HIP version: ${hip_VERSION} " )
174
+ set (RAJA_HIP_WAVESIZE "64" CACHE STRING "Set the wave size for GPU architecture. E.g. MI200/MI300 this is 64." )
174
175
if ("${hip_VERSION} " VERSION_LESS "3.5" )
175
176
message (FATAL_ERROR "Trying to use HIP/ROCm version ${hip_VERSION} . RAJA requires HIP/ROCm version 3.5 or newer. " )
176
177
endif ()
Original file line number Diff line number Diff line change @@ -182,6 +182,8 @@ static_assert(RAJA_HAS_SOME_CXX14,
182
182
#cmakedefine RAJA_ENABLE_NV_TOOLS_EXT
183
183
#cmakedefine RAJA_ENABLE_ROCTX
184
184
185
+ #cmakedefine RAJA_HIP_WAVESIZE @RAJA_HIP_WAVESIZE@
186
+
185
187
/*!
186
188
******************************************************************************
187
189
*
Original file line number Diff line number Diff line change @@ -324,8 +324,9 @@ struct DeviceConstants
324
324
// values for HIP warp size and max block size.
325
325
//
326
326
#if defined(__HIP_PLATFORM_AMD__)
327
- constexpr DeviceConstants device_constants (64 , 1024 , 64 ); // MI300A
328
- // constexpr DeviceConstants device_constants(64, 1024, 128); // MI250X
327
+ constexpr DeviceConstants device_constants (RAJA_HIP_WAVESIZE, 1024 , 64 ); // MI300A
328
+ // constexpr DeviceConstants device_constants(RAJA_HIP_WAVESIZE, 1024, 128); // MI250X
329
+
329
330
#elif defined(__HIP_PLATFORM_NVIDIA__)
330
331
constexpr DeviceConstants device_constants (32 , 1024 , 32 ); // V100
331
332
#endif
Original file line number Diff line number Diff line change @@ -57,7 +57,7 @@ namespace expt
57
57
58
58
public:
59
59
60
- static constexpr int s_num_elem = 64 ;
60
+ static constexpr int s_num_elem = policy::hip::device_constants.WARP_SIZE ;
61
61
62
62
/* !
63
63
* @brief Default constructor, zeros register contents
@@ -780,8 +780,8 @@ namespace expt
780
780
781
781
// Third: mask off everything but output_segment
782
782
// this is because all output segments are valid at this point
783
- // (5-segbits), the 5 is since the warp-width is 32 == 1<<5
784
- int our_output_segment = get_lane ()>>(6 -segbits);
783
+ static constexpr int log2_warp_size = RAJA::log2 (RAJA::policy::hip::device_constants. WARP_SIZE );
784
+ int our_output_segment = get_lane ()>>(log2_warp_size -segbits);
785
785
bool in_output_segment = our_output_segment == output_segment;
786
786
if (!in_output_segment){
787
787
result.get_raw_value () = 0 ;
@@ -828,8 +828,9 @@ namespace expt
828
828
829
829
// First: tree reduce values within each segment
830
830
element_type x = m_value;
831
+ static constexpr int log2_warp_size = RAJA::log2 (RAJA::policy::hip::device_constants.WARP_SIZE );
831
832
RAJA_UNROLL
832
- for (int i = 0 ;i < 6 -segbits; ++ i){
833
+ for (int i = 0 ;i < log2_warp_size -segbits; ++ i){
833
834
834
835
// tree shuffle
835
836
int delta = s_num_elem >> (i+1 );
Original file line number Diff line number Diff line change @@ -29,7 +29,8 @@ namespace expt {
29
29
struct RegisterTraits <RAJA::expt::hip_wave_register, T>{
30
30
using element_type = T;
31
31
using register_policy = RAJA::expt::hip_wave_register;
32
- static constexpr camp::idx_t s_num_elem = 64 ;
32
+ static constexpr camp::idx_t s_num_elem = policy::hip::device_constants.WARP_SIZE;
33
+
33
34
static constexpr camp::idx_t s_num_bits = sizeof (T) * s_num_elem;
34
35
using int_element_type = int32_t ;
35
36
};
Original file line number Diff line number Diff line change @@ -87,7 +87,9 @@ struct TensorTestHelper<RAJA::expt::hip_wave_register>
87
87
void exec (BODY const &body){
88
88
hipDeviceSynchronize ();
89
89
90
- RAJA::forall<RAJA::hip_exec<64 >>(RAJA::RangeSegment (0 ,64 ),
90
+ static constexpr int warp_size = RAJA::policy::hip::device_constants.WARP_SIZE ;
91
+
92
+ RAJA::forall<RAJA::hip_exec<warp_size>>(RAJA::RangeSegment (0 ,warp_size),
91
93
[=] RAJA_HOST_DEVICE (int ){
92
94
body ();
93
95
});
You can’t perform that action at this time.
0 commit comments