Skip to content

Commit 4e9ca2a

Browse files
committed
[Exec] Add a script to test GPU memory bandwidth (apache#15287)
This PR adds a script to test GPU memory bandwidth in TVM. Example usage: python -m tvm.exec.gpu_memory_bandwidth \ "nvidia/geforce-rtx-3090-ti" \ --dtype "float32" \ --bx "8,16,32,64,128,256" \ --tx "32,64,128,256,512,1024" \ --vec "1,2,4"
1 parent f2d5ad6 commit 4e9ca2a

File tree

1 file changed

+192
-0
lines changed

1 file changed

+192
-0
lines changed
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""A script to measure GPU memory bandwidth"""
18+
import argparse
19+
import itertools
20+
21+
import numpy as np
22+
23+
import tvm
24+
from tvm import te, tir
25+
from tvm.meta_schedule.runner import EvaluatorConfig
26+
from tvm.testing import local_run
27+
28+
29+
def _parse_args() -> argparse.Namespace:
30+
def _parse_list_int(source: str):
31+
return [int(i) for i in source.split(",")]
32+
33+
parser = argparse.ArgumentParser(
34+
prog="GPU memory bandwidth testing",
35+
description="""Example:
36+
python -m tvm.exec.gpu_memory_bandwidth "nvidia/geforce-rtx-3090-ti" \
37+
--dtype "float32"
38+
--bx "8,16,32,64,128,256" \
39+
--tx "32,64,128,256,512,1024" \
40+
--vec "1,2,4"
41+
""",
42+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
43+
)
44+
parser.add_argument(
45+
"target",
46+
type=str,
47+
help="The target to be benchmarked",
48+
)
49+
parser.add_argument(
50+
"--xo",
51+
type=int,
52+
default=1024,
53+
help="The value of `XO` in [XO, K, XI] => [XO, XI] reduction",
54+
)
55+
parser.add_argument(
56+
"--k",
57+
type=int,
58+
default=64,
59+
help="The value of `K` in [XO, K, XI] => [XO, XI] reduction",
60+
)
61+
parser.add_argument(
62+
"--xi",
63+
type=int,
64+
default=4096,
65+
help="The value of `XI` in [XO, K, XI] -> [XO, XI] reduction",
66+
)
67+
parser.add_argument(
68+
"--dtype",
69+
type=str,
70+
default="float32",
71+
help="The data type to be used in the workload",
72+
)
73+
parser.add_argument(
74+
"--bx",
75+
type=_parse_list_int,
76+
default=[8, 16, 32, 64, 128, 256],
77+
help="The value to be used to split `XO` into [BX, _]",
78+
)
79+
parser.add_argument(
80+
"--tx",
81+
type=_parse_list_int,
82+
default=[32, 64, 128, 256, 512, 1024],
83+
help="Number of threads to be used",
84+
)
85+
parser.add_argument(
86+
"--vec",
87+
type=_parse_list_int,
88+
default=[1, 2, 4],
89+
help="Vector length to be used in vectorized load",
90+
)
91+
return parser.parse_args()
92+
93+
94+
def _workload(
95+
len_xo: int,
96+
len_k: int,
97+
len_xi: int,
98+
dtype: str,
99+
):
100+
# pylint: disable=invalid-name
101+
A = te.placeholder((len_xo, len_k, len_xi), dtype=dtype, name="A")
102+
k = te.reduce_axis((0, len_k), "k")
103+
B = te.compute(
104+
(len_xo, len_xi),
105+
lambda i, j: te.sum(A[i, k, j], axis=k),
106+
name="B",
107+
)
108+
# pylint: enable=invalid-name
109+
return te.create_prim_func([A, B])
110+
111+
112+
def _schedule(
113+
sch: tir.Schedule,
114+
len_bx: int,
115+
len_tx: int,
116+
len_vec: int,
117+
):
118+
# pylint: disable=invalid-name
119+
block = sch.get_block("B")
120+
xo, xi, k = sch.get_loops(block)
121+
bx, xo = sch.split(xo, factors=[len_bx, None])
122+
xi, tx, vec = sch.split(xi, factors=[None, len_tx, len_vec])
123+
sch.reorder(bx, xi, tx, xo, k, vec)
124+
bx = sch.fuse(bx, xi)
125+
sch.bind(bx, "blockIdx.x")
126+
sch.bind(tx, "threadIdx.x")
127+
ldg = sch.cache_read(block, 0, "local")
128+
sch.compute_at(ldg, k, preserve_unit_loops=True)
129+
sch.vectorize(sch.get_loops(ldg)[-1])
130+
sch.decompose_reduction(block, k)
131+
# pylint: enable=invalid-name
132+
133+
134+
def main(): # pylint: disable=too-many-locals
135+
"""Entry point"""
136+
args = _parse_args()
137+
# pylint: disable=invalid-name
138+
target = tvm.target.Target(args.target)
139+
dtype = args.dtype
140+
141+
a = np.random.uniform(-1, 1, (args.xo, args.k, args.xi)).astype(dtype)
142+
b = np.zeros((args.xo, args.xi), dtype=dtype)
143+
num_bytes = a.size * a.itemsize + b.size * b.itemsize
144+
print("###### Bandwidth Test ######")
145+
print(
146+
f"Workload [XO, K, XI] => [XO, XI]. "
147+
f"[{args.xo}, {args.k}, {args.xi}] => [{args.xo}, {args.xi}]"
148+
)
149+
print(f"Input size: {num_bytes / 1048576} MB")
150+
print(f"Target: {target}")
151+
152+
# pylint: enable=invalid-name
153+
best_bandwidth = -1
154+
for len_bx, len_tx, len_vec in itertools.product(
155+
args.bx,
156+
args.tx,
157+
args.vec,
158+
):
159+
func = _workload(
160+
len_xo=args.xo,
161+
len_k=args.k,
162+
len_xi=args.xi,
163+
dtype=dtype,
164+
)
165+
sch = tir.Schedule(func)
166+
_schedule(sch, len_bx, len_tx, len_vec)
167+
168+
_, profile_result = local_run(
169+
tvm.build(sch.mod, target=target),
170+
target.kind.name,
171+
[a, b],
172+
evaluator_config=EvaluatorConfig(
173+
number=10,
174+
repeat=1,
175+
min_repeat_ms=100,
176+
enable_cpu_cache_flush=False,
177+
),
178+
)
179+
bandwidth = num_bytes / profile_result.mean / (1024**3)
180+
bx = len_bx * args.xi // (len_tx * len_vec) # pylint: disable=invalid-name
181+
mbs = num_bytes / 1024 / 1024
182+
print(
183+
f"bandwidth = {bandwidth:.3f} GB/s, bx = {bx}, tx = {len_tx}, "
184+
f"len_vec = {len_vec}, bytes = {mbs} MB"
185+
)
186+
if bandwidth > best_bandwidth:
187+
best_bandwidth = bandwidth
188+
print(f"peak bandwidth: {best_bandwidth:.3f} GB/s")
189+
190+
191+
if __name__ == "__main__":
192+
main()

0 commit comments

Comments
 (0)