Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ struct PassThrough

__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x; }

__host__ __device__ void operator()(ushort& y, const ushort& x) const { y = x; }
__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { y = x; }

__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x; }

Expand Down
8 changes: 4 additions & 4 deletions composable_kernel/include/tensor_operation/xdlops_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ struct MfmaSelector
}

template <>
static constexpr auto GetMfma<ushort, 32, 32>()
static constexpr auto GetMfma<bhalf_t, 32, 32>()
{
#if defined(CK_AMD_GPU_GFX90A)
return MfmaInstr::mfma_f32_32x32x8bf16_1k;
Expand All @@ -484,7 +484,7 @@ struct MfmaSelector
}

template <>
static constexpr auto GetMfma<ushort, 16, 16>()
static constexpr auto GetMfma<bhalf_t, 16, 16>()
{
#if defined(CK_AMD_GPU_GFX90A)
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
Expand Down Expand Up @@ -662,8 +662,8 @@ struct XdlopsGemm
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{
static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value ||
is_same<base_type, ushort>::value || is_same<base_type, int8_t>::value,
"base base_type must be float, half, ushort, and int8_t!");
is_same<base_type, bhalf_t>::value || is_same<base_type, int8_t>::value,
"base base_type must be float, half, bfloat16, and int8_t!");

static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
mfma_instr.template run<MPerXdlops, NPerXdlops>(p_a_wave[k], p_b_wave[k], p_c_thread);
Expand Down
30 changes: 15 additions & 15 deletions composable_kernel/include/utility/amd_buffer_addressing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,19 @@ llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");

// buffer load i16
__device__ ushort
__device__ bhalf_t
llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16");

__device__ ushort2_t
__device__ bhalf2_t
llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16");

__device__ ushort4_t
__device__ bhalf4_t
llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
Expand Down Expand Up @@ -149,21 +149,21 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,

// buffer store i16
__device__ void
llvm_amdgcn_raw_buffer_store_i16(ushort vdata,
llvm_amdgcn_raw_buffer_store_i16(bhalf_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");

__device__ void
llvm_amdgcn_raw_buffer_store_i16x2(ushort2_t vdata,
llvm_amdgcn_raw_buffer_store_i16x2(bhalf2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");

__device__ void
llvm_amdgcn_raw_buffer_store_i16x4(ushort4_t vdata,
llvm_amdgcn_raw_buffer_store_i16x4(bhalf4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
Expand Down Expand Up @@ -266,7 +266,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(is_same<T, double>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, ushort>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
Expand Down Expand Up @@ -365,7 +365,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
return bit_cast<half8_t>(tmp);
}
}
else if constexpr(is_same<T, ushort>::value)
else if constexpr(is_same<T, bhalf_t>::value)
{
if constexpr(N == 1)
{
Expand All @@ -387,7 +387,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);

return bit_cast<ushort8_t>(tmp);
return bit_cast<bhalf8_t>(tmp);
}
}
else if constexpr(is_same<T, int32_t>::value)
Expand Down Expand Up @@ -522,7 +522,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
(is_same<T, double>::value && (N == 1 || N == 2)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, ushort>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
Expand Down Expand Up @@ -625,7 +625,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
#endif
}
}
else if constexpr(is_same<T, ushort>::value)
else if constexpr(is_same<T, bhalf_t>::value)
{
if constexpr(N == 1)
{
Expand Down Expand Up @@ -653,18 +653,18 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
}
else if constexpr(N == 8)
{
vector_type<half_t, 8> tmp{src_thread_data};
vector_type<bhalf_t, 8> tmp{src_thread_data};

llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);

llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t),
dst_wave_addr_offset + 4 * sizeof(bhalf_t),
0);
}
}
Expand Down
8 changes: 4 additions & 4 deletions composable_kernel/include/utility/amd_xdlops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ template <>
struct intrin_mfma_f32_32x32x8bf16_1k<32, 32>
{
template <class FloatC>
__device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c)
__device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
Expand All @@ -221,7 +221,7 @@ template <>
struct intrin_mfma_f32_16x16x16bf16_1k<16, 16>
{
template <class FloatC>
__device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c)
__device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
Expand All @@ -235,7 +235,7 @@ template <>
struct intrin_mfma_f32_32x32x4bf16<32, 32>
{
template <class FloatC>
__device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c)
__device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
Expand All @@ -249,7 +249,7 @@ template <>
struct intrin_mfma_f32_16x16x8bf16<16, 16>
{
template <class FloatC>
__device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c)
__device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
Expand Down
21 changes: 11 additions & 10 deletions composable_kernel/include/utility/data_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

namespace ck {

using bhalf_t = ushort;
using half_t = _Float16;

// vector_type
Expand Down Expand Up @@ -107,9 +108,9 @@ struct scalar_type<half_t>
};

