Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 109 additions & 66 deletions composable_kernel/include/utility/amd_buffer_addressing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,49 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
index_t src_wave_addr_offset)
{
static_assert(
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, double>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)),
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");

if constexpr(is_same<T, float>::value)
if constexpr(is_same<T, double>::value)
{
// use fp32 load to mimic fp64 load
if constexpr(N == 1)
{
const float2_t tmp = llvm_amdgcn_raw_buffer_load_fp32x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);

return as_type<double>(tmp);
}
else if constexpr(N == 2)
{
const float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);

return as_type<double2_t>(tmp);
}
else if constexpr(N == 4)
{
const float4_t f32_0 = llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);

const float4_t f32_1 =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float),
0);
vector_type<double, 4> tmp;

tmp.AsType<double2_t>()(Number<0>{}) = as_type<double2_t>(f32_0);
tmp.AsType<double2_t>()(Number<1>{}) = as_type<double2_t>(f32_1);

return tmp.AsType<double4_t>()(Number<0>{});
}
}
else if constexpr(is_same<T, float>::value)
{
if constexpr(N == 1)
{
Expand Down Expand Up @@ -267,25 +303,11 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
}
else if constexpr(N == 8)
{
#if 0
vector_type<half_t, 8> tmp;

tmp.AsType<half4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);

tmp.AsType<half4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(half_t),
0);

return tmp.AsType<half8_t>()(Number<0>{});
#else
// use fp32 load to mimic fp16 load
float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);

return as_type<half8_t>(tmp);
#endif
}
}
else if constexpr(is_same<T, int32_t>::value)
Expand Down Expand Up @@ -417,13 +439,34 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
index_t dst_wave_addr_offset)
{
static_assert(
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, double>::value && (N == 1 || N == 2)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)),
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");

if constexpr(is_same<T, float>::value)
if constexpr(is_same<T, double>::value)
{
// use fp32 store to mimic fp64 store
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_fp32x2(as_type<float2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp32x4(as_type<float4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
}
else if constexpr(is_same<T, float>::value)
{
if constexpr(N == 1)
{
Expand All @@ -450,6 +493,49 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
0);
}
}
else if constexpr(is_same<T, half_t>::value)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_fp16(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 8)
{
vector_type<half_t, 8> tmp{src_thread_data};

llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);

llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t),
0);
}
}
else if constexpr(is_same<T, int32_t>::value)
{
if constexpr(N == 1)
Expand Down Expand Up @@ -536,49 +622,6 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
0);
}
}
else if constexpr(is_same<T, half_t>::value)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_fp16(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 8)
{
vector_type<half_t, 8> tmp{src_thread_data};

llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);

llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t),
0);
}
}
}

// buffer_load requires:
Expand Down
11 changes: 11 additions & 0 deletions composable_kernel/include/utility/data_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ struct scalar_type<vector_type<T, N>>
};

//
template <>
struct scalar_type<double>
{
using type = double;
static constexpr index_t vector_size = 1;
};

template <>
struct scalar_type<float>
{
Expand Down Expand Up @@ -864,6 +871,10 @@ struct vector_type<T, 256>
}
};

// fp64
using double2_t = typename vector_type<double, 2>::type;
using double4_t = typename vector_type<double, 4>::type;

// fp32
using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type;
Expand Down
9 changes: 7 additions & 2 deletions composable_kernel/include/utility/dynamic_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,14 @@ __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize el
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, true>{p, element_space_size};
}

template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
template <
AddressSpaceEnum_t BufferAddressSpace,
typename T,
typename ElementSpaceSize,
typename X,
typename enable_if<is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value, bool>::type = false>
__host__ __device__ constexpr auto
make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, T invalid_element_value)
make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element_value)
{
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, false>{
p, element_space_size, invalid_element_value};
Expand Down
3 changes: 3 additions & 0 deletions composable_kernel/include/utility/type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ using remove_reference_t = typename std::remove_reference<T>::type;
template <typename T>
using remove_cv_t = typename std::remove_cv<T>::type;

template <typename T>
using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;

template <typename T>
inline constexpr bool is_pointer_v = std::is_pointer<T>::value;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,13 +374,8 @@ extern "C" __global__ void
CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1{},
CGridBlockCluster_BlockId_To_GM10_GN10{}));

const auto desc_tuple = *reinterpret_cast<const DescTuple*>(
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
// TODO: how to cast?
(const void*)p_desc_tuple
#pragma clang diagnostic pop
);
const auto desc_tuple =
*reinterpret_cast<const DescTuple*>(cast_pointer_to_generic_address_space(p_desc_tuple));

const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = desc_tuple[I0];
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = desc_tuple[I1];
Expand Down