Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,41 @@ struct PassThrough
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x; }
};

struct Add
{
__host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const
{
y = x0 + x1;
}

__host__ __device__ constexpr void
operator()(half_t& y, const half_t& x0, const half_t& x1) const
{
// FIXME - Use float (acc type) bias in the future.
y = x0 + x1;
}
};

struct AlphaBetaAdd
{
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {}

__host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const
{
y = alpha_ * x0 + beta_ * x1;
}

__host__ __device__ constexpr void
operator()(half_t& y, const half_t& x0, const half_t& x1) const
{
// FIXME - Let x0 be acc type
y = static_cast<half_t>(alpha_ * static_cast<float>(x0) + beta_ * static_cast<float>(x1));
}

float alpha_;
float beta_;
};

struct AddRelu
{
__host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const
Expand Down
29 changes: 29 additions & 0 deletions device_operation/include/device_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,35 @@ namespace ck {
namespace tensor_operation {
namespace device {

template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemmBias : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const void* p_bias,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;

virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};

template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGemmBiasPtr = std::unique_ptr<
DeviceGemmBias<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;

template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
Expand Down
Loading