From 2586ce1f0cfa9dea818bc1ec48a201c92368969f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 18 May 2022 18:11:49 +0900 Subject: [PATCH 1/4] [TIR] Support tensorization using ldmatrix + MMA commit 3218facf100b0dfc55715acfd1cee156764129ba Author: Masahiro Masuda Date: Wed May 18 14:04:56 2022 +0900 some clean up commit 7a235b69dc2023b3098ed44d591edb63b20a8f4e Author: Masahiro Masuda Date: Wed May 18 13:55:11 2022 +0900 parameterize over storage scope in mma store intrin commit 827ea4c434c35607b241f8e0ae2efe3214ac2458 Author: Masahiro Masuda Date: Wed May 18 13:37:38 2022 +0900 properly handle floordiv/mod in codegen commit 42d4c6f42182c9fd79566c0955f99cc82abd5144 Author: Masahiro Masuda Date: Wed May 18 09:53:57 2022 +0900 update tuned factors for fp16 commit 328d0aa36b2ea9ea1b051970d612bff82d2d20e6 Author: Masahiro Masuda Date: Wed May 18 08:43:30 2022 +0900 all tests working commit 5e086cf5fd1404ac38f85c4bfbe692687b45a16c Author: Masahiro Masuda Date: Wed May 18 07:48:43 2022 +0900 add doc for mma_fill and mma_store intrin commit 4f945c4116b6d3bdc965ecb2be2229bb46dc11ab Author: Masahiro Masuda Date: Wed May 18 06:39:01 2022 +0900 remove tests commit df7708f7f67761d9c18f9564bc15abd50c12ac69 Author: Masahiro Masuda Date: Tue May 17 19:52:14 2022 +0900 unified test commit 754c83eeb8510b31fb9652b089177f9b8e642ec0 Author: Masahiro Masuda Date: Tue May 17 19:36:24 2022 +0900 clean up LowerWarpmemory commit 178c3dcee7bfa17d5d93fec02aa858dc62151670 Author: Masahiro Masuda Date: Tue May 17 19:15:04 2022 +0900 Use IndexMap commit 07fb58910338c62847fd902b37801d09b8c673b0 Author: Masahiro Masuda Date: Tue May 17 17:51:44 2022 +0900 remove 16x8x8 test commit 2b05b5a5470ac221d559f31a31a8e2ff753b2414 Author: Masahiro Masuda Date: Tue May 17 17:31:35 2022 +0900 generate mma fill/store commit bf23fc50f0ffa99e875d9247ca66acec0c36677f Author: Masahiro Masuda Date: Tue May 17 12:23:30 2022 +0900 mma intrin generation with meta programming commit 5afb5f00afd642cb1e39872edc7965f476dcdcb7 Author: Masahiro Masuda Date: Tue May 17 05:26:14 2022 +0900 ldmatrix intrin generation with meta programming commit fb62abb3424b88ec48c697e306e05889a3ac306f Author: Masahiro Masuda Date: Mon May 16 20:30:49 2022 +0900 minor commit 5a80adce24e84d3ec6bf931b60cb9c730d243394 Author: Masahiro Masuda Date: Mon May 16 19:55:57 2022 +0900 revert some change commit e599a55078ee75f2480a721098341812db58cf6f Author: Masahiro Masuda Date: Mon May 16 19:54:18 2022 +0900 remove obsolete files commit 4b13b85ff91d0d592a7e0c01924e0b49b82f35a8 Author: Masahiro Masuda Date: Mon May 16 19:51:21 2022 +0900 wip commit 848de63455539e25cd0d43e5a65fd048636ef0f7 Author: Masahiro Masuda Date: Mon May 16 19:44:29 2022 +0900 wip commit b35bff97ed10c22559e2164eb7538db0f711ce7e Author: Masahiro Masuda Date: Mon May 16 19:31:18 2022 +0900 update parse error msg commit ad9b053ef865b1f91f03d7b15ed7aae3420ee213 Author: Masahiro Masuda Date: Mon May 16 19:26:51 2022 +0900 fix for avoiding Buffer.vload(...) case commit 54c686443e370edbfae860d0809b1b6182d26414 Author: Masahiro Masuda Date: Mon May 16 18:59:55 2022 +0900 wip commit 078060fe28d22f1db5f07b1c382dee438f02df60 Author: Masahiro Masuda Date: Mon May 16 18:57:34 2022 +0900 wip commit 576f8415e65e0e8a8a7808885e219b3b53867950 Author: Masahiro Masuda Date: Mon May 16 18:52:15 2022 +0900 wip commit 12a376ae2f44aa6660121e64e0358f2866624f7f Author: Masahiro Masuda Date: Mon May 16 17:54:58 2022 +0900 Squashed commit of the following: commit 48eef4981d1a55aaf3b0ac935f2a10347cb1ac2d Author: Masahiro Masuda Date: Mon May 16 17:40:48 2022 +0900 more comment commit 8f67fc87038834e9f7e2c5cd3dfe61fabf442206 Author: Masahiro Masuda Date: Mon May 16 17:11:27 2022 +0900 update test commit ad85036621c005b733763e67ceffae39c356ec99 Author: Masahiro Masuda Date: Mon May 16 16:54:01 2022 +0900 add test commit 4a5dc3ffd5d0bb4a1700e57897c9e0f26e3d2a88 Author: Masahiro Masuda Date: Mon May 16 16:40:47 2022 +0900 [TVMScript] Support function call to help construct AST commit 76c1bcf0ade45d7433a0066236add8372b1cc547 Author: Masahiro Masuda Date: Mon May 16 16:30:07 2022 +0900 simplify iterator in layout transform commit 936280324ea2c91429a6a85a1b8ee89c7b825928 Author: Masahiro Masuda Date: Sat May 14 11:31:39 2022 +0900 remove obsolet files commit 2e119b422d72d726d5f2bd20fe48a1e62fcb0510 Author: Masahiro Masuda Date: Sat May 14 10:43:59 2022 +0900 calculate mma store dst index using inverse affine map commit 9489434ee52b546e2abb2ab28173eefd51525ba4 Author: Masahiro Masuda Date: Sat May 14 10:01:12 2022 +0900 simplify store commit 1adcb77b8bba8e5d91080fe6cbfc7add7f4365c2 Author: Masahiro Masuda Date: Sat May 14 09:43:40 2022 +0900 simplified fill commit 7b13c736d23e0eac94137aa918101d788e60d4f3 Author: Masahiro Masuda Date: Sat May 14 09:22:17 2022 +0900 simplify intrin desc using index map function commit bcf212dda0f94c51f55c48921f61d92fd3b83777 Author: Masahiro Masuda Date: Sat May 14 07:16:42 2022 +0900 seems to work commit dd8ccf9ec2e48100158152e5d4590d141424e2e2 Author: Masahiro Masuda Date: Sat May 14 07:11:57 2022 +0900 poking with the parser commit 596582cbfbd08ebe23ea71aaf7a447472415ccd1 Author: Masahiro Masuda Date: Fri May 13 20:04:59 2022 +0900 16x8x32 4k trans working commit 273f89a8a6ac34f7c79147563922d34d44bffd08 Author: Masahiro Masuda Date: Fri May 13 19:52:13 2022 +0900 add 16x8x16 fp16 trans commit 8e2066cc4c6e86616bc9751324e63ba81a3b02af Author: Masahiro Masuda Date: Fri May 13 19:32:37 2022 +0900 16x8x16 4k trans working commit c2d0744051733e94f840d4517bcee9ca5d444c75 Author: Masahiro Masuda Date: Fri May 13 19:25:52 2022 +0900 16x8x16 trans working commit c2e314cdda1c3a931781e51a863901ea178dffec Author: Masahiro Masuda Date: Fri May 13 16:19:32 2022 +0900 tuned int8 4k, 91 TOPS commit 94d9d965f19ff1a2ebdd342079ef420fb537b16a Author: Masahiro Masuda Date: Fri May 13 15:59:33 2022 +0900 int8 4k tune working commit 3ca8ca02593aff7540c9655aa831348246171752 Author: Masahiro Masuda Date: Fri May 13 08:43:57 2022 +0900 mma 16x8x32 int8 working with ldmatrix b workaround commit 54f1cb731d4b42a6cbc08baf144e74646400eef5 Author: Masahiro Masuda Date: Fri May 13 18:23:27 2022 +0900 wip commit 9d2844db602dc65af4dbd06a73fdd815f486b8b9 Author: Masahiro Masuda Date: Fri May 13 16:38:53 2022 +0900 test tensorize without layout transform commit 86ee6dabc801aeb8d6917bec6de97b42025dbdd1 Author: Masahiro Masuda Date: Fri May 13 15:15:34 2022 +0900 int8 4k tensorize works commit 39f9e32c9a64222c91daba2c32969b27207a31d2 Author: Masahiro Masuda Date: Fri May 13 12:44:39 2022 +0900 begin int8 4k tune commit 6fa91e55b5ab2ba0f901d0d35be1b2fb3ab092b0 Author: Masahiro Masuda Date: Thu May 12 18:53:20 2022 +0900 try fix ldmatrix b for int8 commit 7a962cddc4799fa3df0c0fdf3c056146d3f2cbdf Author: Masahiro Masuda Date: Thu May 12 18:28:34 2022 +0900 fixed warp_coeff commit a0afb5698f307382147a38819e004a2db7f554b1 Author: Masahiro Masuda Date: Thu May 12 12:20:01 2022 +0900 wip commit f70ccd09b07d5325454ffdc39a7619ea84aa7e06 Author: Masahiro Masuda Date: Thu May 12 12:09:57 2022 +0900 int8 tensorize working commit 20321fa4674dabc78fe55b5e0e2876c35b245d21 Author: Masahiro Masuda Date: Thu May 12 07:06:22 2022 +0900 starting 16x8x32 int8 commit 441fd193c59cdc436d87ab35896cbb8c779ddf35 Author: Masahiro Masuda Date: Thu May 12 05:50:46 2022 +0900 adding fp16 accum case commit c9d40b69b1b57bfaddffba09ea07624ae90ee465 Author: Masahiro Masuda Date: Wed May 11 17:04:29 2022 +0900 clean up commit 5b2d48635e762c77c824d1c259ac8bcbcc949421 Author: Masahiro Masuda Date: Wed May 11 16:38:19 2022 +0900 16x8x16 4k tune working commit c3cb170d85600d03da5c3f4cda03552208ca0b8c Author: Masahiro Masuda Date: Wed May 11 16:20:27 2022 +0900 tensoriz fixed commit 68039b081efcdd6aea1d132940b3745f50164974 Author: Masahiro Masuda Date: Wed May 11 15:55:25 2022 +0900 begin 16x8x16 4k tune commit ced5d8d980cc267d4735957c25cb60d71ae977d2 Author: Masahiro Masuda Date: Wed May 11 15:50:11 2022 +0900 16x8x16 worked commit 3d2c90d77c1bb2df2193e9af6cbaa2bd927a26d8 Author: Masahiro Masuda Date: Wed May 11 15:47:26 2022 +0900 fix commit 403050b03ad6b4f0ee8d45088ffb324727bbae48 Author: Masahiro Masuda Date: Wed May 11 15:45:10 2022 +0900 add 16x8x16 test commit 18e8d73661c99cd1c83021063b41a457afcb1638 Author: Masahiro Masuda Date: Wed May 11 06:50:32 2022 +0900 fixed mma store codegen for 16x8x16 commit ec81250561195705122bccb9a2372f71de68121f Author: Masahiro Masuda Date: Wed May 11 04:25:25 2022 +0900 add 16x8x16 mma store codegen commit e08df2a62a4809bcd39782949283c16e7703aa5c Author: Masahiro Masuda Date: Wed May 11 03:47:47 2022 +0900 tensorized C_warp init commit ae0678918929c1ceec73f2039467040c5bb7823b Author: Masahiro Masuda Date: Wed May 11 03:06:06 2022 +0900 mma store codegen working commit deb4d6646cc93d4cdb4f2560ce723bee4d86e144 Author: Masahiro Masuda Date: Tue May 10 19:22:57 2022 +0900 update lower warp memory commit 71fe5fe465300705fa94f9544a2e1a5070de6e0d Author: Masahiro Masuda Date: Tue May 10 09:01:42 2022 +0900 tensorizing mma store commit e80a1f148c47f2a3fac2363a733d8d4e2a2631d0 Author: Masahiro Masuda Date: Thu Apr 28 19:54:08 2022 +0900 clean up commit a9640f4b7c3c9f22b87ca74a61003438dfd8f992 Author: Masahiro Masuda Date: Thu Apr 28 19:40:55 2022 +0900 add tunable 4k test, 36 TFLOPS commit b9f7eae7041d1a9b3e434c331c874e8347e89dc4 Author: Masahiro Masuda Date: Thu Apr 28 18:01:08 2022 +0900 fixed bug in LowerWarpMemory index splitting for ldmatrix commit 00df30823f874910ed1ec1f74718100311764234 Author: Masahiro Masuda Date: Wed Apr 27 07:58:17 2022 +0900 fixed missing reverse_compute_at commit 93f9fe7e5f7ad16c8d0e6240c16c0281a0e97dec Author: Masahiro Masuda Date: Wed Apr 27 06:55:12 2022 +0900 add 4k test commit 3689ef712aa4b282a4818fa2fa2e7e349c3a5eec Author: Masahiro Masuda Date: Wed Apr 27 06:54:09 2022 +0900 temp disable high dim base indices check in tensorize commit 0c859c4f385ba0b6f9477b569b80cee80b5b7282 Author: Masahiro Masuda Date: Tue Apr 26 19:18:23 2022 +0900 clean up commit f6aadbfcfbd73c1667a6de7aedc5894232b8e750 Author: Masahiro Masuda Date: Tue Apr 26 19:13:09 2022 +0900 Add 16x8x8 MMA + LDMatrix test commit 4cf6b20c6ca415e967ab58d80e4a77c701ad7255 Author: Masahiro Masuda Date: Tue Apr 26 18:04:17 2022 +0900 testing 16x8x8 ldmatrix tensoriation --- include/tvm/tir/builtin.h | 27 + python/tvm/tir/tensor_intrin/__init__.py | 1 + python/tvm/tir/tensor_intrin/cuda.py | 469 ++++++++++++++++++ src/target/source/codegen_cuda.cc | 76 ++- src/tir/op/builtin.cc | 6 + src/tir/transforms/lower_warp_memory.cc | 45 +- ...est_tir_schedule_tensorize_ldmatrix_mma.py | 424 ++++++++++++++++ 7 files changed, 1044 insertions(+), 4 deletions(-) create mode 100644 python/tvm/tir/tensor_intrin/cuda.py create mode 100644 tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index f33432645cc3..5fc42392c337 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -651,6 +651,33 @@ TVM_DLL const Op& ptx_cp_async(); TVM_DLL const Op& ptx_commit_group(); TVM_DLL const Op& ptx_wait_group(); +/*! + * \brief tvm intrinsic for storing the result of PTX MMA into a destination pointer. + * For example, if each thread in a warp of size 32 has 4 elements from the result of + * m16xn8xk16 MMA in its registers, this intrinsic can be used to store the result in a + * 16x8 region in shared or global memory. + * + * There is no real PTX instruction that does that, but we want to hide details of + * complex index manipulation behind this intrinsic to simplify TIR lowering passes (e.g. + * LowerWarpMemory). + * + * void mma_store(IntImm m, IntImm n, Var dst_ptr, Var src_ptr, Expr src_offset, Var dst_stride); + */ +TVM_DLL const Op& mma_store(); + +/*! + * \brief tvm intrinsic for zero-initalizing an MMA accumulation registor. + * For example, if each thread in a warp of size 32 has 8 elements from the A matrix in + * m16xn8xk16 MMA in its registers, this intrinsic can be used to zero-initialize its + * 4 accumulation registers. + * + * There is no real PTX instruction that does that, but we introduce this intrinsic for the + * same reason as mma_store above. + * + * void mma_fill(IntImm local_size, Var local_ptr, Expr offset); + */ +TVM_DLL const Op& mma_fill(); + // TODO(tvm-team) replace the usage of the vector operations by Shuffle. /*! * \brief Get the high level half of the vector diff --git a/python/tvm/tir/tensor_intrin/__init__.py b/python/tvm/tir/tensor_intrin/__init__.py index 4115c3b90070..a3b47ff6d5d7 100644 --- a/python/tvm/tir/tensor_intrin/__init__.py +++ b/python/tvm/tir/tensor_intrin/__init__.py @@ -20,3 +20,4 @@ from .arm_cpu import * from .dot_product_common import * from .rocm import * +from .cuda import * diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py new file mode 100644 index 000000000000..853a37735486 --- /dev/null +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -0,0 +1,469 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,missing-function-docstring +"""Intrinsics for tensorization on NVIDIA GPU.""" +from tvm.script import tir as T +from .. import IntImm, Cast +from ..._ffi import register_func +from ...runtime import convert +from .. import TensorIntrin + + +def shared_16x16_to_ldmatrix_32x8_layout(i, j): + thread_id = 4 * (i % 8) + (j % 8) // 2 + return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2) + + +def shared_16x32_to_ldmatrix_32x16_layout(i, j): + thread_id = 4 * (i % 8) + (j % 16) // 4 + return thread_id, 8 * (j // 16) + (i // 8) * 4 + j % 4 + + +def shared_32x16_to_ldmatrix_32x16_layout(i, j): + thread_id = (i % 4) + 4 * (j % 8) + return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4 + + +@register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout") +def index_map_shared_16x16_to_ldmatrix_32x8_layout(ind): + i, j = ind[0], ind[1] + thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i, j) + return convert([thread_id, local_id]) + + +lift = convert + +M_DIM = 16 +N_DIM = 16 +WARP_SIZE = 32 +HALF_WARP = WARP_SIZE // 2 +HALF_WARP_expr = lift(HALF_WARP) + + +def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed): + local_size = (M_DIM * k_dim) // WARP_SIZE + shared_offset = None + index_map = None + + if transposed: + assert is_b, "Transposed A matrix not supported" + + ldmatrix_col_major = is_b and not transposed + + if k_dim == 16: + assert dtype == "float16" + + index_map = shared_16x16_to_ldmatrix_32x8_layout + + if transposed: + shared_offset = ( + lambda tx, stride: stride * 8 * (tx // HALF_WARP_expr) + + stride * (tx % 8) + + 8 * ((tx % HALF_WARP_expr) // 8) + ) + else: + shared_offset = lambda tx, stride: stride * (tx % HALF_WARP_expr) + 8 * ( + tx // HALF_WARP_expr + ) + else: + assert ( + k_dim == 32 and dtype == "int8" + ), "Only k_dim == 16 (float16) or k_dim == 32 (int8) supported for now" + + if ldmatrix_col_major: + index_map = shared_32x16_to_ldmatrix_32x16_layout + # A dummy offset, ldmatrix cannot be used for int8 + trans case. + # We still use the ldmatrix intrinsic, but lower it to a manual loop in the codegen. + # Only the stride information is required. + shared_offset = lambda _, stride: stride + elif is_b and transposed: + index_map = shared_16x32_to_ldmatrix_32x16_layout + shared_offset = ( + lambda tx, stride: stride * 8 * (tx // HALF_WARP_expr) + + (tx % 8) * stride + + 16 * ((tx % HALF_WARP_expr) // 8) + ) + else: + index_map = shared_16x32_to_ldmatrix_32x16_layout + shared_offset = lambda tx, stride: stride * (tx % 16) + 16 * (tx // 16) + + assert index_map and shared_offset + + if is_b and not transposed: + row_dim = k_dim + col_dim = M_DIM + else: + row_dim = M_DIM + col_dim = k_dim + + shmem_shape = (row_dim, col_dim) + + @T.prim_func + def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None: + shared = T.match_buffer( + shared_handle, shmem_shape, dtype, align=128, offset_factor=16, scope="shared" + ) + warp = T.match_buffer( + warp_handle, (WARP_SIZE, local_size), dtype, align=128, offset_factor=16, scope="warp" + ) + + with T.block("root"): + T.reads(shared[0:row_dim, 0:col_dim]) + T.writes(warp[0:WARP_SIZE, 0:local_size]) + + for ax0, ax1 in T.grid(row_dim, col_dim): + with T.block("shared_warp"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(shared[v0, v1]) + + thread_id, local_id = index_map(v0, v1) + T.writes(warp[thread_id, local_id]) + warp[thread_id, local_id] = shared[v0, v1] + + @T.prim_func + def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None: + s0 = T.var("int32") + s1 = T.var("int32") + shared = T.match_buffer( + shared_handle, + shmem_shape, + dtype, + align=128, + offset_factor=16, + scope="shared", + strides=[s0, s1], + ) + warp = T.match_buffer( + warp_handle, (WARP_SIZE, local_size), dtype, align=128, offset_factor=16, scope="warp" + ) + + with T.block("root"): + T.reads(shared[0:row_dim, 0:col_dim]) + T.writes(warp[0:WARP_SIZE, 0:local_size]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, WARP_SIZE) + + T.evaluate( + T.ptx_ldmatrix( + ldmatrix_col_major, + 4, # Always load 4 matrices + ".b16", + warp.data, + warp.elem_offset + lift(local_size) * tx, + shared.access_ptr("r"), + shared_offset(tx, s0), + dtype=dtype, + ) + ) + + return ldmatrix_desc, ldmatrix_impl + + +def get_mma_intrin(k_dim, out_dtype, b_transposed): + local_size = (M_DIM * k_dim) // WARP_SIZE + local_size_out = (M_DIM * N_DIM) // 32 + + index_map_C = shared_16x16_to_ldmatrix_32x8_layout + + if k_dim == 16: + index_map_A = shared_16x16_to_ldmatrix_32x8_layout + index_map_B = shared_16x16_to_ldmatrix_32x8_layout + mma_prefix = "m16n8k16" + elif k_dim == 32 and b_transposed: + index_map_A = index_map_B = shared_16x32_to_ldmatrix_32x16_layout + mma_prefix = "m16n8k32" + elif k_dim == 32 and not b_transposed: + index_map_A = shared_16x32_to_ldmatrix_32x16_layout + index_map_B = shared_32x16_to_ldmatrix_32x16_layout + mma_prefix = "m16n8k32" + else: + assert False + + out_dtype_abbrv = {"float16": "fp16", "float32": "fp32", "int32": "int32"}[out_dtype] + + if out_dtype in ["float16", "float32"]: + in_dtype = "float16" + in_dtype_abbrv = "fp16" + else: + in_dtype = "int8" + in_dtype_abbrv = "int8" + + def maybe_cast(v): + if out_dtype in ["float32", "int32"]: + return Cast(out_dtype, v) + return v + + def maybe_swap(i, j): + if b_transposed: + return j, i + return i, j + + @T.prim_func + def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer( + a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp" + ) + B = T.match_buffer( + b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp" + ) + C = T.match_buffer( + c, (WARP_SIZE, local_size_out), out_dtype, align=128, offset_factor=16, scope="warp" + ) + + with T.block("root"): + T.reads( + C[0:WARP_SIZE, 0:local_size_out], + A[0:WARP_SIZE, 0:local_size], + B[0:WARP_SIZE, 0:local_size], + ) + T.writes(C[0:WARP_SIZE, 0:local_size_out]) + + for i, j, k in T.grid(M_DIM, N_DIM, k_dim): + with T.block("C"): + i, j, k = T.axis.remap("SSR", [i, j, k]) + b_row_ind, b_col_ind = maybe_swap(k, j) + + thread_id_C, local_id_C = index_map_C(i, j) + thread_id_A, local_id_A = index_map_A(i, k) + thread_id_B, local_id_B = index_map_B(b_row_ind, b_col_ind) + + T.reads( + C[thread_id_C, local_id_C], + A[thread_id_A, local_id_A], + B[thread_id_B, local_id_B], + ) + T.writes(C[thread_id_C, local_id_C]) + + C[thread_id_C, local_id_C] += maybe_cast( + A[thread_id_A, local_id_A] + ) * maybe_cast(B[thread_id_B, local_id_B]) + + @T.prim_func + def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer( + a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp" + ) + B = T.match_buffer( + b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp" + ) + C = T.match_buffer( + c, (WARP_SIZE, local_size_out), out_dtype, align=128, offset_factor=16, scope="warp" + ) + + with T.block("root"): + T.reads( + C[0:WARP_SIZE, 0:local_size_out], + A[0:WARP_SIZE, 0:local_size], + B[0:WARP_SIZE, 0:local_size], + ) + T.writes(C[0:WARP_SIZE, 0:local_size_out]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, WARP_SIZE) + + T.evaluate( + T.ptx_mma( + mma_prefix, + "row", + "col", + in_dtype_abbrv, + in_dtype_abbrv, + out_dtype_abbrv, + A.data, + A.elem_offset + tx * lift(local_size), + B.data, + B.elem_offset + tx * lift(local_size), + C.data, + C.elem_offset + tx * lift(local_size_out), + False, + dtype=out_dtype, + ) + ) + + T.evaluate( + T.ptx_mma( + mma_prefix, + "row", + "col", + in_dtype_abbrv, + in_dtype_abbrv, + out_dtype_abbrv, + A.data, + A.elem_offset + tx * lift(local_size), + B.data, + B.elem_offset + tx * lift(local_size) + lift(local_size) // 2, + C.data, + C.elem_offset + tx * lift(local_size_out) + lift(local_size_out) // 2, + False, + dtype=out_dtype, + ) + ) + + return mma_sync_desc, mma_sync_impl + + +def get_mma_fill_intrin(dtype, local_size): + zero = IntImm("int32", 0).astype(dtype) + + # Assume M = N = 16 + index_map = shared_16x16_to_ldmatrix_32x8_layout + + @T.prim_func + def mma_fill_desc(a: T.handle) -> None: + C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, scope="warp") + + with T.block("root"): + T.reads() + T.writes(C_warp[0:WARP_SIZE, 0:local_size]) + for i0, i1 in T.grid(M_DIM, N_DIM): + with T.block("C_warp"): + i, j = T.axis.remap("SS", [i0, i1]) + thread_id, local_id = index_map(i, j) + T.reads() + T.writes(C_warp[thread_id, local_id]) + C_warp[thread_id, local_id] = zero + + @T.prim_func + def mma_fill_impl(a: T.handle) -> None: + C_warp = T.match_buffer( + a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1 + ) + + with T.block("root"): + T.reads() + T.writes(C_warp[0:WARP_SIZE, 0:local_size]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, WARP_SIZE) + + T.evaluate(T.mma_fill(local_size, C_warp.data, C_warp.elem_offset, dtype=dtype)) + + return mma_fill_desc, mma_fill_impl + + +def get_mma_store_intrin(dtype, local_size, scope="global"): + # Assume M = N = 16 + index_map = shared_16x16_to_ldmatrix_32x8_layout + + @T.prim_func + def mma_store_desc(a: T.handle, c: T.handle) -> None: + C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, scope="warp") + C = T.match_buffer(c, [M_DIM, N_DIM], dtype=dtype, scope=scope) + + with T.block("root"): + T.reads(C_warp[0:WARP_SIZE, 0:local_size]) + T.writes(C[0:M_DIM, 0:N_DIM]) + for i0, i1 in T.grid(M_DIM, N_DIM): + with T.block("C_warp"): + v0, v1 = T.axis.remap("SS", [i0, i1]) + thread_id, local_id = index_map(v0, v1) + T.reads(C_warp[thread_id, local_id]) + T.writes(C[v0, v1]) + C[v0, v1] = C_warp[thread_id, local_id] + + @T.prim_func + def mma_store_impl(a: T.handle, c: T.handle) -> None: + s0 = T.var("int32") + s1 = T.var("int32") + + C_warp = T.match_buffer( + a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1 + ) + C = T.match_buffer( + c, [M_DIM, N_DIM], dtype=dtype, scope="global", offset_factor=1, strides=[s0, s1] + ) + + with T.block("root"): + T.reads(C_warp[0:WARP_SIZE, 0:local_size]) + T.writes(C[0:M_DIM, 0:N_DIM]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, WARP_SIZE) + + T.evaluate( + T.mma_store( + M_DIM, + N_DIM, + C.access_ptr("w"), + C_warp.data, + C_warp.elem_offset, + s0, + dtype=dtype, + ) + ) + + return mma_store_desc, mma_store_impl + + +LDMATRIX_16x16_A_INTRIN = "mma.ldmatrix_16x16_a" +TensorIntrin.register(LDMATRIX_16x16_A_INTRIN, *get_ldmatrix_intrin(16, "float16", False, False)) + +LDMATRIX_16x16_B_INTRIN = "mma.ldmatrix_16x16_b" +TensorIntrin.register(LDMATRIX_16x16_B_INTRIN, *get_ldmatrix_intrin(16, "float16", True, False)) + +LDMATRIX_16x16_B_TRANS_INTRIN = "mma.ldmatrix_16x16_b_trans" +TensorIntrin.register( + LDMATRIX_16x16_B_TRANS_INTRIN, *get_ldmatrix_intrin(16, "float16", True, True) +) + +LDMATRIX_16x32_A_INTRIN = "mma.ldmatrix_16x32_a" +TensorIntrin.register(LDMATRIX_16x32_A_INTRIN, *get_ldmatrix_intrin(32, "int8", False, False)) + +LDMATRIX_32x16_B_INTRIN = "mma.ldmatrix_32x16_b" +TensorIntrin.register(LDMATRIX_32x16_B_INTRIN, *get_ldmatrix_intrin(32, "int8", True, False)) + +LDMATRIX_16x32_B_TRANS_INTRIN = "mma.ldmatrix_16x32_b_trans" +TensorIntrin.register(LDMATRIX_16x32_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "int8", True, True)) + +MMA_f16f16f32_INTRIN = "mma_f16f16f32" +TensorIntrin.register(MMA_f16f16f32_INTRIN, *get_mma_intrin(16, "float32", False)) + +MMA_f16f16f32_TRANS_INTRIN = "mma_f16f16f32_trans" +TensorIntrin.register(MMA_f16f16f32_TRANS_INTRIN, *get_mma_intrin(16, "float32", True)) + +MMA_f16f16f16_INTRIN = "mma_f16f16f16" +TensorIntrin.register(MMA_f16f16f16_INTRIN, *get_mma_intrin(16, "float16", False)) + +MMA_f16f16f16_TRANS_INTRIN = "mma_f16f16f16_trans" +TensorIntrin.register(MMA_f16f16f16_TRANS_INTRIN, *get_mma_intrin(16, "float16", True)) + +MMA_i8i8i32_INTRIN = "mma_i8i8i32" +TensorIntrin.register(MMA_i8i8i32_INTRIN, *get_mma_intrin(32, "int32", False)) + +MMA_i8i8i32_TRANS_INTRIN = "mma_i8i8i32_trans" +TensorIntrin.register(MMA_i8i8i32_TRANS_INTRIN, *get_mma_intrin(32, "int32", True)) + +MMA_fill_16x16_f32_INTRIN = "mma_fill_16x16_f32" +TensorIntrin.register(MMA_fill_16x16_f32_INTRIN, *get_mma_fill_intrin("float32", 8)) + +MMA_fill_16x16_f16_INTRIN = "mma_fill_16x16_f16" +TensorIntrin.register(MMA_fill_16x16_f16_INTRIN, *get_mma_fill_intrin("float16", 8)) + +MMA_fill_16x16_i32_INTRIN = "mma_fill_16x16_i32" +TensorIntrin.register(MMA_fill_16x16_i32_INTRIN, *get_mma_fill_intrin("int32", 8)) + +MMA_store_16x16_f32_global_INTRIN = "mma_store_16x16_f32_global_" +TensorIntrin.register( + MMA_store_16x16_f32_global_INTRIN, *get_mma_store_intrin("float32", 8, "global") +) + +MMA_store_16x16_f16_global_INTRIN = "mma_store_16x16_f16_global_" +TensorIntrin.register( + MMA_store_16x16_f16_global_INTRIN, *get_mma_store_intrin("float16", 8, "global") +) + +MMA_store_16x16_i32_global_INTRIN = "mma_store_16x16_i32_global_" +TensorIntrin.register( + MMA_store_16x16_i32_global_INTRIN, *get_mma_store_intrin("int32", 8, "global") +) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 7459d4c250ba..616e75f2e776 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -25,6 +25,7 @@ #include #include +#include #include #include @@ -818,9 +819,78 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string local_ptr = this->PrintExpr(op->args[3]); std::string local_elem_offset = this->PrintExpr(op->args[4]); std::string smem_ptr = this->PrintExpr(op->args[5]); - std::string smem_elem_offset = this->PrintExpr(op->args[6]); - this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset, - smem_ptr, smem_elem_offset); + if (trans && op->dtype.bits() == 8) { + // Since ldmatrix assumes that a matrix element is 16 bit, it cannot properly transpose an + // int8 matrix. + std::string smem_stride = this->PrintExpr(op->args[6]); + ICHECK(num == 4); + os << "for (int i = 0; i < 16; ++i) {\n"; + os << local_ptr << "[" + local_elem_offset + " + i] = " << smem_ptr + << "[(i % 8) / 4 * " + smem_stride + " * 16 + (threadIdx.x % 4) * 4 * " + smem_stride + + "+ (i % 4) * " + smem_stride + " + threadIdx.x / 4 + (i / 8) * 8];\n"; + os << "}\n"; + } else { + std::string smem_elem_offset = this->PrintExpr(op->args[6]); + this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset, + smem_ptr, smem_elem_offset); + } + } else if (op->op.same_as(builtin::mma_store())) { + int m = Downcast(op->args[0])->value; + int n = Downcast(op->args[1])->value; + std::string dst = this->PrintExpr(op->args[2]); + std::string src = this->PrintExpr(op->args[3]); + std::string src_offset = this->PrintExpr(op->args[4]); + PrimExpr stride = op->args[5]; + + ICHECK(m == 16 && n == 16) << "Only m == 16 && n == 16 case supported for now"; + + // Each thread in a warp holds a certain number of elements of an MMA output. + // For example, if we compute a 16x16 tile using MMA, each thread holds 8 elements + // in its registers. So conceptually, a warp memory is organized as a 32x8 block. + // A map from a 16x16 tile to a 32x8 block of memory is specified by the index map below. + + // To store the 32x8 output back to a 16x16 tile in shared or global memory, we invert this map + // to determine the output location for each 8 element. + + const auto* index_map_func = + runtime::Registry::Get("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout"); + ICHECK(index_map_func); + + auto inverse_index_map = + IndexMap::FromFunc(2, *index_map_func).Inverse({Range(0, m), Range(0, n)}); + auto indices_16x16 = inverse_index_map->final_indices; + + // "//" and "%" in the index map are translated to FloorDiv/Mod, but the plain Div/Mod are fine. + // FloorDiv/Mod are supposed to be lowered before they reach codegen, so manually replace them + // to the plain ones here. + class LowerFloorDivMod : public ExprMutator { + public: + PrimExpr VisitExpr_(const FloorDivNode* op) { + return tir::Div(this->VisitExpr(op->a), this->VisitExpr(op->b)); + } + PrimExpr VisitExpr_(const FloorModNode* op) { + return tir::Mod(this->VisitExpr(op->a), this->VisitExpr(op->b)); + } + }; + + auto dst_ind = LowerFloorDivMod()(indices_16x16[0] * stride + indices_16x16[1]); + + var_idmap_[inverse_index_map->initial_indices[0].get()] = "threadIdx.x"; + var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id"; + + os << "for (int local_id = 0; local_id < 8; ++local_id) {\n"; + os << dst << "[" + this->PrintExpr(dst_ind) + "]" + << " = " << src << "[" << src_offset << " + local_id];\n"; + os << "}\n"; + + } else if (op->op.same_as(builtin::mma_fill())) { + std::string num_elem = this->PrintExpr(op->args[0]); + std::string dst = this->PrintExpr(op->args[1]); + std::string dst_offset = this->PrintExpr(op->args[2]); + + os << "for (int i = 0; i < " << num_elem << "; ++i) {\n"; + os << dst << "[" << dst_offset << " + i] = 0.0;"; + os << "}\n"; } else if (op->op.same_as(builtin::ptx_cp_async())) { std::string dst = this->PrintExpr(op->args[0]); std::string dst_offset = this->PrintExpr(op->args[1]); diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 0415d1bbec9e..1871a3d7bf70 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -256,6 +256,12 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_commit_group) TIR_DEFINE_BUILTIN_FUNC(ptx_wait_group) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(mma_store).set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(mma_fill).set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(vectorhigh) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 40971114d416..d8250cd09888 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -101,7 +101,7 @@ namespace tir { // Visitor to find m in pattern // store warp_mem[m * warp_index + (width * m) * y + x] -class WarpStoreCoeffFinder : private StmtVisitor { +class WarpStoreCoeffFinder : private StmtExprVisitor { public: WarpStoreCoeffFinder(const VarNode* buffer, Var warp_index, arith::Analyzer* analyzer) : buffer_(buffer), warp_index_(warp_index), analyzer_(analyzer) {} @@ -113,6 +113,18 @@ class WarpStoreCoeffFinder : private StmtVisitor { private: /// Visitor implementation + void VisitExpr_(const CallNode* op) final { + if (op->op.same_as(builtin::ptx_ldmatrix()) && op->args[3].as() == buffer_) { + UpdatePattern(op->args[4]); + } else if (op->op.same_as(builtin::mma_fill()) && op->args[1].as() == buffer_) { + auto* local_size = op->args[0].as(); + ICHECK(local_size) << "Integer expected for the first argument of mma_fill"; + warp_coeff_ = local_size->value; + } + + StmtExprVisitor::VisitExpr_(op); + } + void VisitStmt_(const StoreNode* op) final { LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; } @@ -245,6 +257,37 @@ class WarpAccessRewriter : protected StmtExprMutator { } protected: + PrimExpr RewriteIndicesAt(const CallNode* op, const std::vector& indices) { + Array new_args = op->args; + for (int i : indices) { + if (op->args[i].get() == buffer_) { + PrimExpr local_index = SplitIndexByGroup(op->args[i + 1]).first; + new_args.Set(i + 1, local_index); + } + } + return Call(op->dtype, op->op, new_args); + } + + PrimExpr VisitExpr_(const CallNode* op) override { + if (op->op.same_as(builtin::ptx_mma())) { + return RewriteIndicesAt(op, {6, 8, 10}); + } + + if (op->op.same_as(builtin::ptx_ldmatrix())) { + return RewriteIndicesAt(op, {3}); + } + + if (op->op.same_as(builtin::mma_store())) { + return RewriteIndicesAt(op, {3}); + } + + if (op->op.same_as(builtin::mma_fill())) { + return RewriteIndicesAt(op, {1}); + } + + return StmtExprMutator::VisitExpr_(op); + } + PrimExpr VisitExpr_(const VarNode* op) override { ICHECK(op != buffer_) << "Cannot access address of warp memory directly"; return StmtExprMutator::VisitExpr_(op); diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py new file mode 100644 index 000000000000..bc097493b761 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py @@ -0,0 +1,424 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import tvm +from tvm import te +from tvm.tir.tensor_intrin.cuda import ( + LDMATRIX_16x16_A_INTRIN, + LDMATRIX_16x16_B_INTRIN, + LDMATRIX_16x16_B_TRANS_INTRIN, + LDMATRIX_16x32_A_INTRIN, + LDMATRIX_32x16_B_INTRIN, + LDMATRIX_16x32_B_TRANS_INTRIN, + MMA_f16f16f32_INTRIN, + MMA_f16f16f32_TRANS_INTRIN, + MMA_f16f16f16_INTRIN, + MMA_f16f16f16_TRANS_INTRIN, + MMA_i8i8i32_INTRIN, + MMA_i8i8i32_TRANS_INTRIN, + MMA_fill_16x16_f32_INTRIN, + MMA_fill_16x16_f16_INTRIN, + MMA_fill_16x16_i32_INTRIN, + MMA_store_16x16_f32_global_INTRIN, + MMA_store_16x16_f16_global_INTRIN, + MMA_store_16x16_i32_global_INTRIN, + shared_16x16_to_ldmatrix_32x8_layout, + shared_32x16_to_ldmatrix_32x16_layout, + shared_16x32_to_ldmatrix_32x16_layout, +) +import tvm.testing +import numpy as np + + +M = 4096 +N = 4096 +K = 4096 +measure_perf = True +gflops = (N * M * K) * 2 / 1e9 + + +def matmul(m, n, k, in_dtype, out_dtype, b_transposed): + b_shape = (n, k) if b_transposed else (k, n) + a = te.placeholder((m, k), name="A", dtype=in_dtype) + b = te.placeholder(b_shape, name="B", dtype=in_dtype) + k = te.reduce_axis((0, k), name="k") + + def maybe_cast(v): + if in_dtype != out_dtype: + return tvm.tir.Cast(out_dtype, v) + return v + + def maybe_swap(i, j): + if b_transposed: + return j, i + return i, j + + c = te.compute( + (m, n), + lambda i, j: te.sum(maybe_cast(a[i, k]) * maybe_cast(b[maybe_swap(k, j)]), axis=[k]), + name="C", + ) + return (a, b, c) + + +def run_test( + k_inner, + in_dtype, + out_dtype, + b_transposed, + i_factors, + j_factors, + k_factors, + index_map_A, + index_map_B, + index_map_C, + ldmatrix_a_intrin, + ldmatrix_b_intrin, + mma_intrin, + mma_fill_intrin, + mma_store_intrin, +): + workload = te.create_prim_func(matmul(M, N, K, in_dtype, out_dtype, b_transposed)) + ir_module = tvm.IRModule({"main": workload}) + sch = tvm.tir.Schedule(ir_module) + + block = sch.get_block("C") + i, j, k = sch.get_loops(block) + i, i_tc = sch.split(i, factors=[None, 16]) + j, j_tc = sch.split(j, factors=[None, 16]) + k, k_tc = sch.split(k, factors=[None, k_inner]) + + sch.reorder(i, j, k, i_tc, j_tc, k_tc) + + block_inner = sch.blockize(i_tc) + block_outer, block_inner = block_inner, block + + num_ty = i_factors[2] * j_factors[2] + + i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors) + j0, j1, j2, j3, j4 = sch.split(j, factors=j_factors) + k0, k1, k2 = sch.split(k, k_factors) + + sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3, k2, i4, j4) + + block_idx = sch.fuse(i0, j0) + block_idy = sch.fuse(i1, j1) + thread_idy = sch.fuse(j2, i2) + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + sch.bind(thread_idy, "threadIdx.y") + + def fetch_to_shared(block, idx, ndim): + block_read = sch.cache_read(block, idx, "shared") + sch.compute_at(block_read, k0) + vector_size = 16 if in_dtype == "int8" else 8 + warp_size = 32 + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size]) + sch.bind(f_2, "threadIdx.x") + sch.bind(f_1, "threadIdx.y") + sch.vectorize(f_3) + offset = 8 if in_dtype == "float16" else 16 + sch.storage_align(block_read, 0, axis=-2, factor=32, offset=offset) + + return block_read + + fetch_to_shared(block_outer, 0, 2) + fetch_to_shared(block_outer, 1, 2) + + A_warp = sch.cache_read(block_outer, 0, "warp") + B_warp = sch.cache_read(block_outer, 1, "warp") + + sch.compute_at(A_warp, k1) + sch.compute_at(B_warp, k1) + + C_warp = sch.cache_write(block_outer, 0, "warp") + sch.reverse_compute_at(C_warp, thread_idy) + + ii, jj = sch.get_loops(C_warp)[-2:] + io, ii = sch.split(ii, factors=[None, 16]) + jo, ji = sch.split(jj, factors=[None, 16]) + sch.reorder(io, jo, ii, ji) + + sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3]) + block_init_c = sch.get_block("C_init") + + def tile_wmma_fragment(block_read, height, width): + i, j = sch.get_loops(block_read)[-2:] + i0, i1 = sch.split(i, factors=[None, height]) + j0, j1 = sch.split(j, factors=[None, width]) + sch.reorder(i0, j0, i1, j1) + return i1 + + loop_a = tile_wmma_fragment(A_warp, 16, k_inner) + + if b_transposed: + loop_b = tile_wmma_fragment(B_warp, 16, k_inner) + else: + loop_b = tile_wmma_fragment(B_warp, k_inner, 16) + + sch.transform_layout(A_warp, 0, "write", index_map_A) + sch.transform_layout(B_warp, 0, "write", index_map_B) + sch.transform_layout(C_warp, 0, "read", index_map_C) + + sch.tensorize(loop_a, ldmatrix_a_intrin) + sch.tensorize(loop_b, ldmatrix_b_intrin) + sch.tensorize(sch.get_loops(block_inner)[-3], mma_intrin) + sch.tensorize(sch.get_loops(block_init_c)[-2], mma_fill_intrin) + sch.tensorize(sch.get_loops(C_warp)[-2], mma_store_intrin) + + f = tvm.build(sch.mod["main"], target="cuda", name="dense") + dev = tvm.device("cuda", 0) + + if in_dtype == "float16": + a_np = np.random.uniform(size=(M, K)).astype("float16") + + if b_transposed: + b_np = np.random.uniform(size=(N, K)).astype("float16") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32").transpose()).astype( + out_dtype + ) + else: + b_np = np.random.uniform(size=(K, N)).astype("float16") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype(out_dtype) + else: + a_np = np.random.randint(-128, 128, (M, K)).astype("int8") + + if b_transposed: + b_np = np.random.randint(-128, 128, (N, K)).astype("int8") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32").transpose()).astype( + "int32" + ) + else: + b_np = np.random.randint(-128, 128, (K, N)).astype("int8") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype("int32") + + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros((M, N), dtype=out_dtype), dev) + + f(a, b, c) + + if out_dtype != "float16": + # The numpy reference is computed with fp32 precision (otherwise too slow). + # So there is non-trivial accuracy difference if TVM result is computed with fp16 accumulation. + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) + + return lambda: f.time_evaluator(f.entry_name, dev, number=500)(a, b, c) + + +def is_ampere_or_newer(): + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + return major * 10 + minor >= 80 + + +def test_f16f16f32_m16n16k16(): + if not is_ampere_or_newer(): + return + + def index_map(i, j): + return ( + i // 16, + j // 16, + *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16), + ) + + k_inner = 16 + in_dtype = "float16" + out_dtype = "float32" + i_factors, j_factors, k_factors = [4, 8, 2, 4, 1], [1, 64, 2, 1, 2], [128, 2, 1] + + timer = run_test( + k_inner, + in_dtype, + out_dtype, + False, # b_transposed + i_factors, + j_factors, + k_factors, + index_map, + index_map, + index_map, + LDMATRIX_16x16_A_INTRIN, + LDMATRIX_16x16_B_INTRIN, + MMA_f16f16f32_INTRIN, + MMA_fill_16x16_f32_INTRIN, + MMA_store_16x16_f32_global_INTRIN, + ) + + if measure_perf: + print("f16f16f32_m16n16k16: %f GFLOPS" % (gflops / (timer().mean))) + + timer = run_test( + k_inner, + in_dtype, + out_dtype, + True, # b_transposed + i_factors, + j_factors, + k_factors, + index_map, + index_map, + index_map, + LDMATRIX_16x16_A_INTRIN, + LDMATRIX_16x16_B_TRANS_INTRIN, + MMA_f16f16f32_TRANS_INTRIN, + MMA_fill_16x16_f32_INTRIN, + MMA_store_16x16_f32_global_INTRIN, + ) + + if measure_perf: + print("f16f16f32_m16n16k16_trans: %f GFLOPS" % (gflops / (timer().mean))) + + +def test_f16f16f16_m16n16k16(): + if not is_ampere_or_newer(): + return + + def index_map(i, j): + return ( + i // 16, + j // 16, + *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16), + ) + + k_inner = 16 + in_dtype = "float16" + out_dtype = "float16" + i_factors, j_factors, k_factors = [16, 2, 1, 4, 2], [16, 2, 2, 1, 4], [128, 2, 1] + + timer = run_test( + k_inner, + in_dtype, + out_dtype, + False, # b_transposed + i_factors, + j_factors, + k_factors, + index_map, + index_map, + index_map, + LDMATRIX_16x16_A_INTRIN, + LDMATRIX_16x16_B_INTRIN, + MMA_f16f16f16_INTRIN, + MMA_fill_16x16_f16_INTRIN, + MMA_store_16x16_f16_global_INTRIN, + ) + + if measure_perf: + print("f16f16f16_m16n16k16: %f GFLOPS" % (gflops / (timer().mean))) + + timer = run_test( + k_inner, + in_dtype, + out_dtype, + True, # b_transposed + i_factors, + j_factors, + k_factors, + index_map, + index_map, + index_map, + LDMATRIX_16x16_A_INTRIN, + LDMATRIX_16x16_B_TRANS_INTRIN, + MMA_f16f16f16_TRANS_INTRIN, + MMA_fill_16x16_f16_INTRIN, + MMA_store_16x16_f16_global_INTRIN, + ) + + if measure_perf: + print("f16f16f16_m16n16k16_trans: %f GFLOPS" % (gflops / (timer().mean))) + + +def test_i8i8i32_m16n16k32(): + if not is_ampere_or_newer(): + return + + def index_map_A(i, j): + return ( + i // 16, + j // 32, + *shared_16x32_to_ldmatrix_32x16_layout(i % 16, j % 32), + ) + + def index_map_B(i, j): + return ( + i // 32, + j // 16, + *shared_32x16_to_ldmatrix_32x16_layout(i % 32, j % 16), + ) + + def index_map_C(i, j): + return ( + i // 16, + j // 16, + *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16), + ) + + k_inner = 32 + in_dtype = "int8" + out_dtype = "int32" + i_factors, j_factors, k_factors = [1, 32, 1, 4, 2], [8, 4, 4, 2, 1], [32, 2, 2] + + timer = run_test( + k_inner, + in_dtype, + out_dtype, + False, # b_transposed + i_factors, + j_factors, + k_factors, + index_map_A, + index_map_B, + index_map_C, + LDMATRIX_16x32_A_INTRIN, + LDMATRIX_32x16_B_INTRIN, + MMA_i8i8i32_INTRIN, + MMA_fill_16x16_i32_INTRIN, + MMA_store_16x16_i32_global_INTRIN, + ) + + if measure_perf: + print("i8i8i32_m16n16k32: %f GOPS" % (gflops / (timer().mean))) + + timer = run_test( + k_inner, + in_dtype, + out_dtype, + True, # b_transposed + i_factors, + j_factors, + k_factors, + index_map_A, + index_map_A, + index_map_C, + LDMATRIX_16x32_A_INTRIN, + LDMATRIX_16x32_B_TRANS_INTRIN, + MMA_i8i8i32_TRANS_INTRIN, + MMA_fill_16x16_i32_INTRIN, + MMA_store_16x16_i32_global_INTRIN, + ) + + if measure_perf: + print("i8i8i32_m16n16k32_trans: %f GOPS" % (gflops / (timer().mean))) + + +if __name__ == "__main__": + test_f16f16f32_m16n16k16() + test_f16f16f16_m16n16k16() + test_i8i8i32_m16n16k32() From 393391bb2fd0390885f2b5deff0d1b4bc8e73e28 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 18 May 2022 19:38:51 +0900 Subject: [PATCH 2/4] set measure_perf to False --- .../python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py index bc097493b761..2b888c5e86fd 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py +++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py @@ -47,7 +47,7 @@ M = 4096 N = 4096 K = 4096 -measure_perf = True +measure_perf = False gflops = (N * M * K) * 2 / 1e9 From 8263b6952332256ad66f71811f9d8b6f8e2fb2ca Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 19 May 2022 05:01:45 +0900 Subject: [PATCH 3/4] add requires_gpu decorator in tests, always test build on non-ampere --- ...est_tir_schedule_tensorize_ldmatrix_mma.py | 40 +++++++++---------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py index 2b888c5e86fd..78c615ff06c3 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py +++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py @@ -75,6 +75,12 @@ def maybe_swap(i, j): return (a, b, c) +def is_ampere_or_newer(): + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + return major * 10 + minor >= 80 + + def run_test( k_inner, in_dtype, @@ -182,6 +188,10 @@ def tile_wmma_fragment(block_read, height, width): sch.tensorize(sch.get_loops(C_warp)[-2], mma_store_intrin) f = tvm.build(sch.mod["main"], target="cuda", name="dense") + + if not is_ampere_or_newer(): + return None + dev = tvm.device("cuda", 0) if in_dtype == "float16": @@ -221,16 +231,8 @@ def tile_wmma_fragment(block_read, height, width): return lambda: f.time_evaluator(f.entry_name, dev, number=500)(a, b, c) -def is_ampere_or_newer(): - arch = tvm.contrib.nvcc.get_target_compute_version() - major, minor = tvm.contrib.nvcc.parse_compute_version(arch) - return major * 10 + minor >= 80 - - +@tvm.testing.requires_cuda def test_f16f16f32_m16n16k16(): - if not is_ampere_or_newer(): - return - def index_map(i, j): return ( i // 16, @@ -261,7 +263,7 @@ def index_map(i, j): MMA_store_16x16_f32_global_INTRIN, ) - if measure_perf: + if measure_perf and timer: print("f16f16f32_m16n16k16: %f GFLOPS" % (gflops / (timer().mean))) timer = run_test( @@ -282,14 +284,12 @@ def index_map(i, j): MMA_store_16x16_f32_global_INTRIN, ) - if measure_perf: + if measure_perf and timer: print("f16f16f32_m16n16k16_trans: %f GFLOPS" % (gflops / (timer().mean))) +@tvm.testing.requires_cuda def test_f16f16f16_m16n16k16(): - if not is_ampere_or_newer(): - return - def index_map(i, j): return ( i // 16, @@ -320,7 +320,7 @@ def index_map(i, j): MMA_store_16x16_f16_global_INTRIN, ) - if measure_perf: + if measure_perf and timer: print("f16f16f16_m16n16k16: %f GFLOPS" % (gflops / (timer().mean))) timer = run_test( @@ -341,14 +341,12 @@ def index_map(i, j): MMA_store_16x16_f16_global_INTRIN, ) - if measure_perf: + if measure_perf and timer: print("f16f16f16_m16n16k16_trans: %f GFLOPS" % (gflops / (timer().mean))) +@tvm.testing.requires_cuda def test_i8i8i32_m16n16k32(): - if not is_ampere_or_newer(): - return - def index_map_A(i, j): return ( i // 16, @@ -393,7 +391,7 @@ def index_map_C(i, j): MMA_store_16x16_i32_global_INTRIN, ) - if measure_perf: + if measure_perf and timer: print("i8i8i32_m16n16k32: %f GOPS" % (gflops / (timer().mean))) timer = run_test( @@ -414,7 +412,7 @@ def index_map_C(i, j): MMA_store_16x16_i32_global_INTRIN, ) - if measure_perf: + if measure_perf and timer: print("i8i8i32_m16n16k32_trans: %f GOPS" % (gflops / (timer().mean))) From 7bf882cc4351f29cb353cbe1ae20ac55c2ee3833 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 19 May 2022 08:14:52 +0900 Subject: [PATCH 4/4] skip cuda compile on old gpu --- .../unittest/test_tir_schedule_tensorize_ldmatrix_mma.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py index 78c615ff06c3..67e8ae0ad836 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py +++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py @@ -77,8 +77,8 @@ def maybe_swap(i, j): def is_ampere_or_newer(): arch = tvm.contrib.nvcc.get_target_compute_version() - major, minor = tvm.contrib.nvcc.parse_compute_version(arch) - return major * 10 + minor >= 80 + major, _ = tvm.contrib.nvcc.parse_compute_version(arch) + return major >= 8 def run_test( @@ -187,11 +187,11 @@ def tile_wmma_fragment(block_read, height, width): sch.tensorize(sch.get_loops(block_init_c)[-2], mma_fill_intrin) sch.tensorize(sch.get_loops(C_warp)[-2], mma_store_intrin) - f = tvm.build(sch.mod["main"], target="cuda", name="dense") - if not is_ampere_or_newer(): return None + f = tvm.build(sch.mod["main"], target="cuda", name="dense") + dev = tvm.device("cuda", 0) if in_dtype == "float16":