|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +"""A script to measure GPU memory bandwidth""" |
| 18 | +import argparse |
| 19 | +import itertools |
| 20 | + |
| 21 | +import numpy as np |
| 22 | + |
| 23 | +import tvm |
| 24 | +from tvm import te, tir |
| 25 | +from tvm.meta_schedule.runner import EvaluatorConfig |
| 26 | +from tvm.testing import local_run |
| 27 | + |
| 28 | + |
| 29 | +def _parse_args() -> argparse.Namespace: |
| 30 | + def _parse_list_int(source: str): |
| 31 | + return [int(i) for i in source.split(",")] |
| 32 | + |
| 33 | + parser = argparse.ArgumentParser( |
| 34 | + prog="GPU memory bandwidth testing", |
| 35 | + description="""Example: |
| 36 | + python -m tvm.exec.gpu_memory_bandwidth "nvidia/geforce-rtx-3090-ti" \ |
| 37 | + --dtype "float32" |
| 38 | + --bx "8,16,32,64,128,256" \ |
| 39 | + --tx "32,64,128,256,512,1024" \ |
| 40 | + --vec "1,2,4" |
| 41 | +""", |
| 42 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| 43 | + ) |
| 44 | + parser.add_argument( |
| 45 | + "target", |
| 46 | + type=str, |
| 47 | + help="The target to be benchmarked", |
| 48 | + ) |
| 49 | + parser.add_argument( |
| 50 | + "--xo", |
| 51 | + type=int, |
| 52 | + default=1024, |
| 53 | + help="The value of `XO` in [XO, K, XI] => [XO, XI] reduction", |
| 54 | + ) |
| 55 | + parser.add_argument( |
| 56 | + "--k", |
| 57 | + type=int, |
| 58 | + default=64, |
| 59 | + help="The value of `K` in [XO, K, XI] => [XO, XI] reduction", |
| 60 | + ) |
| 61 | + parser.add_argument( |
| 62 | + "--xi", |
| 63 | + type=int, |
| 64 | + default=4096, |
| 65 | + help="The value of `XI` in [XO, K, XI] -> [XO, XI] reduction", |
| 66 | + ) |
| 67 | + parser.add_argument( |
| 68 | + "--dtype", |
| 69 | + type=str, |
| 70 | + default="float32", |
| 71 | + help="The data type to be used in the workload", |
| 72 | + ) |
| 73 | + parser.add_argument( |
| 74 | + "--bx", |
| 75 | + type=_parse_list_int, |
| 76 | + default=[8, 16, 32, 64, 128, 256], |
| 77 | + help="The value to be used to split `XO` into [BX, _]", |
| 78 | + ) |
| 79 | + parser.add_argument( |
| 80 | + "--tx", |
| 81 | + type=_parse_list_int, |
| 82 | + default=[32, 64, 128, 256, 512, 1024], |
| 83 | + help="Number of threads to be used", |
| 84 | + ) |
| 85 | + parser.add_argument( |
| 86 | + "--vec", |
| 87 | + type=_parse_list_int, |
| 88 | + default=[1, 2, 4], |
| 89 | + help="Vector length to be used in vectorized load", |
| 90 | + ) |
| 91 | + return parser.parse_args() |
| 92 | + |
| 93 | + |
| 94 | +def _workload( |
| 95 | + len_xo: int, |
| 96 | + len_k: int, |
| 97 | + len_xi: int, |
| 98 | + dtype: str, |
| 99 | +): |
| 100 | + # pylint: disable=invalid-name |
| 101 | + A = te.placeholder((len_xo, len_k, len_xi), dtype=dtype, name="A") |
| 102 | + k = te.reduce_axis((0, len_k), "k") |
| 103 | + B = te.compute( |
| 104 | + (len_xo, len_xi), |
| 105 | + lambda i, j: te.sum(A[i, k, j], axis=k), |
| 106 | + name="B", |
| 107 | + ) |
| 108 | + # pylint: enable=invalid-name |
| 109 | + return te.create_prim_func([A, B]) |
| 110 | + |
| 111 | + |
| 112 | +def _schedule( |
| 113 | + sch: tir.Schedule, |
| 114 | + len_bx: int, |
| 115 | + len_tx: int, |
| 116 | + len_vec: int, |
| 117 | +): |
| 118 | + # pylint: disable=invalid-name |
| 119 | + block = sch.get_block("B") |
| 120 | + xo, xi, k = sch.get_loops(block) |
| 121 | + bx, xo = sch.split(xo, factors=[len_bx, None]) |
| 122 | + xi, tx, vec = sch.split(xi, factors=[None, len_tx, len_vec]) |
| 123 | + sch.reorder(bx, xi, tx, xo, k, vec) |
| 124 | + bx = sch.fuse(bx, xi) |
| 125 | + sch.bind(bx, "blockIdx.x") |
| 126 | + sch.bind(tx, "threadIdx.x") |
| 127 | + ldg = sch.cache_read(block, 0, "local") |
| 128 | + sch.compute_at(ldg, k, preserve_unit_loops=True) |
| 129 | + sch.vectorize(sch.get_loops(ldg)[-1]) |
| 130 | + sch.decompose_reduction(block, k) |
| 131 | + # pylint: enable=invalid-name |
| 132 | + |
| 133 | + |
| 134 | +def main(): # pylint: disable=too-many-locals |
| 135 | + """Entry point""" |
| 136 | + args = _parse_args() |
| 137 | + # pylint: disable=invalid-name |
| 138 | + target = tvm.target.Target(args.target) |
| 139 | + dtype = args.dtype |
| 140 | + |
| 141 | + a = np.random.uniform(-1, 1, (args.xo, args.k, args.xi)).astype(dtype) |
| 142 | + b = np.zeros((args.xo, args.xi), dtype=dtype) |
| 143 | + num_bytes = a.size * a.itemsize + b.size * b.itemsize |
| 144 | + print("###### Bandwidth Test ######") |
| 145 | + print( |
| 146 | + f"Workload [XO, K, XI] => [XO, XI]. " |
| 147 | + f"[{args.xo}, {args.k}, {args.xi}] => [{args.xo}, {args.xi}]" |
| 148 | + ) |
| 149 | + print(f"Input size: {num_bytes / 1048576} MB") |
| 150 | + print(f"Target: {target}") |
| 151 | + |
| 152 | + # pylint: enable=invalid-name |
| 153 | + best_bandwidth = -1 |
| 154 | + for len_bx, len_tx, len_vec in itertools.product( |
| 155 | + args.bx, |
| 156 | + args.tx, |
| 157 | + args.vec, |
| 158 | + ): |
| 159 | + func = _workload( |
| 160 | + len_xo=args.xo, |
| 161 | + len_k=args.k, |
| 162 | + len_xi=args.xi, |
| 163 | + dtype=dtype, |
| 164 | + ) |
| 165 | + sch = tir.Schedule(func) |
| 166 | + _schedule(sch, len_bx, len_tx, len_vec) |
| 167 | + |
| 168 | + _, profile_result = local_run( |
| 169 | + tvm.build(sch.mod, target=target), |
| 170 | + target.kind.name, |
| 171 | + [a, b], |
| 172 | + evaluator_config=EvaluatorConfig( |
| 173 | + number=10, |
| 174 | + repeat=1, |
| 175 | + min_repeat_ms=100, |
| 176 | + enable_cpu_cache_flush=False, |
| 177 | + ), |
| 178 | + ) |
| 179 | + bandwidth = num_bytes / profile_result.mean / (1024**3) |
| 180 | + bx = len_bx * args.xi // (len_tx * len_vec) # pylint: disable=invalid-name |
| 181 | + mbs = num_bytes / 1024 / 1024 |
| 182 | + print( |
| 183 | + f"bandwidth = {bandwidth:.3f} GB/s, bx = {bx}, tx = {len_tx}, " |
| 184 | + f"len_vec = {len_vec}, bytes = {mbs} MB" |
| 185 | + ) |
| 186 | + if bandwidth > best_bandwidth: |
| 187 | + best_bandwidth = bandwidth |
| 188 | + print(f"peak bandwidth: {best_bandwidth:.3f} GB/s") |
| 189 | + |
| 190 | + |
| 191 | +if __name__ == "__main__": |
| 192 | + main() |
0 commit comments