Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@ namespace ck {
namespace tensor_operation {
namespace element_wise {

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wnon-virtual-dtor"
struct UnaryOpBase
{
public:
__host__ __device__ virtual ~UnaryOpBase() = default;
__host__ __device__ ~UnaryOpBase() = default;

__host__ __device__ UnaryOpBase() = default;
__host__ __device__ UnaryOpBase(const UnaryOpBase&) = default;
__host__ __device__ constexpr UnaryOpBase() = default;
__host__ __device__ constexpr UnaryOpBase(const UnaryOpBase&) = default;
__host__ __device__ constexpr UnaryOpBase(UnaryOpBase&&) = default;
__host__ __device__ UnaryOpBase& operator=(const UnaryOpBase&) = default;
__host__ __device__ UnaryOpBase(UnaryOpBase&&) = default;
__host__ __device__ UnaryOpBase& operator=(UnaryOpBase&&) = default;

__host__ __device__ virtual inline void operator()(float& y, const float& x) const = 0;
Expand Down Expand Up @@ -50,8 +52,14 @@ struct PassThroughPack2
constexpr const static bool is_pack2_invocable = true;
};

struct PassThrough : public UnaryOpBase
struct PassThrough final : public UnaryOpBase
{
__host__ __device__ constexpr PassThrough() = default;
__host__ __device__ constexpr PassThrough(const PassThrough&) = default;
__host__ __device__ constexpr PassThrough(PassThrough&&) = default;
__host__ __device__ PassThrough& operator=(const PassThrough&) = default;
__host__ __device__ PassThrough& operator=(PassThrough&&) = default;
__host__ __device__ ~PassThrough() = default;

__host__ __device__ inline void operator()(float& y, const float& x) const final { y = x; }

Expand Down Expand Up @@ -409,8 +417,15 @@ struct UnarySquare
};
};

struct UnaryAbs : public UnaryOpBase
struct UnaryAbs final : public UnaryOpBase
{
__host__ __device__ constexpr UnaryAbs() = default;
__host__ __device__ constexpr UnaryAbs(const UnaryAbs&) = default;
__host__ __device__ constexpr UnaryAbs(UnaryAbs&&) = default;
__host__ __device__ UnaryAbs& operator=(const UnaryAbs&) = default;
__host__ __device__ UnaryAbs& operator=(UnaryAbs&&) = default;
__host__ __device__ ~UnaryAbs() = default;

__host__ __device__ inline void operator()(float& y, const float& x) const final
{
y = ck::math::abs(x);
Expand Down Expand Up @@ -459,8 +474,15 @@ struct UnarySqrt
};
};

struct Relu : public UnaryOpBase
struct Relu final : public UnaryOpBase
{
__host__ __device__ constexpr Relu() = default;
__host__ __device__ constexpr Relu(const Relu&) = default;
__host__ __device__ constexpr Relu(Relu&&) = default;
__host__ __device__ Relu& operator=(const Relu&) = default;
__host__ __device__ Relu& operator=(Relu&&) = default;
__host__ __device__ ~Relu() = default;

__host__ __device__ inline void operator()(float& y, const float& x) const final
{
y = x > 0 ? x : 0;
Expand Down Expand Up @@ -633,8 +655,14 @@ struct Gelu
}
};

struct Sigmoid : public UnaryOpBase
struct Sigmoid final : public UnaryOpBase
{
__host__ __device__ constexpr Sigmoid() = default;
__host__ __device__ constexpr Sigmoid(const Sigmoid&) = default;
__host__ __device__ constexpr Sigmoid(Sigmoid&&) = default;
__host__ __device__ Sigmoid& operator=(const Sigmoid&) = default;
__host__ __device__ Sigmoid& operator=(Sigmoid&&) = default;
__host__ __device__ ~Sigmoid() = default;

__host__ __device__ inline void operator()(float& y, const float& x) const final
{
Expand Down Expand Up @@ -688,8 +716,15 @@ struct Silu
};
};

struct TanH : public UnaryOpBase
struct TanH final : public UnaryOpBase
{
__host__ __device__ constexpr TanH() = default;
__host__ __device__ constexpr TanH(const TanH&) = default;
__host__ __device__ constexpr TanH(TanH&&) = default;
__host__ __device__ TanH& operator=(const TanH&) = default;
__host__ __device__ TanH& operator=(TanH&&) = default;
__host__ __device__ ~TanH() = default;

__host__ __device__ inline void operator()(float& y, const float& x) const final
{
y = ck::math::tanh(x);
Expand Down Expand Up @@ -959,8 +994,12 @@ struct Rcp
};
};

struct Swish : public UnaryOpBase
struct Swish final : public UnaryOpBase
{
__host__ __device__ constexpr Swish(const Swish&) = default;
__host__ __device__ constexpr Swish(Swish&&) = default;
__host__ __device__ ~Swish() = default;

__host__ __device__ Swish(float beta = 1.0f) : beta_(beta) {}

__host__ __device__ float get_beta() const { return beta_; }
Expand Down Expand Up @@ -1019,8 +1058,12 @@ struct Swish : public UnaryOpBase
}
};

struct SoftRelu : public UnaryOpBase
struct SoftRelu final : public UnaryOpBase
{
__host__ __device__ constexpr SoftRelu(const SoftRelu&) = default;
__host__ __device__ constexpr SoftRelu(SoftRelu&&) = default;
__host__ __device__ ~SoftRelu() = default;

__host__ __device__ SoftRelu(float alpha = 1.0f) : alpha_(alpha) {}

__host__ __device__ float get_alpha() const { return alpha_; }
Expand Down Expand Up @@ -1070,8 +1113,12 @@ struct SoftRelu : public UnaryOpBase
}
};

