|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +# pylint: disable=invalid-name, unused-variable, too-many-locals |
| 18 | +# pylint: disable=unused-argument, redefined-builtin |
| 19 | +"""GeMM dense schedule on AArch64""" |
| 20 | +import tvm |
| 21 | +from tvm import te |
| 22 | +from tvm.topi import nn |
| 23 | +from tvm.topi.arm_cpu.arm_utils import get_tiling_A, get_tiling_B_transformed, pad_dim_to_multiple |
| 24 | +from ..utils import get_const_tuple, traverse_inline |
| 25 | +from .. import tag |
| 26 | + |
| 27 | +# Compute function |
| 28 | +def dense_gemm_compute( |
| 29 | + cfg, data, weight, bias=None, out_dtype=None, transpose_a=False, transpose_b=True |
| 30 | +): |
| 31 | + """ |
| 32 | + Compute dense using GeMM. |
| 33 | +
|
| 34 | + Parameters |
| 35 | + ---------- |
| 36 | + cfg : Autotvm tuning space config file, |
| 37 | + empty in this case, but it's needed as an arg. |
| 38 | +
|
| 39 | + data : tvm.te.Tensor |
| 40 | + 2-D with shape [M, K] or [K, M]. |
| 41 | +
|
| 42 | + weight : tvm.te.Tensor |
| 43 | + 2-D with shape [K, N] or [N, K]. |
| 44 | +
|
| 45 | + bias : Optional[tvm.te.Tensor] |
| 46 | + 1-D with shape [N] |
| 47 | +
|
| 48 | +
|
| 49 | + out_dtype : Optional[str] |
| 50 | + Specifies the output data type. |
| 51 | +
|
| 52 | + transpose_a : Optional[bool] = False |
| 53 | + Whether the data tensor is in transposed format. |
| 54 | +
|
| 55 | + transpose_b : Optional[bool] = True |
| 56 | + Whether the weight tensor is in transposed format. |
| 57 | +
|
| 58 | + Returns |
| 59 | + ------- |
| 60 | + out : tvm.te.Tensor |
| 61 | + 1-D with shape [out_dim] |
| 62 | + """ |
| 63 | + |
| 64 | + if out_dtype is None: |
| 65 | + out_dtype = data.dtype |
| 66 | + M, K = get_const_tuple(data.shape) # batch, in_dim |
| 67 | + if bool(transpose_b): # out_dim |
| 68 | + (N, _) = get_const_tuple(weight.shape) |
| 69 | + else: |
| 70 | + (_, N) = get_const_tuple(weight.shape) |
| 71 | + |
| 72 | + tile_M, tile_K = get_tiling_A(False, out_dtype) |
| 73 | + tile_N, _ = get_tiling_B_transformed(False, out_dtype, False) |
| 74 | + |
| 75 | + M_padded, pad_M = pad_dim_to_multiple(M, tile_M) |
| 76 | + K_padded, pad_K = pad_dim_to_multiple(K, tile_K) |
| 77 | + N_padded, pad_N = pad_dim_to_multiple(N, tile_N) |
| 78 | + m_pad_after = (pad_M, pad_K) |
| 79 | + n_pad_after = (pad_N, pad_K) if transpose_b else (pad_K, pad_N) |
| 80 | + |
| 81 | + if pad_M != 0 or pad_K != 0: |
| 82 | + data = nn.pad(data, pad_before=(0, 0), pad_after=m_pad_after, name="data_padded") |
| 83 | + |
| 84 | + k = te.reduce_axis((0, K_padded), name="k") |
| 85 | + |
| 86 | + if bool(transpose_b): |
| 87 | + weight = te.compute( |
| 88 | + (K_padded, N_padded), lambda x, y: weight[y, x], name="weight_transposed" |
| 89 | + ) |
| 90 | + |
| 91 | + if pad_N != 0 or pad_K != 0: |
| 92 | + weight = nn.pad(weight, pad_before=(0, 0), pad_after=n_pad_after, name="weight_padded") |
| 93 | + |
| 94 | + C = te.compute( |
| 95 | + (M_padded, N_padded), |
| 96 | + lambda x, y: te.sum( |
| 97 | + data[x, k].astype(out_dtype) * weight[k, y].astype(out_dtype), |
| 98 | + axis=k, |
| 99 | + ).astype(out_dtype), |
| 100 | + name="C", |
| 101 | + ) |
| 102 | + |
| 103 | + if bias is not None: |
| 104 | + C = te.compute( |
| 105 | + (M_padded, N_padded), |
| 106 | + lambda i, j: C[i, j] + bias[j].astype(out_dtype), |
| 107 | + tag=tag.BROADCAST, |
| 108 | + name="dense_biased_output", |
| 109 | + ) |
| 110 | + |
| 111 | + # We need to ensure that infer bound pass does not remove the padding |
| 112 | + # which is necessary for the tensorizations to work. So we need to |
| 113 | + # add a dummy reference to the padding area of the result |
| 114 | + zero = ( |
| 115 | + tvm.tir.const(1, C.dtype) * C[0, N_padded - 1] |
| 116 | + - tvm.tir.const(1, C.dtype) * C[0, N_padded - 1] |
| 117 | + ) |
| 118 | + |
| 119 | + out = te.compute( |
| 120 | + (M, N), lambda x, y: (C[x, y] + zero).astype(out_dtype), name="dense_gemm_output" |
| 121 | + ) |
| 122 | + |
| 123 | + return out |
| 124 | + |
| 125 | + |
| 126 | +def _dense_gemm_schedule(s, out): |
| 127 | + C = out.op.input_tensors[0] |
| 128 | + A = C.op.input_tensors[0] |
| 129 | + out_type = A.dtype |
| 130 | + tile_M, tile_K = get_tiling_A(False, out_type) |
| 131 | + tile_N, _ = get_tiling_B_transformed(False, out_type, False) |
| 132 | + |
| 133 | + if C.op.name == "dense_biased_output": |
| 134 | + s[C].compute_inline() |
| 135 | + C = C.op.input_tensors[0] |
| 136 | + x, y = s[C].op.axis |
| 137 | + (k,) = s[C].op.reduce_axis |
| 138 | + |
| 139 | + k_outer, k_inner = s[C].split(k, factor=tile_K) |
| 140 | + x_outer, x_inner = s[C].split(x, factor=tile_M) |
| 141 | + y_outer, y_inner = s[C].split(y, factor=tile_N) |
| 142 | + y_inner_outer, y_inner_inner = s[C].split(y_inner, nparts=4) |
| 143 | + s[C].parallel(x_outer) |
| 144 | + s[C].reorder( |
| 145 | + x_outer, |
| 146 | + y_outer, |
| 147 | + k_outer, |
| 148 | + k_inner, |
| 149 | + y_inner_outer, |
| 150 | + x_inner, |
| 151 | + y_inner_inner, |
| 152 | + ) |
| 153 | + s[C].unroll(y_inner_outer) |
| 154 | + s[C].unroll(x_inner) |
| 155 | + s[C].vectorize(y_inner_inner) |
| 156 | + |
| 157 | + return s |
| 158 | + |
| 159 | + |
| 160 | +def dense_gemm_schedule(cfg, outs): |
| 161 | + """Schedule the dense_gemm strategy""" |
| 162 | + s = te.create_schedule([x.op for x in outs]) |
| 163 | + out = outs[0] |
| 164 | + x, y = out.op.axis |
| 165 | + _, inner = s[out].split(y, 4) |
| 166 | + s[out].parallel(x) |
| 167 | + s[out].vectorize(inner) |
| 168 | + |
| 169 | + def _callback(op): |
| 170 | + if "dense_gemm_output" in op.name: |
| 171 | + _dense_gemm_schedule(s, op.output(0)) |
| 172 | + |
| 173 | + traverse_inline(s, out.op, _callback) |
| 174 | + return s |
0 commit comments