Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,45 @@ struct PassThrough
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x; }
};

struct Add
{
__host__ __device__ constexpr void
operator()(float& dst, const float& src_y, const float& bias) const
Comment thread
asroy marked this conversation as resolved.
Outdated
{
// FIXME - Use float (acc type) bias in the future.
Comment thread
rocking5566 marked this conversation as resolved.
Outdated
dst = src_y + bias;
}

__host__ __device__ constexpr void
operator()(half_t& dst, const half_t& src_y, const half_t& bias) const
Comment thread
asroy marked this conversation as resolved.
Outdated
{
// FIXME - Use float (acc type) bias in the future.
dst = src_y + bias;
}
};

struct AlphaBetaAdd
{
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {}

__host__ __device__ constexpr void
operator()(float& dst, const float& src_y, const float& bias) const
{
dst = alpha_ * src_y + beta_ * bias;
}

__host__ __device__ constexpr void
operator()(half_t& dst, const half_t& src_y, const half_t& bias) const
{
// FIXME - Let src_y be acc type
dst = static_cast<half_t>(alpha_ * static_cast<float>(src_y) +
beta_ * static_cast<float>(bias));
}

float alpha_;
float beta_;
};

struct AddRelu
{
__host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const
Expand Down
29 changes: 29 additions & 0 deletions device_operation/include/device_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,35 @@ namespace ck {
namespace tensor_operation {
namespace device {

template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemmBias : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const void* p_bias,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;

virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};

template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGemmBiasPtr = std::unique_ptr<
DeviceGemmBias<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;

template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
Expand Down
Loading