struct Power : public UnaryOpBase
struct Power final : public UnaryOpBase
{
__host__ __device__ constexpr Power(const Power&) = default;
__host__ __device__ constexpr Power(Power&&) = default;
__host__ __device__ ~Power() = default;

__host__ __device__ Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
: alpha_(alpha), beta_(beta), gamma_(gamma)
{
Expand Down Expand Up @@ -1148,8 +1195,12 @@ struct Power : public UnaryOpBase
}
};

struct ClippedRelu : public UnaryOpBase
struct ClippedRelu final : public UnaryOpBase
{
__host__ __device__ constexpr ClippedRelu(const ClippedRelu&) = default;
__host__ __device__ constexpr ClippedRelu(ClippedRelu&&) = default;
__host__ __device__ ~ClippedRelu() = default;

__host__ __device__ ClippedRelu(float alpha = 0.f, float beta = 1.f)
: alpha_(alpha), beta_(beta)
{
Expand Down Expand Up @@ -1205,8 +1256,11 @@ struct ClippedRelu : public UnaryOpBase
}
};

struct LeakyRelu : public UnaryOpBase
struct LeakyRelu final : public UnaryOpBase
{
__host__ __device__ constexpr LeakyRelu(const LeakyRelu&) = default;
__host__ __device__ constexpr LeakyRelu(LeakyRelu&&) = default;
__host__ __device__ ~LeakyRelu() = default;

__host__ __device__ LeakyRelu(float alpha = 0.f) : alpha_(alpha) {}

Expand Down Expand Up @@ -1250,8 +1304,11 @@ struct LeakyRelu : public UnaryOpBase
}
};

struct Elu : public UnaryOpBase
struct Elu final : public UnaryOpBase
{
__host__ __device__ constexpr Elu(const Elu&) = default;
__host__ __device__ constexpr Elu(Elu&&) = default;
__host__ __device__ ~Elu() = default;

__host__ __device__ Elu(float alpha = 1.f) : alpha_(alpha) {}

Expand Down Expand Up @@ -1296,8 +1353,11 @@ struct Elu : public UnaryOpBase
}
};

