diff --git a/composable_kernel/include/utility/amd_buffer_addressing.hpp b/composable_kernel/include/utility/amd_buffer_addressing.hpp index 57081b7fd72..a54607a0534 100644 --- a/composable_kernel/include/utility/amd_buffer_addressing.hpp +++ b/composable_kernel/include/utility/amd_buffer_addressing.hpp @@ -209,13 +209,49 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w index_t src_wave_addr_offset) { static_assert( - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)), + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); - if constexpr(is_same::value) + if constexpr(is_same::value) + { + // use fp32 load to mimic fp64 load + if constexpr(N == 1) + { + const float2_t tmp = llvm_amdgcn_raw_buffer_load_fp32x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return as_type(tmp); + } + else if constexpr(N == 2) + { + const float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return as_type(tmp); + } + else if constexpr(N == 4) + { + const float4_t f32_0 = llvm_amdgcn_raw_buffer_load_fp32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + const float4_t f32_1 = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(float), + 0); + vector_type tmp; + + tmp.AsType()(Number<0>{}) = as_type(f32_0); + tmp.AsType()(Number<1>{}) = as_type(f32_1); + + return tmp.AsType()(Number<0>{}); + } + } + else if constexpr(is_same::value) { if constexpr(N == 1) { @@ -267,25 +303,11 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w } else if constexpr(N == 8) { -#if 0 - vector_type tmp; - - tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp16x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - - tmp.AsType()(Number<1>{}) = - llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(half_t), - 0); - - return tmp.AsType()(Number<0>{}); -#else + // use fp32 load to mimic fp16 load float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); return as_type(tmp); -#endif } } else if constexpr(is_same::value) @@ -417,13 +439,34 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src index_t dst_wave_addr_offset) { static_assert( - (is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2)) || + (is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)), + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); - if constexpr(is_same::value) + if constexpr(is_same::value) + { + // use fp32 store to mimic fp64 store + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_fp32x2(as_type(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_fp32x4(as_type(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + } + else if constexpr(is_same::value) { if constexpr(N == 1) { @@ -450,6 +493,49 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src 0); } } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_fp16(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 8) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 4 * sizeof(half_t), + 0); + } + } else if constexpr(is_same::value) { if constexpr(N == 1) @@ -536,49 +622,6 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src 0); } } - else if constexpr(is_same::value) - { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_fp16(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 8) - { - vector_type tmp{src_thread_data}; - - llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - - llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 4 * sizeof(half_t), - 0); - } - } } // buffer_load requires: diff --git a/composable_kernel/include/utility/data_type.hpp b/composable_kernel/include/utility/data_type.hpp index 24a2190e843..bfaac8a939d 100644 --- a/composable_kernel/include/utility/data_type.hpp +++ b/composable_kernel/include/utility/data_type.hpp @@ -73,6 +73,13 @@ struct scalar_type> }; // +template <> +struct scalar_type +{ + using type = double; + static constexpr index_t vector_size = 1; +}; + template <> struct scalar_type { @@ -864,6 +871,10 @@ struct vector_type } }; +// fp64 +using double2_t = typename vector_type::type; +using double4_t = typename vector_type::type; + // fp32 using float2_t = typename vector_type::type; using float4_t = typename vector_type::type; diff --git a/composable_kernel/include/utility/dynamic_buffer.hpp b/composable_kernel/include/utility/dynamic_buffer.hpp index 4d583e3ce7f..c3e3cc79aea 100644 --- a/composable_kernel/include/utility/dynamic_buffer.hpp +++ b/composable_kernel/include/utility/dynamic_buffer.hpp @@ -234,9 +234,14 @@ __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize el return DynamicBuffer{p, element_space_size}; } -template +template < + AddressSpaceEnum_t BufferAddressSpace, + typename T, + typename ElementSpaceSize, + typename X, + typename enable_if, remove_cvref_t>::value, bool>::type = false> __host__ __device__ constexpr auto -make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, T invalid_element_value) +make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element_value) { return DynamicBuffer{ p, element_space_size, invalid_element_value}; diff --git a/composable_kernel/include/utility/type.hpp b/composable_kernel/include/utility/type.hpp index b7902ad4968..89a2bdbde63 100644 --- a/composable_kernel/include/utility/type.hpp +++ b/composable_kernel/include/utility/type.hpp @@ -22,6 +22,9 @@ using remove_reference_t = typename std::remove_reference::type; template using remove_cv_t = typename std::remove_cv::type; +template +using remove_cvref_t = remove_cv_t>; + template inline constexpr bool is_pointer_v = std::is_pointer::value; diff --git a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp index c1208ac3cbe..71239e0ecc9 100644 --- a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp +++ b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp @@ -374,13 +374,8 @@ extern "C" __global__ void CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1{}, CGridBlockCluster_BlockId_To_GM10_GN10{})); - const auto desc_tuple = *reinterpret_cast( -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wold-style-cast" - // TODO: how to cast? - (const void*)p_desc_tuple -#pragma clang diagnostic pop - ); + const auto desc_tuple = + *reinterpret_cast(cast_pointer_to_generic_address_space(p_desc_tuple)); const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = desc_tuple[I0]; const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = desc_tuple[I1];