Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
"""conv2d arm cpu strategy"""
strategy = _op.OpStrategy()
data, kernel = inputs
data_shape = data.shape
kernel_shape = kernel.shape
dilation_h, dilation_w = attrs.get_int_tuple("dilation")
stride_h, stride_w = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding")
Expand Down Expand Up @@ -258,6 +260,11 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
target.features.has_sme
and kernel.dtype == data.dtype
and out_type.dtype == "float32"
and data_shape[0] == 1
# The schedule uses tensorization which does not work when the
# reduction axis of the gemm has unit iters. See
# https://github.com/apache/tvm/issues/16566
and (data_shape[3] * kernel_shape[0] * kernel_shape[1]) > 1
):
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid_SME),
Expand Down
134 changes: 110 additions & 24 deletions python/tvm/tir/tensor_intrin/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,51 @@ def _create_ptrue_mask(dtype):
return T.broadcast(T.bool(True), tir.get_vscale_expr(dtype))


def get_sme_transpose_interleave_2svlx2svl_fp32_intrin():
def _create_active_lane_mask(tensor, relative_offsets, vertical_limit):
"""
Get the active lane mask intrinsic call for predicated accesses.

Parameters
----------
tensor : tvm.tir.Buffer
The tensor the buffer access will be performed on.
relative_offsets : Tuple[PrimExpr, PrimExpr]
The vertical and horizontal offsets into the accumulator tile.
vertical_limit : PrimExpr
An absolute offset specifying the limit at which rows should be stored.

Returns
-------
PrimExpr
The active lane mask intrinsic.
"""
vertical_offset, horizontal_offset = relative_offsets
stride = tensor.strides[0]

# The base is the offset of the first value we wish to store
base = T.int32(tensor.offset_of([vertical_offset, horizontal_offset])[0])

# The limit is the maximum offset in the current row of 'base' that we wish to allow values
# to be stored. Calculating this limit is a bit tricky since we can only request offsets of
# elements in the tensorized tile of the output tensor. One way to calculate this is to find
# the offset of the first value in the row of the output tensor that 'base' is in and add
# 'stride' to it.
limit = (
base
- T.int32(horizontal_offset)
- T.int32((tensor.offset_of([0, 0])[0] % stride))
+ T.int32(stride)
)
limit = T.Min(limit, T.Cast("int32", vertical_limit) * stride)

return T.get_active_lane_mask(
"uint1xvscalex4",
T.Cast("int32", base),
T.Cast("int32", limit),
)


