|
| 1 | +# Matrix Multiplication |
| 2 | + |
| 3 | +In this chapter, we will extend :numref:`ch_block_matmul_cpu` to optimize matrix multiplication on GPUs. |
| 4 | + |
| 5 | +```{.python .input n=39} |
| 6 | +import d2ltvm |
| 7 | +import numpy as np |
| 8 | +import timeit |
| 9 | +import tvm |
| 10 | +``` |
| 11 | + |
| 12 | +## Setup |
| 13 | + |
| 14 | +We will use MXNet as our baseline, which calls cuBLAS to compute the results. |
| 15 | + |
| 16 | +```{.python .input n=39} |
| 17 | +# Save to the d2ltvm package. |
| 18 | +def matmul_timer_mxnet(n, ctx): |
| 19 | + """The matrix multiplication timer for MXNet |
| 20 | +
|
| 21 | + n : width and height of inptus |
| 22 | + ctx : device |
| 23 | + """ |
| 24 | + timer = timeit.Timer( |
| 25 | + setup='import d2ltvm\n' |
| 26 | + 'import mxnet as mx\n' |
| 27 | + 'a, b, c, = d2ltvm.get_abc((%d, %d), lambda x: mx.nd.array(x, ctx=mx.%s()))\n' |
| 28 | + 'mx.nd.waitall()' % (n, n, ctx), |
| 29 | + stmt='mx.nd.dot(a, b, out=c); c.wait_to_read()') |
| 30 | + return timer.timeit |
| 31 | +``` |
| 32 | + |
| 33 | +Compute the GFLOPS. |
| 34 | + |
| 35 | +```{.python .input n=37} |
| 36 | +sizes = 2**np.arange(8, 15, 1) |
| 37 | +times = [d2ltvm.bench_workload(matmul_timer_mxnet(int(n), 'gpu')) |
| 38 | + for n in sizes] |
| 39 | +mxnet_gflops = 2 * sizes **3 / 1e9 / np.array(times) |
| 40 | +``` |
| 41 | + |
| 42 | +## Blocked Matrix Multiplication for GPU |
| 43 | + |
| 44 | +We will follow :numref:`ch_block_matmul_cpu` to split the matrix $C$ into blocks, and have each core (streaming multiprocessor) to compute a block at a time. It can be done by assigning a block to a thread block as we did in :numref:`ch_vector_add_gpu`. As mentioned in :numref:`ch_gpu_arch`, the GPU core has a finer architecture, we need to split a block further for every thread in the thread block. The simplest way is illustrated in :numref:`ch_vector_add_gpu`, here we will explore the local memory within a core and 2-D thread indexing. |
| 45 | + |
| 46 | +### Shared Memory |
| 47 | + |
| 48 | +Within a GPU core, there is a shared memory that can be accessed by all threads. We mentioned there is a L1 cache within each core, which is managed by the compiler and hardware. Unlike cache, we can allocate memory directly on the shared memory as others such as main memory and the global GPU memory. |
| 49 | + |
| 50 | +In the TVM abstraction, we also call it cache to simplify the concept. Creating a read-only cache for $A$ that will be used by $C$ on the shared memroy, we can call `s.cache_read(A, "shared", [C])`. |
| 51 | + |
| 52 | + |
| 53 | +:label:`fig_matmul_block_gpu_shared` |
| 54 | + |
| 55 | +In :numref:`ch_block_matmul_cpu`, we created a write cache of an output block. Here, we will explore the opportunity to create read caches for input blocks. We redraw :numref:`fig_matmul_block` in :numref:`fig_matmul_block_gpu_shared`, it shows how to compute an output block through a series of matrix multiplications over input blocks. Since we will use all threads in a thread block to compute this block, we can cache input blocks in the shared memory. Now we can rewrite the block computation in :numref:`ch_block_matmul_cpu` as: |
| 56 | + |
| 57 | +```python |
| 58 | +for k in range(0, n, tk): |
| 59 | + A_shared = A[y:y+ty, k:k+tk] # cache in shared memory |
| 60 | + B_shared = B[k:k+tk, x:x+tx] # cache in shared memory |
| 61 | + # use all threads in the thread block |
| 62 | + C[y:y+ty, x:x+tx] += dot(A_shared, B_shared) |
| 63 | +``` |
| 64 | + |
| 65 | + |
| 66 | +Here `tx`, `ty` and `tk` are the tile sizes. The only difference is that we put the input blocks in the shared cache. |
| 67 | + |
| 68 | +Assume `tx=64`, `ty=128` and `tk=32`, then for each core, we will cache two matrices of sizes $128\times 32$ and $32\times 64$ on the shared memory, with a total size 24 KB. We can query the shared memory size in KB of the GPU we are using to make sure that these two matrices can fit into the shared memory. |
| 69 | + |
| 70 | +```{.python .input} |
| 71 | +ctx = tvm.gpu() |
| 72 | +ctx.max_shared_memory_per_block/1024 |
| 73 | +``` |
| 74 | + |
| 75 | +### Thread Block and Registers |
| 76 | + |
| 77 | +Next let's explore how to compute an output block in parallel efficiently. We can use the same idea: further splitting the block into smaller block tiles, and having each thread to compute one block. :numref:`fig_matmul_block_thread_block` shows splitting a $128 \times 64$ output block into $16 \times 16$ tiles, with each tile a $8\times 4$ matrix. Then we will create 256 threads within this thread block. Since the output is a matrix, we use a 2-D thread indexing, with `blockDim.x = blockDim.y = 16`. In addition, we will move the inputs, two vectors with length of 8 and 4, respectively, and the output, a $8\times 4$ matrix, for each thread into the local memory. |
| 78 | + |
| 79 | + |
| 80 | +:label:`fig_matmul_thread_block` |
| 81 | + |
| 82 | +The local memory means the memory created in the kernel, which can be only accessed by the single thread that is executing this kernel. From the hardware aspect, this space is allocated on the global memory. But the compiler will try to allocate them on the registers, which is even faster than the shared memory, if it fits. For each thread, |
| 83 | +we will allocate three matrices of sizes $8\times 1$, $1\times 4$ and $8\times 4$, with in total 46 32-bit floats. It fits into the constraint that each thread will have 255 32-bit registers. |
| 84 | + |
| 85 | +### Cooperative Fetching |
| 86 | + |
| 87 | +Finally, loading the blocks of `A_shared` and `B_shared` into the shared memory is time consuming. We can accelerate it through multi-threading, namely using all threads in a thread block to load it. |
| 88 | + |
| 89 | +## Implementation |
| 90 | + |
| 91 | +We first implement utility functions to split an axis with a list of factors, and bind a list of axes with threads. |
| 92 | + |
| 93 | +```{.python .input n=40} |
| 94 | +# Save into the d2ltvm package. |
| 95 | +def split(stage, axis, factors): |
| 96 | + """Split an axis by a list of factors in a reverse order |
| 97 | + """ |
| 98 | + axes = [] |
| 99 | + for f in reversed(factors): |
| 100 | + axis, x = stage.split(axis, f) |
| 101 | + axes.append(x) |
| 102 | + return list(reversed(axes+[axis])) |
| 103 | +
|
| 104 | +# Save into the d2ltvm package. |
| 105 | +def bind_thread(stage, axes, tags): |
| 106 | + """Bind a list of axes to thread axes |
| 107 | + """ |
| 108 | + for axis, tag in zip(axes, tags): |
| 109 | + stage.bind(axis, tvm.thread_axis(tag)) |
| 110 | +``` |
| 111 | + |
| 112 | +Next set the hyperparamters with values we described before. |
| 113 | + |
| 114 | +```{.python .input} |
| 115 | +block_size = 16 # the # of threads for one dimension in a thread block. |
| 116 | +tx, ty, tk = 8, 4, 32 # tile sizes for one CUDA thread |
| 117 | +``` |
| 118 | + |
| 119 | +Now we can implement our schedule. There are three things worth mentioning: one is we denote by `x` the rows and `y` the columns, so an element can be assessed by `C[x,y]`. While in CUDA thread indexing, `x` is used for the innermost dimension, i.e. columns. Therefore you will see we bind axis `yb` (split from `y`) to `blockIdx.x` instead of `blockIdx.y`. The other one is we need to partition the axes of `A_shared` and `B_shared` into `block_size` parts, so we can reuse the threads binded to `xo` and `yo` for cooperative fetching. Otherwise TVM may not properly synchronize threads that lead to wrong results. |
| 120 | + |
| 121 | +```{.python .input n=69} |
| 122 | +def matmul_gpu(n): |
| 123 | + A, B, C = d2ltvm.matmul(n, n, n) |
| 124 | + s = tvm.create_schedule(C.op) |
| 125 | + # Create caches |
| 126 | + A_shared = s.cache_read(A, "shared", [C]) |
| 127 | + A_local = s.cache_read(A_shared, "local", [C]) |
| 128 | + B_shared = s.cache_read(B, "shared", [C]) |
| 129 | + B_local = s.cache_read(B_shared, "local", [C]) |
| 130 | + C_local = s.cache_write(C, "local") |
| 131 | + # Split each axis into block axis, thread axis, and inner axis. |
| 132 | + x, y = s[C].op.axis |
| 133 | + xb, xo, xi = split(s[C], x, (block_size, tx)) |
| 134 | + yb, yo, yi = split(s[C], y, (block_size, ty)) |
| 135 | + s[C].reorder(xb, yb, xo, yo, xi, yi) |
| 136 | + # Note that we bind yb to blockIdx.x instead of blockIdx.y. |
| 137 | + bind_thread(s[C], (yb, xb, yo, xo), |
| 138 | + ("blockIdx.x", "blockIdx.y", "threadIdx.x", "threadIdx.y")) |
| 139 | + # Optimize C_local |
| 140 | + s[C_local].compute_at(s[C], yo) |
| 141 | + yi, xi = s[C_local].op.axis |
| 142 | + k, = s[C_local].op.reduce_axis |
| 143 | + ko, ki = s[C_local].split(k, tk) |
| 144 | + s[C_local].reorder(ko, ki, yi, xi) |
| 145 | + # Optimize read caches of A and B with cooperative Fetching |
| 146 | + def optimize_read_cache(shared, local, i): |
| 147 | + s[shared].compute_at(s[C_local], ko) |
| 148 | + s[local].compute_at(s[C_local], ki) |
| 149 | + y, x = s[shared].op.axis |
| 150 | + # Note that we must split into bloc_size parts to reuse |
| 151 | + # the previous axis threads. |
| 152 | + yo, yi = s[shared].split(y, nparts=block_size) |
| 153 | + xo, xi = s[shared].split(x, nparts=block_size) |
| 154 | + s[shared].reorder(yo, xo, yi, xi) |
| 155 | + bind_thread(s[shared], (yo, xo), ("threadIdx.y", "threadIdx.x")) |
| 156 | + optimize_read_cache(A_shared, A_local, True) |
| 157 | + optimize_read_cache(B_shared, B_local, False) |
| 158 | + return s, (A, B, C) |
| 159 | +``` |
| 160 | + |
| 161 | +Let's verify the correctness of the schedule. First print the pseudo codes. Since we didn't unroll the loops, the pseudo codes are relative compact and we can check the allocated the cache sizes and how each stage is computed. |
| 162 | + |
| 163 | +```{.python .input} |
| 164 | +n = 2048 |
| 165 | +s, args = matmul_gpu(n) |
| 166 | +tvm.lower(s, args, simple_mode=True) |
| 167 | +``` |
| 168 | + |
| 169 | +Next we compare the results against NumPy to check the correctness. |
| 170 | + |
| 171 | +```{.python .input} |
| 172 | +target, ctx = 'cuda', tvm.gpu() |
| 173 | +mod = tvm.build(s, args, target) |
| 174 | +a, b, c, = d2ltvm.get_abc((n, n), lambda x: tvm.nd.array(x, ctx=ctx)) |
| 175 | +mod(a, b, c) |
| 176 | +np.testing.assert_allclose( |
| 177 | + c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), atol=1e-2) |
| 178 | +``` |
| 179 | + |
| 180 | +Finally, measure the performance and compare to our baseline. You can see that our schedule works well for small matrices while is constantly slower for large ones. The reason might due to 1) we didn't consider bank conflict when reading share memory, 2) the CUDA codes generated by TVM maybe not ideal, 3) previous works show that assembly codes provides more flexibility and often outperform CUDA codes performance :cite:`Nath.Tomov.Dongarra.2010,Lai.Seznec.2013`. |
| 181 | + |
| 182 | +```{.python .input} |
| 183 | +tvm_gflops = d2ltvm.bench_matmul_tvm(matmul_gpu, sizes, 'cuda') |
| 184 | +d2ltvm.plot_gflops(sizes, [mxnet_gflops, tvm_gflops], legend=['MXNet', 'TVM']) |
| 185 | +``` |
| 186 | + |
| 187 | +## Summary |
| 188 | + |
| 189 | +- We use a two-level block tiling to parallelize matrix multiplication on GPUs. |
| 190 | +- We load data used by a thread block into share memory, and data used by a CUDA thread into registers |
| 191 | +- The shared data within a thread block is loaded by cooperative fetching. |
0 commit comments