Skip to content

Commit 8e5bc34

Browse files
committed
add matmul for gpu
1 parent e86a018 commit 8e5bc34

File tree

11 files changed

+963
-32
lines changed

11 files changed

+963
-32
lines changed

chapter_cpu_schedules/vector_add.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ import tvm
1616

1717
We first define reusable plot functions to draw multiple lines, which generalize the plot function defined in :numref:`ch_call_overhead`.
1818

19-
```{.python .input n=10}
19+
```{.python .input n=1}
2020
# Save to the d2ltvm package.
2121
def plot(X, Y, xlabel=None, ylabel=None, legend=[], xlim=None,
2222
ylim=None, xscale='linear', yscale='linear', fmts=None,
23-
figsize=(6, 4)):
23+
figsize=(4.5, 3)):
2424
"""Plot multiple lines"""
2525
display.set_matplotlib_formats('svg')
2626
plt.rcParams['figure.figsize'] = figsize

chapter_getting_started/from_mxnet.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ import tvm
1111
from tvm import relay
1212
```
1313

14-
Here three additional modules are imported than the previous chapter. We will use `PIL` to read images, `MXNet` to obtain pre-trained neural networks, and the `relay` module in TVM to convert and optimize a neural network.
14+
Here three additional modules are imported than the previous chapter. We will use `PIL` to read images, `MXNet` to obtain pre-trained neural networks, and the `relay` module :cite:`Roesch.Lyubomirsky.Kirisame.ea.2019` in TVM to convert and optimize a neural network.
1515

1616
## Obtaining Pre-trained Models
1717

18-
A pre-trained model means a neural network with parameters trained on a data set. Here we download and load a ResNet-18 model by specifying `pretrained=True` from MXNet's model zoo. If you want to know details about this model, please refer to [Chapter 7.6 in D2L](http://d2l.ai/chapter_convolutional-modern/resnet.html). You can find more models on the [MXNet model zoo](https://mxnet.apache.org/api/python/docs/api/gluon/model_zoo/index.html) page, or refer to [GluonCV](https://gluon-cv.mxnet.io/model_zoo/index.html) and [GluonNLP](http://gluon-nlp.mxnet.io/model_zoo/index.html) for more computer vision and natural language models.
18+
A pre-trained model means a neural network with parameters trained on a data set. Here we download and load a ResNet-18 model by specifying `pretrained=True` from MXNet's model zoo :cite:`Chen.Li.Li.ea.2015`. If you want to know details about this model, please refer to [Chapter 7.6 in D2L](http://d2l.ai/chapter_convolutional-modern/resnet.html). You can find more models on the [MXNet model zoo](https://mxnet.apache.org/api/python/docs/api/gluon/model_zoo/index.html) page, or refer to [GluonCV](https://gluon-cv.mxnet.io/model_zoo/index.html) and [GluonNLP](http://gluon-nlp.mxnet.io/model_zoo/index.html) for more computer vision and natural language models.
1919

2020
```{.python .input n=2}
2121
model = mx.gluon.model_zoo.vision.resnet18_v2(pretrained=True)

chapter_getting_started/install.md

+10-2
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@ wget http://tvm.d2l.ai/d2l-tvm.zip
1717
unzip d2l-tvm.zip -d d2l-tvm
1818
```
1919

20+
2021
## Installing Running Environment
2122

2223
If you have both Python 3.5 or later and pip installed, the easiest way to
2324
install the running environment is through pip. There packages are needed,
24-
`d2ltvm` for all dependencies such as Jupyter and saved code blocks, and `tvm`
25+
`d2ltvm` for all dependencies such as Jupyter and saved code blocks, and `tvm` :cite:`Chen.Moreau.Jiang.ea.2018`
2526
for the deep learning compiler we are using. Some chapters use `mxnet` as
2627
a baseline.
2728

@@ -31,6 +32,7 @@ First install `d2ltvm`:
3132
pip install git+https://github.com/d2l-ai/d2l-tvm
3233
```
3334

