Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
147 commits
Select commit Hold shift + click to select a range
36c38ad
wmma_op + unit test
aska-0096 Oct 21, 2022
7dca846
add arch limitation to wmma test
aska-0096 Oct 21, 2022
049cc8a
change arch limitation
aska-0096 Oct 21, 2022
790e21e
Refactor + Add all type unit test(int4 compile failed)
aska-0096 Oct 28, 2022
24faa1f
Add f32_16x16x16_bf16 unit test
aska-0096 Oct 28, 2022
4fec5ad
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/com…
aska-0096 Oct 28, 2022
ab66332
Merge develop
aska-0096 Nov 7, 2022
98ccb36
tempsave
aska-0096 Nov 16, 2022
d16063d
tempsave
aska-0096 Nov 22, 2022
b3cc22a
tempsave
aska-0096 Nov 24, 2022
9adf2e6
runtime bug, cannot find symbol
aska-0096 Nov 30, 2022
0cd587d
workaround for incorrect HIP warpSize return value
aska-0096 Dec 1, 2022
43a2099
debugging
aska-0096 Dec 2, 2022
7395995
tempsave
aska-0096 Dec 5, 2022
9bd4468
Correctness OK, waiting for optimization
aska-0096 Dec 9, 2022
289f15d
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/com…
aska-0096 Dec 9, 2022
0a80872
Tidy up + format
aska-0096 Dec 9, 2022
9739ede
temp save
aska-0096 Dec 12, 2022
e43df26
temp save, reproduce the v_bfi_b32 issue
aska-0096 Dec 13, 2022
13af8cc
add inline asm for wmmaop test
aska-0096 Dec 13, 2022
63f8766
tidy up
aska-0096 Dec 15, 2022
b741109
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/com…
aska-0096 Dec 15, 2022
2a0e543
clean some debug purpose code
aska-0096 Dec 15, 2022
3941bd1
discard some codes
aska-0096 Dec 15, 2022
cfb397b
clang format
aska-0096 Dec 15, 2022
5d5891b
clang format
aska-0096 Dec 15, 2022
40ec8e5
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/com…
aska-0096 Dec 19, 2022
8efd363
compiler issue fixed + increase tile size
aska-0096 Jan 11, 2023
ccb94ce
navi3x_multipleD+example
aska-0096 Jan 13, 2023
2963dd9
temp save
aska-0096 Jan 16, 2023
c6de88b
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/com…
aska-0096 Jan 16, 2023
07180cb
workable
aska-0096 Jan 18, 2023
abfc94b
batchedgemm[OK], groupconv[debug]
aska-0096 Jan 18, 2023
9c3c435
groupconv: Sanity check[OK], Performance[Bad]
aska-0096 Jan 18, 2023
0517cf0
navi3x_groupconv_need_optimization
aska-0096 Jan 19, 2023
3ddd357
create necessary files
aska-0096 Jan 30, 2023
a0a469e
save progress
aska-0096 Feb 3, 2023
a6b2f1c
Add Inter-Row thread transfer
aska-0096 Feb 9, 2023
5df713e
save progress
aska-0096 Feb 11, 2023
74f0d5d
save debugging progress
aska-0096 Feb 14, 2023
4ddda63
sanity check pass
aska-0096 Feb 16, 2023
27dc055
fix a host tensor bug and clean up flash-attn code
aska-0096 Feb 16, 2023
cc6a534
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/com…
aska-0096 Feb 16, 2023
8dbb73b
format
aska-0096 Feb 16, 2023
f45099c
cancel unnecessary change
aska-0096 Feb 16, 2023
9620dbc
cancel unnecessary change
aska-0096 Feb 16, 2023
c749c26
cancel unnecessary change
aska-0096 Feb 16, 2023
c811a0e
temp save, add asm backend flag to amd_wmma
aska-0096 Feb 16, 2023
d4adc71
Mat-A LDS Bypass sanity pass
aska-0096 Feb 24, 2023
6a9d7b6
temp save
aska-0096 Feb 27, 2023
84b4ada
gemm sanity fix
aska-0096 Feb 27, 2023
7e003d3
Porting new blockwise gemm to flash attention
aska-0096 Feb 28, 2023
fbc576b
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/com…
aska-0096 Feb 28, 2023
a045e0b
Example branch provide to compiler team
aska-0096 Feb 28, 2023
579f84c
tempsave
aska-0096 Mar 6, 2023
bdd0f64
Fix a bug
aska-0096 Mar 6, 2023
686212e
Merge branch 'lds_bypass_spilling' into lds_option_passthrough
aska-0096 Mar 6, 2023
a38ce02
batched gemm ported
aska-0096 Mar 6, 2023
f00dab9
conv A-skip lds ported
aska-0096 Mar 6, 2023
04c6a97
Skip B-Lds real gemm
aska-0096 Mar 6, 2023
060c4f3
Skip B Lds Gemm + MulD
aska-0096 Mar 6, 2023
708fd81
batched gemm, conv, skip b lds
aska-0096 Mar 6, 2023
6e28a8a
format
aska-0096 Mar 6, 2023
c5fd087
Attn, skip b lds
aska-0096 Mar 6, 2023
8e862b7
Change GridwiseOp nam
aska-0096 Mar 7, 2023
e330961
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/com…
aska-0096 Mar 7, 2023
9fb64da
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/com…
aska-0096 Mar 23, 2023
a469434
fix a typo caused bug
aska-0096 Mar 23, 2023
dc8309d
Skip A_Lds sanity pass, Skip B_Lds scratch occured
aska-0096 Mar 23, 2023
0583005
Bug found, intra-row permute off caused
aska-0096 Mar 23, 2023
44be643
bug found
aska-0096 Mar 23, 2023
b8e153a
a fix
aska-0096 Mar 27, 2023
0f1fca4
disable buffer load due to incorrect 3rd dword
aska-0096 Mar 27, 2023
31ca2f4
update fmha config, no scratch generated
aska-0096 Mar 29, 2023
23ad945
Merge branch 'e2e_kernellib' of https://github.com/aska-0096/navi3x_c…
aska-0096 Mar 29, 2023
82fef9e
update 3rd dword
aska-0096 Mar 29, 2023
5e30377
fmha config update
aska-0096 Apr 7, 2023
2c265eb
FMHA, add support to gfx1101/gfx1102
aska-0096 Apr 14, 2023
87ae9b7
Merge branch 'e2e_kernellib' of https://github.com/aska-0096/navi3x_c…
aska-0096 Apr 19, 2023
a29d2b7
Merge pull request #1 from aska-0096/e2e_v2
aska-0096 Apr 19, 2023
cad3212
Merge origin dev (#2)
aska-0096 Apr 19, 2023
a0058be
Disable SkipLDS & Align AIT api (#3)
aska-0096 Apr 20, 2023
394dbf8
fix layernorm, reduction Ops (#4)
aska-0096 Apr 21, 2023
bddc3af
fix typo
aska-0096 Apr 21, 2023
f677f70
Fix attention with causal mask
aska-0096 Apr 22, 2023
9e1091c
multiple fix, try ait compile
aska-0096 Apr 23, 2023
6e2c615
Add A/B not use LDS pipeline
aska-0096 Apr 27, 2023
d676da8
Clang format, Add gfx1101, gfx1102 support of FMHA example
aska-0096 Apr 27, 2023
716860e
cancel change of format script
aska-0096 Apr 27, 2023
0bb08f4
1. Enable 2-stage global Prefetch ( May cause VGPR spilling)
aska-0096 May 10, 2023
2f88070
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/com…
aska-0096 May 10, 2023
5bf77d8
clang-format
aska-0096 May 10, 2023
2ec3f4c
1. change blockwise gemm loopover direction from kmn to mnk ( ~1% imp…
aska-0096 May 18, 2023
c713d22
Update low level abstration of blockwise gemm wmma
aska-0096 May 19, 2023
3ccfb0a
(2/5) bilinear gemm pass, perf bug: skip a lds has lower performance …
aska-0096 May 19, 2023
12a4ea6
(3/5) batched gemm pass, perf bug: skip a lds has lower performance t…
aska-0096 May 19, 2023
fd4ff3a
(4/5) grouped conv pass
aska-0096 May 19, 2023
bee4e34
(5/5) attention pass, todo: debug lds perf bug
aska-0096 May 19, 2023
efee454
AIT Attention API refactor (#8)
aska-0096 Jun 13, 2023
e305e41
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/com…
aska-0096 Jun 13, 2023
823c880
Merge branch 'e2e_kernellib' of https://github.com/aska-0096/navi3x_c…
aska-0096 Jun 13, 2023
d44f666
deprecate inline asm wmma
aska-0096 Jun 13, 2023
6c1aa33
Bug fix: double lds skip
aska-0096 Jun 13, 2023
83d926d
clang-format
aska-0096 Jun 15, 2023
4377795
Fix errors in
aska-0096 Jun 19, 2023
b010b09
part2 of previous commit
aska-0096 Jun 19, 2023
35e5c53
clang format
aska-0096 Jun 19, 2023
b377063
API fix of gridwisegemmpipeline
aska-0096 Jun 20, 2023
8053bca
separate array base and vector base attention tensor transformation
aska-0096 Jun 20, 2023
fd9e80c
fix gemm
aska-0096 Jun 25, 2023
1fb4a47
clang format
aska-0096 Jun 26, 2023
6e6c535
add gemm fp16 instances
aska-0096 Jul 7, 2023
febd76e
Temp save
aska-0096 Jul 20, 2023
0c51a35
fpAintB kernel compile pass
aska-0096 Jul 25, 2023
66e6107
Sanity pass.
aska-0096 Jul 28, 2023
32bac6f
Temp save
aska-0096 Aug 1, 2023
5cf73a5
debug code enabled
aska-0096 Aug 3, 2023
b5083bf
Fp16AInt8B_GEMM sanity
aska-0096 Aug 3, 2023
c3cba9c
Merge pull request #827 from ROCmSoftwarePlatform/fpAintB_clear
aska-0096 Aug 3, 2023
73e475d
MQA implementation
aska-0096 Aug 7, 2023
3cf4572
Merge pull request #832 from ROCmSoftwarePlatform/MQA
aska-0096 Aug 7, 2023
b2d5cf8
GQA-4 example
aska-0096 Aug 8, 2023
3ba0f0d
Merge pull request #833 from ROCmSoftwarePlatform/GQA
aska-0096 Aug 8, 2023
d1894bd
tempsave
aska-0096 Aug 9, 2023
061009a
Compile pass
aska-0096 Aug 15, 2023
bf75259
New implementation of fp16Aint8B Gemm, Acheieve similar math throughp…
aska-0096 Aug 16, 2023
cc0ffeb
Merge pull request #851 from ROCmSoftwarePlatform/perf_opt_fpAintB
aska-0096 Aug 16, 2023
2724c51
merge develop
Feb 24, 2024
809d7df
format
Feb 24, 2024
4fe4969
Merge branch 'kaba' into navi3_rel
aska-0096 Feb 26, 2024
18d5297
Todo: fix gemm_bilinear_wmma instances compilation bug
aska-0096 Feb 26, 2024
4c102fc
Solve a bug when K1=16
aska-0096 Feb 27, 2024
924639f
remove unnecessary changes
aska-0096 Feb 27, 2024
2ab0d8f
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/com…
aska-0096 Feb 27, 2024
b62926d
Remove tensor layout limitation to LDS usage in tesnor contraction
aska-0096 Feb 27, 2024
8a6e65a
update self-attention and cross-attention
aska-0096 Feb 28, 2024
08ab9cf
fix a typo of name
aska-0096 Feb 28, 2024
0c0ddef
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/com…
aska-0096 Feb 28, 2024
8ae8a55
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/com…
aska-0096 Feb 29, 2024
6845924
Add arch limiter for fp8 gemm
aska-0096 Feb 29, 2024
e42f9ec
enable fp8 gemm_xdl for all gfx9 targets
illsilin Mar 8, 2024
32371ea
Merge branch 'develop' into navi3_rel
illsilin Mar 8, 2024
7b28bcb
temporarily disable gemm_xdl_fp16_fp8 on MI100/200
illsilin Mar 8, 2024
1012795
Merge branch 'navi3_rel' of github.com:ROCm/composable_kernel into na…
illsilin Mar 8, 2024
91ee125
fix the cmake logic for gemm_xdl_fp16_fp8
illsilin Mar 8, 2024
13303ed
Merge branch 'develop' into navi3_rel
illsilin Mar 8, 2024
56a6723
re-enable the gemm_xdl_fp16_fp8 on MI100/200
illsilin Mar 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions example/01_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)

add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
if(GPU_TARGETS MATCHES "gfx11")
add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
Expand All @@ -53,12 +53,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64)

add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp)

add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)

add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)

