Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
09a2b54
Copy "gemm reduce" to "gemm bias add reduce"
rocking5566 Jun 1, 2022
300ac4e
Implement gemm bias add reduction
rocking5566 Jun 2, 2022
10962ae
Merge commit '1677cf705eb0f1f96e60d052df0e024bdf007b62' into gemm_add…
rocking5566 Jun 6, 2022
d6b08e3
Fix compiler error due to merge from develop
rocking5566 Jun 6, 2022
af1812f
Add tensor operation for gemm + bias + add + reduce
rocking5566 Jun 6, 2022
a5842a7
Add gemm_bais_add_reduce to ckProfiler
rocking5566 Jun 7, 2022
78f28a1
Add c1 functor
rocking5566 Jun 8, 2022
b3812da
Refine type
rocking5566 Jun 9, 2022
aa02705
Use reduceAccDataType instead of explicitly float
rocking5566 Jun 13, 2022
e3976f1
Change to use check_err()
rocking5566 Jun 13, 2022
3e4d275
Do relu in float32 instead of bhalf_t. Because bhalf_t is unsigned
rocking5566 Jun 13, 2022
46eca0a
Refactor relu. using type_trait instead of overloading
rocking5566 Jun 13, 2022
c44818e
Rename DxsReduceAccElementwiseOperation to DxsReduceAccElementwiseOpe…
rocking5566 Jun 13, 2022
edb2f81
Fix denominator
rocking5566 Jun 14, 2022
bf44347
Refine nameing
rocking5566 Jun 14, 2022
dfeade7
Fix denominator in host
rocking5566 Jun 14, 2022
6c0636c
Remove useless include header
rocking5566 Jun 14, 2022
3d091db
Use AccDataType
rocking5566 Jun 14, 2022
94d5f72
Fix static_cast order
rocking5566 Jun 14, 2022
f9d22b0
Refine type
rocking5566 Jun 14, 2022
a76eac2
[What] Remove tuple type in the base class
rocking5566 Jun 15, 2022
537b3bf
Merge branch 'develop' into gemm_add_bias_reduction
rocking5566 Jun 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions example/21_gemm_layernorm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_fp16 gemm_bias_relu_add_layernorm_xdl_fp16.cpp)
add_example_executable(example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp)
425 changes: 425 additions & 0 deletions example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp

Large diffs are not rendered by default.

Large diffs are not rendered by default.

50 changes: 50 additions & 0 deletions include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,56 @@ using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<DPtrsGlobal,
DxsInElementwiseOperation,
DxsAccElementwiseOperation>>;

template <typename DPtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation>
struct DeviceGemmBiasAddReduce : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const void* p_c0,
const void* p_c1,
DPtrsGlobal p_dxs,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t StrideC1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
C1ElementwiseOperation c1_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op,
ck::index_t BatchCount = 1) = 0;

virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};

template <typename DPtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation>
using DeviceGemmBiasAddReducePtr =
std::unique_ptr<DeviceGemmBiasAddReduce<DPtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
C1ElementwiseOperation,
DxsInElementwiseOperation,
DxsAccElementwiseOperation>>;

} // namespace device
} // namespace tensor_operation
} // namespace ck
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,21 @@ struct AddHardswishAdd
}
};

struct Relu
{
__host__ __device__ void operator()(float& y, const float& x) const { y = x > 0 ? x : 0; }

__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x > 0 ? x : 0; }

__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { y = x > 0 ? x : 0; }
Comment thread
qianfengz marked this conversation as resolved.
Outdated

__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x > 0 ? x : 0; }

__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x > 0 ? x : 0; }

__host__ __device__ void operator()(double& y, const double& x) const { y = x > 0 ? x : 0; }
};

struct Normalize
{
Normalize(float epsilon = 1e-4) : epsilon_(epsilon) {}
Expand Down
Loading