35+
3436
Then compile `tvm` from source codes. TVM doesn't have a pip package because it
3537
highly depends on the libraries available on your system. Please follow the
3638
instructions on
@@ -40,6 +42,8 @@ book requires at least
4042
```bash
4143
set(USE_LLVM ON)
4244
```
45+
46+
4347
Also
4448
don't forget the enable `cython`, which accelerates the performance. You just
4549
need to run `make cython` in the TVM source folder.
@@ -51,13 +55,15 @@ may use the pre-built library that is for evaluating this book:
5155
pip install https://tvm-repo.s3-us-west-2.amazonaws.com/tvm-0.6.dev0-cp37-cp37m-linux_x86_64.whl
5256
```
5357

54-
Finally, install MXNet's CUDA version if GPUs are available. Assume you are have
58+
59+
Finally, install MXNet's CUDA version if GPUs are available :cite:`Chen.Li.Li.ea.2015`. Assume you are have
5560
CUDA 10.1 installed, then
5661

5762
```bash
5863
pip install mxnet-cu101
5964
```
6065

66+
6167
You can change the `101` to match your CUDA version.
6268

6369
Once all packages are installed, we now open the Jupyter notebook by
@@ -66,6 +72,7 @@ Once all packages are installed, we now open the Jupyter notebook by
6672
jupyter notebook
6773
```
6874

75+
6976
At this point open http://localhost:8888 (which usually opens automatically) in the browser, then you can view and run the code in each section of the book.
7077

7178