struct Logistic : public UnaryOpBase
struct Logistic final : public UnaryOpBase
{
__host__ __device__ constexpr Logistic(const Logistic&) = default;
__host__ __device__ constexpr Logistic(Logistic&&) = default;
__host__ __device__ ~Logistic() = default;

__host__ __device__ Logistic(float alpha = 1.0f) : alpha_(alpha) {}

Expand Down Expand Up @@ -1631,8 +1691,23 @@ struct DynamicUnaryOp

__host__ __device__ ~DynamicUnaryOp()
{
if(unary_op_ptr_)
delete unary_op_ptr_;
switch(unary_op_type_)
{
case(UnaryOpType::Swish): delete static_cast<Swish*>(unary_op_ptr_); break;
case(UnaryOpType::Sigmoid): delete static_cast<Sigmoid*>(unary_op_ptr_); break;
case(UnaryOpType::PassThrough): delete static_cast<PassThrough*>(unary_op_ptr_); break;
case(UnaryOpType::Logistic): delete static_cast<Logistic*>(unary_op_ptr_); break;
case(UnaryOpType::TanH): delete static_cast<TanH*>(unary_op_ptr_); break;
case(UnaryOpType::Relu): delete static_cast<Relu*>(unary_op_ptr_); break;
case(UnaryOpType::SoftRelu): delete static_cast<SoftRelu*>(unary_op_ptr_); break;
case(UnaryOpType::UnaryAbs): delete static_cast<UnaryAbs*>(unary_op_ptr_); break;
case(UnaryOpType::Power): delete static_cast<Power*>(unary_op_ptr_); break;
case(UnaryOpType::ClippedRelu): delete static_cast<ClippedRelu*>(unary_op_ptr_); break;
case(UnaryOpType::LeakyRelu): delete static_cast<LeakyRelu*>(unary_op_ptr_); break;
case(UnaryOpType::Elu): delete static_cast<Elu*>(unary_op_ptr_); break;

default: break;
}
}

__device__ void InitUnaryOpPtrOnDevice()
Expand Down Expand Up @@ -1721,6 +1796,7 @@ struct DynamicUnaryOp
float beta;
float gamma;
};
#pragma clang diagnostic pop

} // namespace element_wise
} // namespace tensor_operation
Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/numeric/math.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/host/reference/reference_elementwise.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/host/reference/reference_permute.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
18 changes: 9 additions & 9 deletions include/ck_tile/ops/reduce/block/block_reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ namespace ck_tile {
// synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
template <typename AccDistributedTensor_, typename ReduceFunc, bool WithBroadcast = true>
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
const ReduceFunc& reduce_func,
bool_constant<WithBroadcast> = {})
const ReduceFunc& reduce_func,
bool_constant<WithBroadcast> = {})
{
using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode;
Expand Down Expand Up @@ -116,7 +116,7 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
*/
template <typename AccDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_& acc_tensor,
const ReduceFunc& reduce_func)
const ReduceFunc& reduce_func)
{
using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode;
Expand Down Expand Up @@ -175,9 +175,9 @@ template <typename AccDistributedTensor_,
index_t... InReduceDims,
typename ReduceFunc>
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor,
const InDistributedTensor_& in_tensor,
sequence<InReduceDims...>,
const ReduceFunc& reduce_func)
const InDistributedTensor_& in_tensor,
sequence<InReduceDims...>,
const ReduceFunc& reduce_func)
{
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
Expand Down Expand Up @@ -250,9 +250,9 @@ template <typename AccDataType_,
typename ReduceFunc,
typename InDataType_>
CK_TILE_DEVICE auto block_tile_reduce(const InDistributedTensor_& in_tensor,
sequence<InReduceDims...> in_reduce_dims,
const ReduceFunc& reduce_func,
const InDataType_& reduce_init)
sequence<InReduceDims...> in_reduce_dims,
const ReduceFunc& reduce_func,
const InDataType_& reduce_init)
{
using InDataType = typename InDistributedTensor_::DataType;
using AccDataType = remove_cvref_t<AccDataType_>;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/ops/welford/block/block_welford.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down