def get_sme_transpose_interleave_2svlx2svl_fp32_intrin(cols, rows):
"""
Transpose a matrix of size 2SVL x 2SVL (where 'SVL' is the Scalable Vector Length) using
the Scalable Matrix Extension (SME).
Expand Down Expand Up @@ -247,9 +291,6 @@ def impl():
strides=[T.int32(), 1],
)

# Disable predication
ptrue = _create_ptrue_mask("float32")

with T.block("root"):
T.reads(A[0:SVF2, 0:SVF2])
T.writes(A_t[0:SVF2, 0:SVF2])
Expand All @@ -263,19 +304,22 @@ def impl():

input_ptr = A.access_ptr("r", offset=offset)
sub_tile = T.int32(sub_tile_idx)
predicate = _create_active_lane_mask(
A, (row_offset + slice_idx, col_offset), cols
)
T.evaluate(
T.call_llvm_intrin(
"void",
"llvm.aarch64.sme.ld1w.horiz",
T.uint32(4),
ptrue,
predicate,
input_ptr,
sub_tile,
slice_idx,
)
)

# Store columns to the ouptut matrix
# Store columns to the output matrix
with T.serial(0, SVF) as slice_idx:
for sub_tile_idx in range(0, sub_tile_count):
col_offset = SVF if sub_tile_idx >= (sub_tile_count // 2) else 0
Expand All @@ -284,12 +328,15 @@ def impl():

output_ptr = A_t.access_ptr("w", offset=offset)
sub_tile = T.int32(sub_tile_idx)
predicate = _create_active_lane_mask(
A_t, (row_offset + slice_idx, col_offset), rows
)
T.evaluate(
T.call_llvm_intrin(
"void",
"llvm.aarch64.sme.st1w.vert",
T.uint32(4),
ptrue,
predicate,
output_ptr,
sub_tile,
slice_idx,
Expand Down Expand Up @@ -445,7 +492,24 @@ def impl():
return desc, impl()


def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K, in_dtype):
def get_transpose_interleave_intrin_name(in_dtype, out_dtype, extent_cols, extent_rows):
if in_dtype == "float32" and out_dtype == "float32":
sme_transpose_interleave_intrin_name = (
ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE + f"_{extent_cols}_{extent_rows}"
)
tir.TensorIntrin.register(
sme_transpose_interleave_intrin_name,
*get_sme_transpose_interleave_2svlx2svl_fp32_intrin(extent_cols, extent_rows),
override=True,
)
return sme_transpose_interleave_intrin_name
elif in_dtype == "float16" and out_dtype == "float32":
return ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE
else:
raise ValueError("Input/output data type combination not supported.")


def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(M, K, in_dtype):
"""
Compute a GEMM of size 2SVL x 2SVL (where 'SVL' is the Scalable Vector Length using
outer product operations from the Scalable Matrix Extension (SME).
Expand Down Expand Up @@ -579,15 +643,39 @@ def impl():
k_row = k * rows_per_iter
in_dtype_svf = tir.get_vscale_expr(in_dtype)

a_low = T.BufferLoad(A, [k_row, T.Ramp(0, 1, in_dtype_svf)])
b_low = T.BufferLoad(B, [k_row, T.Ramp(0, 1, in_dtype_svf)])

# Ideally we'd rely on predicating the loads and use the same predicate
# for the outer product operation. However, support for predicated
# buffers is not currently supported by multiple lowering passes such as
# "LowerMatchBuffer", therefore the predicate is passed directly to the
# outer product operation for now.
if in_dtype == "float32":
a_high = T.BufferLoad(A, [k_row, T.Ramp(in_dtype_svf, 1, in_dtype_svf)])
b_high = T.BufferLoad(B, [k_row, T.Ramp(in_dtype_svf, 1, in_dtype_svf)])
a_low = (
T.BufferLoad(A, [k_row, T.Ramp(0, 1, in_dtype_svf)]),
_create_active_lane_mask(A, (k_row, 0), K),
)
b_low = (
T.BufferLoad(B, [k_row, T.Ramp(0, 1, in_dtype_svf)]),
_create_active_lane_mask(B, (k_row, 0), K),
)
a_high = (
T.BufferLoad(A, [k_row, T.Ramp(in_dtype_svf, 1, in_dtype_svf)]),
_create_active_lane_mask(A, (k_row, in_dtype_svf), K),
)
b_high = (
T.BufferLoad(B, [k_row, T.Ramp(in_dtype_svf, 1, in_dtype_svf)]),
_create_active_lane_mask(B, (k_row, in_dtype_svf), K),
)
else:
a_high = T.BufferLoad(A, [k_row + 1, T.Ramp(0, 1, in_dtype_svf)])
b_high = T.BufferLoad(B, [k_row + 1, T.Ramp(0, 1, in_dtype_svf)])
a_low = (T.BufferLoad(A, [k_row, T.Ramp(0, 1, in_dtype_svf)]), ptrue)
b_low = (T.BufferLoad(B, [k_row, T.Ramp(0, 1, in_dtype_svf)]), ptrue)
a_high = (
T.BufferLoad(A, [k_row + 1, T.Ramp(0, 1, in_dtype_svf)]),
ptrue,
)
b_high = (
T.BufferLoad(B, [k_row + 1, T.Ramp(0, 1, in_dtype_svf)]),
ptrue,
)

input_combinations = [
(a_low, b_low),
Expand All @@ -606,10 +694,10 @@ def impl():
fmopa_intrin,
T.uint32(5),
sub_tile,
ptrue,
ptrue,
input_1,
input_2,
input_1[1],
input_2[1],
input_1[0],
input_2[0],
)
)

Expand All @@ -626,7 +714,9 @@ def impl():
"void",
"llvm.aarch64.sme.st1w.horiz",
T.uint32(4),
_create_ptrue_mask("float32"),
_create_active_lane_mask(
C, (vert_offset + slice_idx, horiz_offset), M
),
output_ptr,
T.int32(sub_tile_idx),
T.int32(slice_idx),
Expand Down Expand Up @@ -691,10 +781,6 @@ def impl(c: T.handle) -> None:
# in versions of LLVM >= 15. Installations with older versions of LLVM will
# not be able to use them.
if llvm_version_major() >= 15:
TensorIntrin.register(
ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE,
*get_sme_transpose_interleave_2svlx2svl_fp32_intrin(),
)
TensorIntrin.register(
ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE,
*get_sme_transpose_interleave_block2_2svl_fp16_intrin(),
Expand Down
40 changes: 21 additions & 19 deletions python/tvm/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from tvm.script import tir as T
import tvm.contrib.nnpack
from tvm.tir.schedule.analysis import has_block
from tvm.topi.arm_cpu.matmul import _get_transpose_interleave_intrin_name

from ..utils import traverse_inline, get_const_tuple
from .. import nn
Expand Down Expand Up @@ -773,10 +772,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA,
ARM_SME_INIT,
get_sme_gemm_interleaved_mopa_2svlx2svl_intrin,
)

transpose_interleave_intrin_name = _get_transpose_interleave_intrin_name(
in_dtype, out_dtype
get_transpose_interleave_intrin_name,
)