list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
Expand All @@ -72,5 +66,12 @@ foreach(gpu IN LISTS GPU_TARGETS)
endif()
endforeach()

add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)

add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)

add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)

47 changes: 41 additions & 6 deletions example/01_gemm/gemm_wmma_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,50 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;

static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;

// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmMNKPadding, 256, 128, 256, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, 1>;
< ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
CElementOp,
GemmDefault,
1, // Prefetch stage
128, // BlockSize
64, // MPerBlock
128, // NPerBlock
64, // KPerBlock
8, // K1
16, // MPerWmma
16, // NPerWmma
2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store
S<1, 32, 1, 4>,
8>;
// clang-format on

using ReferenceGemmInstance = ck::tensor_operation::host::
Expand Down
16 changes: 16 additions & 0 deletions example/01_gemm/run_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,22 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
case 2:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
break;
case 3:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
case 4:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n);
break;
case 5:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-2.f, 2.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-2.f, 2.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-0.1f, 0.1f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-0.1f, 0.1f}(b_k_n);
Expand Down
87 changes: 44 additions & 43 deletions example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,48 +65,49 @@ using CDEElementOp = AlphaBetaAdd;

static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;

using DeviceOpInstance =
ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle<ALayout,
BLayout,
ck::Tuple<DLayout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<DDataType>,
EDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmSpec,
256,
128,
256,
8,
8,
16,
16,
4,
4,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1,
1,
S<1, 32, 1, 8>,
8>;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle<
ALayout,
BLayout,
ck::Tuple<DLayout>,
ELayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<DDataType>,
EDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmSpec,
2, // Prefetch stage
128, // BlockSize
128, // MPerBlock
64, // NPerBlock
64, // KPerBlock
8, // K1
16, // MPerWmma
16, // NPerWmma
4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store
S<1, 32, 1, 4>,
8>;

