Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ struct PassThrough
__host__ __device__ void operator()(float& y, const float& x) const { y = x; }

__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x; }

__host__ __device__ void operator()(ushort& y, const ushort& x) const { y = x; }

__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x; }

__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x; }
};

struct Add
Expand Down
2 changes: 0 additions & 2 deletions composable_kernel/include/utility/common_header.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@
#include "amd_address_space.hpp"
#include "amd_buffer_addressing.hpp"
#include "static_buffer.hpp"
// TODO remove this
#include "static_buffer_of_vector_type_v2.hpp"
#include "dynamic_buffer.hpp"
#include "is_known_at_compile_time.hpp"
#include "transpose_vectors.hpp"
Expand Down
10 changes: 10 additions & 0 deletions composable_kernel/include/utility/dynamic_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ struct DynamicBuffer
is_same<remove_cvref_t<X>, int8x4_t>::value) ||
(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x8_t>::value) ||
(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x16_t>::value) ||
(is_same<remove_cvref_t<T>, int8x4_t>::value &&
is_same<remove_cvref_t<X>, int8x4_t>::value) ||
(is_same<remove_cvref_t<T>, int8x8_t>::value &&
Expand Down Expand Up @@ -212,6 +214,14 @@ struct DynamicBuffer
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x2_t*>(&x);
}
else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x16_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x4_t*>(&x);
}
else if constexpr(is_same<remove_cvref_t<T>, int8x4_t>::value &&
is_same<remove_cvref_t<X>, int8x4_t>::value)
{
Expand Down

This file was deleted.

2 changes: 2 additions & 0 deletions device_operation/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ set(DEVICE_BATCHED_GEMM_INSTANCE_SOURCE
set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp;
)

Expand Down
Loading