# Interleave the padded im2col matrix utilizing the matrix tile
Expand All @@ -787,7 +783,9 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True)
sch.parallel(b)
sch.reorder(b, ko, mo, ki, mi)
sch.tensorize(ki, transpose_interleave_intrin_name)
sch.tensorize(
ki, get_transpose_interleave_intrin_name(in_dtype, out_dtype, M_padded, K_padded)
)

# Interleave the padded weights matrix utilizing the matrix tile
if in_dtype == "float16":
Expand All @@ -797,7 +795,9 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True)
no, ni = sch.split(n, factors=(None, tile_N), disable_predication=True)
sch.reorder(ko, no, ki, ni)
sch.tensorize(ki, transpose_interleave_intrin_name)
sch.tensorize(
ki, get_transpose_interleave_intrin_name(in_dtype, out_dtype, M_padded, K_padded)
)

# Split and reorder the loops of the GeMM for tensorization
b, m, n, k = sch.get_loops(gemm_block)
Expand All @@ -816,11 +816,11 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):

# Tensorize the GeMM update
sme_gemm_interleaved_intrin_name = (
ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{K_padded}_{in_dtype}"
ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{M_padded}_{K_padded}_{in_dtype}"
)
tvm.tir.TensorIntrin.register(
sme_gemm_interleaved_intrin_name,
*get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded, in_dtype),
*get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(M_padded, K_padded, in_dtype),
override=True,
)
sch.tensorize(mi, sme_gemm_interleaved_intrin_name)
Expand Down Expand Up @@ -922,16 +922,18 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
reshape_block = func_blocks["T_reshape"]
A_pad_block = func_blocks["A_padded_K"] if func_blocks["A_padded_K"] else None
A_pad_block = func_blocks["A_padded_M"] if func_blocks["A_padded_M"] else A_pad_block
if use_sme:
sch.compute_inline(reshape_block)
elif A_pad_block:
sch.compute_inline(reshape_block)
b, m, k = sch.get_loops(A_pad_block)
_, k_inner = sch.split(k, [None, tile_N])
sch.vectorize(k_inner)
sch.compute_at(A_pad_block, mi)
else:
sch.compute_at(reshape_block, mi)
use_explicit_predication = use_sme and in_dtype == "float32"
if not use_explicit_predication:
if use_sme:
sch.compute_inline(reshape_block)
elif A_pad_block:
sch.compute_inline(reshape_block)
b, m, k = sch.get_loops(A_pad_block)
_, k_inner = sch.split(k, [None, tile_N])
sch.vectorize(k_inner)
sch.compute_at(A_pad_block, mi)
else:
sch.compute_at(reshape_block, mi)

# Weight flattening
if func_blocks["weight_flatten"]:
Expand Down
39 changes: 26 additions & 13 deletions python/tvm/topi/arm_cpu/conv2d_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,23 +133,25 @@ def compute_conv2d_gemm_without_weight_transform(
)

# Pad to tiles (if necessary)
pad_M, pad_K = arm_utils.get_conv2d_im2col_padding(M, K, tile_M, tile_K_A)
pad_N, _ = arm_utils.get_conv2d_weights_padding(N, K, tile_N, tile_K_B)
use_explicit_predication = use_sme and in_dtype == "float32"
if not use_explicit_predication:
pad_M, pad_K = arm_utils.get_conv2d_im2col_padding(M, K, tile_M, tile_K_A)
pad_N, _ = arm_utils.get_conv2d_weights_padding(N, K, tile_N, tile_K_B)

M_padded = M + pad_M
K_padded = K + pad_K
N_padded = N + pad_N
M_padded = M + pad_M
K_padded = K + pad_K
N_padded = N + pad_N

pad_before = (0, 0, 0)
pad_after = (0, pad_M, pad_K)
pad_before = (0, 0, 0)
pad_after = (0, pad_M, pad_K)

if pad_K != 0:
A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded_K")
elif pad_M != 0:
A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded_M")
if pad_K != 0:
A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded_K")
elif pad_M != 0:
A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded_M")

idxm = tvm.tir.indexmod
k = te.reduce_axis((0, K_padded), "k")
k = te.reduce_axis((0, K if use_explicit_predication else K_padded), "k")

# Determine matrix multiplication compute definition
target = Target.current(allow_none=False)
Expand Down Expand Up @@ -300,7 +302,18 @@ def compute_conv2d_gemm_without_weight_transform(
name="C",
)
zero = tvm.tir.const(0)
elif use_scalable_vectors or use_sme:
elif use_explicit_predication:
assert len(B_interleaved_t.shape) == 2
C = te.compute(
(batches, M, N),
lambda b, x, y: te.sum(
A[b, x, k].astype(in_dtype) * B_interleaved_t[k, y].astype(in_dtype),
axis=k,
),
name="C",
)
zero = tvm.tir.const(0)
elif use_scalable_vectors:
assert len(B_interleaved_t.shape) == 2
C = te.compute(
(batches, M_padded, N_padded),
Expand Down
Loading