Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
09a2b54
Copy "gemm reduce" to "gemm bias add reduce"
rocking5566 Jun 1, 2022
300ac4e
Implement gemm bias add reduction
rocking5566 Jun 2, 2022
10962ae
Merge commit '1677cf705eb0f1f96e60d052df0e024bdf007b62' into gemm_add…
rocking5566 Jun 6, 2022
d6b08e3
Fix compiler error due to merge from develop
rocking5566 Jun 6, 2022
af1812f
Add tensor operation for gemm + bias + add + reduce
rocking5566 Jun 6, 2022
a5842a7
Add gemm_bais_add_reduce to ckProfiler
rocking5566 Jun 7, 2022
78f28a1
Add c1 functor
rocking5566 Jun 8, 2022
b3812da
Refine type
rocking5566 Jun 9, 2022
aa02705
Use reduceAccDataType instead of explicitly float
rocking5566 Jun 13, 2022
e3976f1
Change to use check_err()
rocking5566 Jun 13, 2022
3e4d275
Do relu in float32 instead of bhalf_t. Because bhalf_t is unsigned
rocking5566 Jun 13, 2022
46eca0a
Refactor relu. using type_trait instead of overloading
rocking5566 Jun 13, 2022
c44818e
Rename DxsReduceAccElementwiseOperation to DxsReduceAccElementwiseOpe…
rocking5566 Jun 13, 2022
edb2f81
Fix denominator
rocking5566 Jun 14, 2022
bf44347
Refine nameing
rocking5566 Jun 14, 2022
dfeade7
Fix denominator in host
rocking5566 Jun 14, 2022
6c0636c
Remove useless include header
rocking5566 Jun 14, 2022
3d091db
Use AccDataType
rocking5566 Jun 14, 2022
94d5f72
Fix static_cast order
rocking5566 Jun 14, 2022
f9d22b0
Refine type
rocking5566 Jun 14, 2022
a76eac2
[What] Remove tuple type in the base class
rocking5566 Jun 15, 2022
537b3bf
Merge branch 'develop' into gemm_add_bias_reduction
rocking5566 Jun 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions example/16_gemm_reduce/gemm_reduce_xdl_mean_squaremean_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ using UnaryDivElementOp =
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, true>;
using UnarySquareElementOp =
ck::tensor_operation::element_wise::UnarySquare<ReduceAccDataType, ReduceAccDataType, false>;
using DxsInElementOp = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOp = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;

using DGlobalMemOp =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
Expand All @@ -67,7 +67,7 @@ using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOp, DxsOutElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on

using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
Expand Down Expand Up @@ -204,8 +204,8 @@ int main(int argc, char* argv[])
auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()));

auto dxs_in_element_op = DxsInElementOp{};
auto dxs_out_element_op = DxsOutElementOp{M, M};
auto dxs_in_element_op = DxsInElementOps{};
auto dxs_out_element_op = DxsOutElementOps{N, N};

// do GEMM
auto gemm = DeviceGemmReduceInstance{};
Expand Down Expand Up @@ -261,14 +261,15 @@ int main(int argc, char* argv[])

for(int m = 0; m < M; ++m)
{
float d0_acc = d0_reduce_op.GetIdentityValue();
float d1_acc = d1_reduce_op.GetIdentityValue();
ReduceAccDataType d0_acc = d0_reduce_op.GetIdentityValue();
ReduceAccDataType d1_acc = d1_reduce_op.GetIdentityValue();

for(int n = 0; n < N; ++n)
{
float c_val = ck::type_convert<float>(c_m_n_host_result(m, n));
float d0_val = 0;
float d1_val = 0;
ReduceAccDataType c_val =
ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
ReduceAccDataType d0_val = 0;
ReduceAccDataType d1_val = 0;

dxs_in_element_op(ck::Number<0>{})(d0_val, c_val);
dxs_in_element_op(ck::Number<1>{})(d1_val, c_val);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ using UnaryIdenticElementOp =
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, false>;
using UnarySquareElementOp =
ck::tensor_operation::element_wise::UnarySquare<ReduceAccDataType, ReduceAccDataType, false>;
using DxsInElementOp = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOp = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>;
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOps = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>;

using DGlobalMemOp =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
Expand All @@ -63,7 +63,7 @@ using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatc
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOp, DxsOutElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on

using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
Expand Down Expand Up @@ -206,8 +206,8 @@ int main(int argc, char* argv[])
a_element_op,
b_element_op,
c_element_op,
DxsInElementOp{},
DxsOutElementOp{},
DxsInElementOps{},
DxsOutElementOps{},
BatchCount);

if(!batched_gemm.IsSupportedArgument(argument))
Expand Down
1 change: 1 addition & 0 deletions example/21_gemm_layernorm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_fp16 gemm_bias_relu_add_layernorm_xdl_fp16.cpp)
add_example_executable(example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp)
Loading