int main(int argc, char* argv[])
{
Expand Down Expand Up @@ -264,7 +265,7 @@ int main(int argc, char* argv[])
float gb_per_sec = num_btype / 1.E6 / ave_time;

std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
<< device_op.GetTypeString() << std::endl;

e_device_buf.FromDevice(e_m_n_device_result.mData.data());

Expand Down
87 changes: 44 additions & 43 deletions example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ using DDataType = I8;
using EDataType = I8;

using ALayout = Row;
using BLayout = Row;
using BLayout = Col;
using DLayout = Row;
using ELayout = Row;

Expand All @@ -65,48 +65,49 @@ using CDEElementOp = AlphaBetaAdd;

static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;

using DeviceOpInstance =
ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle<ALayout,
BLayout,
ck::Tuple<DLayout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<DDataType>,
EDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmSpec,
32,
16,
16,
4,
16,
16,
16,
1,
1,
S<2, 16, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
16,
16,
1,
S<4, 1, 8>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
16,
2,
1,
1,
1,
S<1, 16, 1, 2>,
8>;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle<
ALayout,
BLayout,
ck::Tuple<DLayout>,
ELayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<DDataType>,
EDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmSpec,
2, // Prefetch stage
128, // BlockSize
128, // MPerBlock
64, // NPerBlock
64, // KPerBlock
8, // K1
16, // MPerWmma
16, // NPerWmma
4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store
S<1, 32, 1, 4>,
8>;

int main(int argc, char* argv[])
{
Expand Down
2 changes: 1 addition & 1 deletion example/29_batched_gemm_bias_e_permute/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp)

if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
if(GPU_TARGETS MATCHES "gfx11")
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)
endif()
Loading