@@ -86,3 +93,4 @@ from matplotlib import pyplot as plt
8693
from IPython import display
8794
import mxnet as mx
8895
```
96+

chapter_gpu_schedules/index.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Operator Optimizations on GPUs
22
:label:`ch_gpu_schedules`
33

4-
54
```toc
65
:maxdepth: 2
76
:numbered:
87
98
arch
109
vector_add
10+
matmul
1111
```

chapter_gpu_schedules/matmul.md

+191
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# Matrix Multiplication
2+
3+
In this chapter, we will extend :numref:`ch_block_matmul_cpu` to optimize matrix multiplication on GPUs.
4+
5+
```{.python .input n=39}
6+
import d2ltvm
7+
import numpy as np
8+
import timeit
9+
import tvm
10+
```
11+
12+
## Setup
13+
14+
We will use MXNet as our baseline, which calls cuBLAS to compute the results.
15+
16+
```{.python .input n=39}
17+
# Save to the d2ltvm package.
18+
def matmul_timer_mxnet(n, ctx):
19+
"""The matrix multiplication timer for MXNet
20+
21+
n : width and height of inptus
22+
ctx : device
23+
"""
24+
timer = timeit.Timer(
25+
setup='import d2ltvm\n'
26+
'import mxnet as mx\n'
27+
'a, b, c, = d2ltvm.get_abc((%d, %d), lambda x: mx.nd.array(x, ctx=mx.%s()))\n'
28+
'mx.nd.waitall()' % (n, n, ctx),
29+
stmt='mx.nd.dot(a, b, out=c); c.wait_to_read()')
30+
return timer.timeit
31+
```
32+
33+
Compute the GFLOPS.
34+
35+
```{.python .input n=37}
36+
sizes = 2**np.arange(8, 15, 1)
37+
times = [d2ltvm.bench_workload(matmul_timer_mxnet(int(n), 'gpu'))
38+
for n in sizes]
39+
mxnet_gflops = 2 * sizes **3 / 1e9 / np.array(times)
40+
```
41+
42+
## Blocked Matrix Multiplication for GPU
43+
44+
We will follow :numref:`ch_block_matmul_cpu` to split the matrix $C$ into blocks, and have each core (streaming multiprocessor) to compute a block at a time. It can be done by assigning a block to a thread block as we did in :numref:`ch_vector_add_gpu`. As mentioned in :numref:`ch_gpu_arch`, the GPU core has a finer architecture, we need to split a block further for every thread in the thread block. The simplest way is illustrated in :numref:`ch_vector_add_gpu`, here we will explore the local memory within a core and 2-D thread indexing.
45+
46+
### Shared Memory
47+
48+
Within a GPU core, there is a shared memory that can be accessed by all threads. We mentioned there is a L1 cache within each core, which is managed by the compiler and hardware. Unlike cache, we can allocate memory directly on the shared memory as others such as main memory and the global GPU memory.
49+
50+
In the TVM abstraction, we also call it cache to simplify the concept. Creating a read-only cache for $A$ that will be used by $C$ on the shared memroy, we can call `s.cache_read(A, "shared", [C])`.
51+
52+
![Blocked tiling for matrix multiplication with $A$ and $B$ on shared memory.](../img/matmul_block_gpu1.svg)
53+
:label:`fig_matmul_block_gpu_shared`
54+
55+
In :numref:`ch_block_matmul_cpu`, we created a write cache of an output block. Here, we will explore the opportunity to create read caches for input blocks. We redraw :numref:`fig_matmul_block` in :numref:`fig_matmul_block_gpu_shared`, it shows how to compute an output block through a series of matrix multiplications over input blocks. Since we will use all threads in a thread block to compute this block, we can cache input blocks in the shared memory. Now we can rewrite the block computation in :numref:`ch_block_matmul_cpu` as:
56+
57+
```python
58+
for k in range(0, n, tk):
59+
A_shared = A[y:y+ty, k:k+tk] # cache in shared memory
60+
B_shared = B[k:k+tk, x:x+tx] # cache in shared memory
61+
# use all threads in the thread block
62+
C[y:y+ty, x:x+tx] += dot(A_shared, B_shared)
63+
```
64+
65+
66+
Here `tx`, `ty` and `tk` are the tile sizes. The only difference is that we put the input blocks in the shared cache.
67+
68+
Assume `tx=64`, `ty=128` and `tk=32`, then for each core, we will cache two matrices of sizes $128\times 32$ and $32\times 64$ on the shared memory, with a total size 24 KB. We can query the shared memory size in KB of the GPU we are using to make sure that these two matrices can fit into the shared memory.
69+
70+
```{.python .input}
71+
ctx = tvm.gpu()
72+
ctx.max_shared_memory_per_block/1024
73+
```
74+
75+
### Thread Block and Registers
76+
77+
Next let's explore how to compute an output block in parallel efficiently. We can use the same idea: further splitting the block into smaller block tiles, and having each thread to compute one block. :numref:`fig_matmul_block_thread_block` shows splitting a $128 \times 64$ output block into $16 \times 16$ tiles, with each tile a $8\times 4$ matrix. Then we will create 256 threads within this thread block. Since the output is a matrix, we use a 2-D thread indexing, with `blockDim.x = blockDim.y = 16`. In addition, we will move the inputs, two vectors with length of 8 and 4, respectively, and the output, a $8\times 4$ matrix, for each thread into the local memory.
78+
79+
![Blocked tiling for matrix multiplication.](../img/matmul_thread_block.svg)
80+
:label:`fig_matmul_thread_block`
81+
82+
The local memory means the memory created in the kernel, which can be only accessed by the single thread that is executing this kernel. From the hardware aspect, this space is allocated on the global memory. But the compiler will try to allocate them on the registers, which is even faster than the shared memory, if it fits. For each thread,
83+
we will allocate three matrices of sizes $8\times 1$, $1\times 4$ and $8\times 4$, with in total 46 32-bit floats. It fits into the constraint that each thread will have 255 32-bit registers.
84+
85+
### Cooperative Fetching
86+
87+
Finally, loading the blocks of `A_shared` and `B_shared` into the shared memory is time consuming. We can accelerate it through multi-threading, namely using all threads in a thread block to load it.
88+
89+
## Implementation
90+
91+
We first implement utility functions to split an axis with a list of factors, and bind a list of axes with threads.
92+
93+
```{.python .input n=40}
94+
# Save into the d2ltvm package.
95+
def split(stage, axis, factors):
96+
"""Split an axis by a list of factors in a reverse order
97+
"""
98+
axes = []
99+
for f in reversed(factors):
100+
axis, x = stage.split(axis, f)
101+
axes.append(x)
102+
return list(reversed(axes+[axis]))
103+
104+
# Save into the d2ltvm package.
105+
def bind_thread(stage, axes, tags):
106+
"""Bind a list of axes to thread axes
107+
"""
108+
for axis, tag in zip(axes, tags):
109+
stage.bind(axis, tvm.thread_axis(tag))
110+
```
111+
112+
Next set the hyperparamters with values we described before.
113+
114+
```{.python .input}
115+
block_size = 16 # the # of threads for one dimension in a thread block.
116+
tx, ty, tk = 8, 4, 32 # tile sizes for one CUDA thread
117+
```
118+
119+
Now we can implement our schedule. There are three things worth mentioning: one is we denote by `x` the rows and `y` the columns, so an element can be assessed by `C[x,y]`. While in CUDA thread indexing, `x` is used for the innermost dimension, i.e. columns. Therefore you will see we bind axis `yb` (split from `y`) to `blockIdx.x` instead of `blockIdx.y`. The other one is we need to partition the axes of `A_shared` and `B_shared` into `block_size` parts, so we can reuse the threads binded to `xo` and `yo` for cooperative fetching. Otherwise TVM may not properly synchronize threads that lead to wrong results.
120+
121+
```{.python .input n=69}
122+
def matmul_gpu(n):
123+
A, B, C = d2ltvm.matmul(n, n, n)
124+
s = tvm.create_schedule(C.op)
125+
# Create caches
126+
A_shared = s.cache_read(A, "shared", [C])
127+
A_local = s.cache_read(A_shared, "local", [C])
128+
B_shared = s.cache_read(B, "shared", [C])
129+
B_local = s.cache_read(B_shared, "local", [C])
130+
C_local = s.cache_write(C, "local")
131+
# Split each axis into block axis, thread axis, and inner axis.
132+
x, y = s[C].op.axis
133+
xb, xo, xi = split(s[C], x, (block_size, tx))
134+
yb, yo, yi = split(s[C], y, (block_size, ty))
135+
s[C].reorder(xb, yb, xo, yo, xi, yi)
136+
# Note that we bind yb to blockIdx.x instead of blockIdx.y.
137+
bind_thread(s[C], (yb, xb, yo, xo),
138+
("blockIdx.x", "blockIdx.y", "threadIdx.x", "threadIdx.y"))
139+
# Optimize C_local
140+
s[C_local].compute_at(s[C], yo)
141+
yi, xi = s[C_local].op.axis
142+
k, = s[C_local].op.reduce_axis
143+
ko, ki = s[C_local].split(k, tk)
144+
s[C_local].reorder(ko, ki, yi, xi)
145+
# Optimize read caches of A and B with cooperative Fetching
146+
def optimize_read_cache(shared, local, i):
147+
s[shared].compute_at(s[C_local], ko)
148+
s[local].compute_at(s[C_local], ki)
149+
y, x = s[shared].op.axis
150+
# Note that we must split into bloc_size parts to reuse
151+
# the previous axis threads.
152+
yo, yi = s[shared].split(y, nparts=block_size)
153+
xo, xi = s[shared].split(x, nparts=block_size)
154+
s[shared].reorder(yo, xo, yi, xi)
155+
bind_thread(s[shared], (yo, xo), ("threadIdx.y", "threadIdx.x"))
156+
optimize_read_cache(A_shared, A_local, True)
157+
optimize_read_cache(B_shared, B_local, False)
158+
return s, (A, B, C)
159+
```
160+
161+
Let's verify the correctness of the schedule. First print the pseudo codes. Since we didn't unroll the loops, the pseudo codes are relative compact and we can check the allocated the cache sizes and how each stage is computed.
162+
163+
```{.python .input}
164+
n = 2048
165+
s, args = matmul_gpu(n)
166+
tvm.lower(s, args, simple_mode=True)
167+
```
168+
169+
Next we compare the results against NumPy to check the correctness.
170+
171+
```{.python .input}
172+
target, ctx = 'cuda', tvm.gpu()
173+
mod = tvm.build(s, args, target)
174+
a, b, c, = d2ltvm.get_abc((n, n), lambda x: tvm.nd.array(x, ctx=ctx))
175+
mod(a, b, c)
176+
np.testing.assert_allclose(
177+
c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), atol=1e-2)
178+
```
179+
180+
Finally, measure the performance and compare to our baseline. You can see that our schedule works well for small matrices while is constantly slower for large ones. The reason might due to 1) we didn't consider bank conflict when reading share memory, 2) the CUDA codes generated by TVM maybe not ideal, 3) previous works show that assembly codes provides more flexibility and often outperform CUDA codes performance :cite:`Nath.Tomov.Dongarra.2010,Lai.Seznec.2013`.
181+
182+
```{.python .input}
183+
tvm_gflops = d2ltvm.bench_matmul_tvm(matmul_gpu, sizes, 'cuda')
184+
d2ltvm.plot_gflops(sizes, [mxnet_gflops, tvm_gflops], legend=['MXNet', 'TVM'])
185+
```
186+
187+
## Summary
188+
189+
- We use a two-level block tiling to parallelize matrix multiplication on GPUs.
190+
- We load data used by a thread block into share memory, and data used by a CUDA thread into registers
191+
- The shared data within a thread block is loaded by cooperative fetching.

chapter_references/zreferences.md

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
```eval_rst
2+
3+
.. only:: html
4+
5+
References
6+
==========
7+
8+
```
9+
10+
:bibliography:`../d2ltvm.bib`

d2ltvm.bib

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
@Article{ Chen.Li.Li.ea.2015,
2+
title = {MXNet: A flexible and efficient machine learning library
3+
for heterogeneous distributed systems},
4+
author = {Chen, Tianqi and Li, Mu and Li, Yutian and Lin, Min and
5+
Wang, Naiyan and Wang, Minjie and Xiao, Tianjun and Xu,
6+
Bing and Zhang, Chiyuan and Zhang, Zheng},
7+
journal = {arXiv preprint arXiv:1512.01274},
8+
year = {2015}
9+
}
10+
11+
@InProceedings{ Chen.Moreau.Jiang.ea.2018,
12+
title = {TVM: An automated end-to-end optimizing compiler for deep
13+
learning},
14+
author = {Chen, Tianqi and Moreau, Thierry and Jiang, Ziheng and
15+
Zheng, Lianmin and Yan, Eddie and Shen, Haichen and Cowan,
16+
Meghan and Wang, Leyuan and Hu, Yuwei and Ceze, Luis and
17+
others},
18+
booktitle = {13th USENIX Symposium on Operating Systems Design and
19+
Implementation (OSDI 18)},
20+
pages = {578--594},
21+
year = {2018}
22+
}
23+
24+
@Article{ Roesch.Lyubomirsky.Kirisame.ea.2019,
25+
title = {Relay: A High-Level IR for Deep Learning},
26+
author = {Roesch, Jared and Lyubomirsky, Steven and Kirisame, Marisa
27+
and Pollock, Josh and Weber, Logan and Jiang, Ziheng and
28+
Chen, Tianqi and Moreau, Thierry and Tatlock, Zachary},
29+
journal = {arXiv preprint arXiv:1904.08368},
30+
year = {2019}
31+
}
32+
33+
@InProceedings{ Wang.Chen.Liu.ea.2019,
34+
title = {A Unified Optimization Approach for CNN Model Inference on
35+
Integrated GPUs},
36+
author = {Wang, Leyuan and Chen, Zhi and Liu, Yizhi and Wang, Yao
37+
and Zheng, Lianmin and Li, Mu and Wang, Yida},
38+
booktitle = {Proceedings of the 48th International Conference on
39+
Parallel Processing},
40+
pages = {99},
41+
year = {2019},
42+
organization = {ACM}
43+
}
44+
45+
@InProceedings{ Liu.Wang.Yu.ea.2019,
46+
title = {Optimizing CNN Model Inference on CPUs},
47+
author = {Liu, Yizhi and Wang, Yao and Yu, Ruofei and Li, Mu and
48+
Sharma, Vin and Wang, Yida},
49+
booktitle = {2019 USENIX Annual Technical Conference (USENIX ATC 19)},
50+
pages = {1025--1040},
51+
year = {2019}
52+
}
53+
54+
@InProceedings{ Jiang.Chen.Li.2018,
55+
title = {Efficient Deep Learning Inference on Edge Devices},
56+
author = {Jiang, Ziheng and Chen, Tianqi and Li, Mu},
57+
year = {2018}
58+
}
59+
@InProceedings{ Lai.Seznec.2013,
60+
title = {Performance upper bound analysis and optimization of SGEMM
61+
on Fermi and Kepler GPUs},
62+
author = {Lai, Junjie and Seznec, Andre},
63+
booktitle = {Proceedings of the 2013 IEEE/ACM International Symposium
64+
on Code Generation and Optimization (CGO)},
65+
pages = {1--10},
66+
year = {2013},
67+
organization = {IEEE}
68+
}
69+
70+
@Article{ Nath.Tomov.Dongarra.2010,
71+
title = {An improved MAGMA GEMM for Fermi graphics processing
72+
units},
73+
author = {Nath, Rajib and Tomov, Stanimire and Dongarra, Jack},
74+
journal = {The International Journal of High Performance Computing
75+
Applications},
76+
volume = {24},
77+
number = {4},
78+
pages = {511--515},
79+
year = {2010},
80+
publisher = {SAGE Publications Sage UK: London, England}
81+
}

0 commit comments

Comments
 (0)