template <>
struct scalar_type<ushort>
struct scalar_type<bhalf_t>
{
using type = ushort;
using type = bhalf_t;
static constexpr index_t vector_size = 1;
};

Expand Down Expand Up @@ -904,12 +905,12 @@ using half32_t = typename vector_type<half_t, 32>::type;
using half64_t = typename vector_type<half_t, 64>::type;

// bfp16
using ushort2_t = typename vector_type<ushort, 2>::type;
using ushort4_t = typename vector_type<ushort, 4>::type;
using ushort8_t = typename vector_type<ushort, 8>::type;
using ushort16_t = typename vector_type<ushort, 16>::type;
using ushort32_t = typename vector_type<ushort, 32>::type;
using ushort64_t = typename vector_type<ushort, 64>::type;
using bhalf2_t = typename vector_type<bhalf_t, 2>::type;
using bhalf4_t = typename vector_type<bhalf_t, 4>::type;
using bhalf8_t = typename vector_type<bhalf_t, 8>::type;
using bhalf16_t = typename vector_type<bhalf_t, 16>::type;
using bhalf32_t = typename vector_type<bhalf_t, 32>::type;
using bhalf64_t = typename vector_type<bhalf_t, 64>::type;

// i32
using int32x2_t = typename vector_type<int32_t, 2>::type;
Expand All @@ -936,7 +937,7 @@ __host__ __device__ Y type_convert(X x)

// convert bfp16 to fp32
template <>
inline __host__ __device__ float type_convert(ushort x)
inline __host__ __device__ float type_convert(bhalf_t x)
{
union
{
Expand All @@ -949,7 +950,7 @@ inline __host__ __device__ float type_convert(ushort x)

// convert fp32 to bfp16
template <>
inline __host__ __device__ ushort type_convert(float x)
inline __host__ __device__ bhalf_t type_convert(float x)
{
union
{
Expand Down
28 changes: 15 additions & 13 deletions device_operation/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ set(DEVICE_GEMM_INSTANCE_SOURCE
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp;
Expand All @@ -35,7 +37,7 @@ set(DEVICE_GEMM_INSTANCE_SOURCE
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp;
)
)

# device_gemm_bias_2d_instance
set(DEVICE_GEMM_BIAS_2D_INSTANCE_SOURCE
Expand Down Expand Up @@ -82,9 +84,9 @@ set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE
)

# device_conv1d_fwd_instance
set(DEVICE_CONV1D_FWD_INSTANCE_SOURCE
set(DEVICE_CONV1D_FWD_INSTANCE_SOURCE
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp;
)
)

# device_conv2d_fwd_bias_relu_instance
set(DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE
Expand All @@ -106,11 +108,11 @@ add_library(device_gemm_bias_2d_instance SHARED ${DEVICE_GEMM_BIAS_2D_INSTANCE_S
add_library(device_gemm_bias_relu_instance SHARED ${DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE})
add_library(device_gemm_bias_relu_add_instance SHARED ${DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE})
add_library(device_batched_gemm_instance SHARED ${DEVICE_BATCHED_GEMM_INSTANCE_SOURCE})
add_library(device_conv1d_fwd_instance SHARED ${DEVICE_CONV1D_FWD_INSTANCE_SOURCE})
add_library(device_conv2d_fwd_instance SHARED ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE})
add_library(device_conv2d_fwd_bias_relu_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE})
add_library(device_conv2d_fwd_bias_relu_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE})
add_library(device_conv2d_fwd_bias_relu_atomic_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE})
add_library(device_conv1d_fwd_instance SHARED ${DEVICE_CONV1D_FWD_INSTANCE_SOURCE})
add_library(device_conv2d_fwd_instance SHARED ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE})
add_library(device_conv2d_fwd_bias_relu_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE})
add_library(device_conv2d_fwd_bias_relu_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE})
add_library(device_conv2d_fwd_bias_relu_atomic_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE})

target_include_directories(device_gemm_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
target_include_directories(device_gemm_bias_2d_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
Expand Down Expand Up @@ -150,8 +152,8 @@ install(TARGETS device_gemm_bias_2d_instance LIBRARY DESTINATION lib)
install(TARGETS device_gemm_bias_relu_instance LIBRARY DESTINATION lib)
install(TARGETS device_gemm_bias_relu_add_instance LIBRARY DESTINATION lib)
install(TARGETS device_batched_gemm_instance LIBRARY DESTINATION lib)
install(TARGETS device_conv1d_fwd_instance LIBRARY DESTINATION lib)
install(TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib)
install(TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib)
install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib)
install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib)
install(TARGETS device_conv1d_fwd_instance LIBRARY DESTINATION lib)
install(TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib)
install(TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib)
install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib)
install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib)
Loading