From 9b01077659ec64d7e3db6e8310134ac950d427aa Mon Sep 17 00:00:00 2001 From: Neil Hickey Date: Wed, 18 Jan 2023 16:11:58 +0000 Subject: [PATCH 1/2] [SVE] Adding codegen tests for SVE Now that ci_arm contains LLVM 15, add tests to check that code for SVE can be correctly generated. Add tests to cover a selection of arithmetic ops and gather loads. The gather load is currently xfailed until work is completed to enable them through TVM and LLVM. --- .../unittest/test_target_codegen_aarch64.py | 545 ++++++++++++++++++ 1 file changed, 545 insertions(+) create mode 100644 tests/python/unittest/test_target_codegen_aarch64.py diff --git a/tests/python/unittest/test_target_codegen_aarch64.py b/tests/python/unittest/test_target_codegen_aarch64.py new file mode 100644 index 000000000000..78bc604612a1 --- /dev/null +++ b/tests/python/unittest/test_target_codegen_aarch64.py @@ -0,0 +1,545 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te +from tvm.script import tir as TIR +import re +import os +import ctypes +import pytest + +from tvm.target.codegen import llvm_version_major + + +@pytest.mark.skipif( + llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" +) +def test_mul(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + + def check_correct_assembly(type): + m = te.var("m") + A = te.placeholder(m, dtype=type, name="A") + B = te.placeholder(m, dtype=type, name="B") + C = te.compute((m), lambda i: A[i] * B[i], name="C") + s = te.create_schedule([C.op]) + + f = tvm.build(s, [A, B, C], target) + + # Verify we see SVE load instructions and mul instructions using z registers + assembly = f.get_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + matches = re.findall( + r"mul\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) + + assert len(loads) > 1 + assert len(matches) > 1 + + check_correct_assembly("float") + check_correct_assembly("float16") + check_correct_assembly("uint8") + check_correct_assembly("uint16") + check_correct_assembly("uint32") + check_correct_assembly("uint64") + check_correct_assembly("int8") + check_correct_assembly("int16") + check_correct_assembly("int32") + check_correct_assembly("int64") + + +@pytest.mark.skipif( + llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" +) +def test_add(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + + def check_correct_assembly(type): + m = te.var("m") + A = te.placeholder(m, dtype=type, name="A") + B = te.placeholder(m, dtype=type, name="B") + C = te.compute((m), lambda i: A[i] + B[i], name="C") + s = te.create_schedule([C.op]) + + f = tvm.build(s, [A, B, C], target) + + # Verify we see SVE load instructions and add instructions using z registers + assembly = f.get_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + matches = re.findall( + r"add\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) + + assert len(loads) > 1 + assert len(matches) > 1 + + check_correct_assembly("float") + check_correct_assembly("float16") + check_correct_assembly("uint8") + check_correct_assembly("uint16") + check_correct_assembly("uint32") + check_correct_assembly("uint64") + check_correct_assembly("int8") + check_correct_assembly("int16") + check_correct_assembly("int32") + check_correct_assembly("int64") + + +@pytest.mark.skipif( + llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" +) +def test_sub(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + + def check_correct_assembly(type): + m = te.var("m") + A = te.placeholder(m, dtype=type, name="A") + B = te.placeholder(m, dtype=type, name="B") + C = te.compute((m), lambda i: A[i] - B[i], name="C") + s = te.create_schedule([C.op]) + + f = tvm.build(s, [A, B, C], target) + + # Verify we see SVE load instructions and sub instructions using z registers + assembly = f.get_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + matches = re.findall( + r"sub\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) + + assert len(loads) > 1 + assert len(matches) > 1 + + check_correct_assembly("float") + check_correct_assembly("float16") + check_correct_assembly("uint8") + check_correct_assembly("uint16") + check_correct_assembly("uint32") + check_correct_assembly("uint64") + check_correct_assembly("int8") + check_correct_assembly("int16") + check_correct_assembly("int32") + check_correct_assembly("int64") + + +@pytest.mark.skipif( + llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" +) +def test_muladd(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + + def check_correct_assembly(type): + m = te.var("m") + A = te.placeholder(m, dtype=type, name="A") + B = te.placeholder(m, dtype=type, name="B") + C = te.placeholder(m, dtype=type, name="C") + D = te.compute((m), lambda i: A[i] * B[i] + C[i], name="D") + s = te.create_schedule([D.op]) + + f = tvm.build(s, [A, B, C, D], target) + + # Verify we see SVE load instructions and either mad or mla instructions using z registers + assembly = f.get_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + matches = re.findall( + r"mad|mla\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) + + assert len(loads) > 1 + assert len(matches) > 1 + + check_correct_assembly("float") + check_correct_assembly("float16") + check_correct_assembly("uint8") + check_correct_assembly("uint16") + check_correct_assembly("uint32") + check_correct_assembly("uint64") + check_correct_assembly("int8") + check_correct_assembly("int16") + check_correct_assembly("int32") + check_correct_assembly("int64") + + +@pytest.mark.skipif( + llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" +) +def test_max(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + + def check_correct_assembly(type): + m = te.var("m") + A = te.placeholder(m, dtype=type, name="A") + B = te.placeholder(m, dtype=type, name="B") + C = te.compute((m), lambda i: tvm.te.max(A[i], B[i])) + s = te.create_schedule([C.op]) + + f = tvm.build(s, [A, B, C], target) + + # Verify we see SVE load instructions and cmgt + sel instructions or a max instruction, all using z registers + assembly = f.get_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + compare = re.findall( + r"cmgt\tp[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) + select = re.findall("sel\tz[0-9].[shdb], p[0-9], z[0-9].[shdb], z[0-9].[shdb]", assembly) + max = re.findall( + r"max\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) + + assert len(loads) > 1 + assert (len(compare) > 1 and len(select) == len(compare)) or len(max) > 1 + + check_correct_assembly("float") + check_correct_assembly("float16") + check_correct_assembly("uint8") + check_correct_assembly("uint16") + check_correct_assembly("uint32") + check_correct_assembly("uint64") + check_correct_assembly("int8") + check_correct_assembly("int16") + check_correct_assembly("int32") + check_correct_assembly("int64") + + +@pytest.mark.skipif( + llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" +) +def test_min(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + + def check_correct_assembly(type): + m = te.var("m") + A = te.placeholder(m, dtype=type, name="A") + B = te.placeholder(m, dtype=type, name="B") + C = te.compute((m), lambda i: tvm.te.min(A[i], B[i])) + s = te.create_schedule([C.op]) + + f = tvm.build(s, [A, B, C], target) + + # Verify we see SVE load instructions and cmgt + sel instructions or a min instruction, all using z registers + assembly = f.get_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + compare = re.findall( + r"cmgt\tp[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) + select = re.findall("sel\tz[0-9].[shdb], p[0-9], z[0-9].[shdb], z[0-9].[shdb]", assembly) + min = re.findall( + r"min\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) + + assert len(loads) > 1 + assert (len(compare) > 1 and len(select) == len(compare)) or len(min) > 1 + + check_correct_assembly("float") + check_correct_assembly("float16") + check_correct_assembly("uint8") + check_correct_assembly("uint16") + check_correct_assembly("uint32") + check_correct_assembly("uint64") + check_correct_assembly("int8") + check_correct_assembly("int16") + check_correct_assembly("int32") + check_correct_assembly("int64") + + +@pytest.mark.skipif( + llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" +) +def test_div(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + + def check_correct_assembly(type): + m = te.var("m") + A = te.placeholder(m, dtype=type, name="A") + B = te.placeholder(m, dtype=type, name="B") + C = te.compute((m), lambda i: tvm.te.div(A[i], B[i])) + s = te.create_schedule([C.op]) + + f = tvm.build(s, [A, B, C], target) + + # Verify we see SVE load instructions and div instructions using z registers + assembly = f.get_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + matches = re.findall( + r"div\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) + + assert len(loads) > 1 + assert len(matches) > 1 + + check_correct_assembly("float") + check_correct_assembly("float16") + check_correct_assembly("uint8") + check_correct_assembly("uint16") + check_correct_assembly("uint32") + check_correct_assembly("uint64") + check_correct_assembly("int8") + check_correct_assembly("int16") + check_correct_assembly("int32") + check_correct_assembly("int64") + + +@pytest.mark.skipif( + llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" +) +def test_mod(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + + def check_correct_assembly(type): + m = te.var("m") + A = te.placeholder(m, dtype=type, name="A") + B = te.placeholder(m, dtype=type, name="B") + C = te.compute((m), lambda i: tvm.te.floormod(A[i], B[i]), name="C") + s = te.create_schedule([C.op]) + + f = tvm.build(s, [A, B, C], target) + + # Verify we see SVE load instructions and mls instructions using z registers + assembly = f.get_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + matches = re.findall( + r"mls\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) + + assert len(loads) > 1 + assert len(matches) > 0 + + check_correct_assembly("uint8") + check_correct_assembly("uint16") + check_correct_assembly("uint32") + check_correct_assembly("uint64") + check_correct_assembly("int8") + check_correct_assembly("int16") + check_correct_assembly("int32") + check_correct_assembly("int64") + + +@pytest.mark.skipif( + llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" +) +def test_eq(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + + def check_correct_assembly(type): + m = te.var("m") + A = te.placeholder(m, dtype=type, name="A") + B = te.placeholder(m, dtype=type, name="B") + C = te.compute((m), lambda i: A[i] == B[i], name="C") + s = te.create_schedule([C.op]) + + f = tvm.build(s, [A, B, C], target) + + # Verify we see SVE load instructions and cmpeq or cmeq instructions using z registers + assembly = f.get_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + matches = re.findall( + r"cm(p)?eq\tp[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) + + assert len(loads) > 1 + assert len(matches) > 1 + + check_correct_assembly("float") + check_correct_assembly("float16") + check_correct_assembly("uint8") + check_correct_assembly("uint16") + check_correct_assembly("uint32") + check_correct_assembly("uint64") + check_correct_assembly("int8") + check_correct_assembly("int16") + check_correct_assembly("int32") + check_correct_assembly("int64") + + +@pytest.mark.skipif( + llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" +) +def test_neq(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + + def check_correct_assembly(type): + m = te.var("m") + A = te.placeholder(m, dtype=type, name="A") + B = te.placeholder(m, dtype=type, name="B") + C = te.compute((m), lambda i: A[i] != B[i], name="C") + s = te.create_schedule([C.op]) + + f = tvm.build(s, [A, B, C], target) + + # Verify we see SVE load instructions and cmpgt, cmgt, cmpne or cmne instructions, all using z registers + assembly = f.get_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + matches = re.findall( + r"cm(p)?(gt|ne)\tp[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) + + assert len(loads) > 1 + assert len(matches) > 1 + + check_correct_assembly("float") + check_correct_assembly("float16") + check_correct_assembly("uint8") + check_correct_assembly("uint16") + check_correct_assembly("uint32") + check_correct_assembly("uint64") + check_correct_assembly("int8") + check_correct_assembly("int16") + check_correct_assembly("int32") + check_correct_assembly("int64") + + +@pytest.mark.skipif( + llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" +) +def test_or(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + + def check_correct_assembly(type): + m = te.var("m") + A = te.placeholder(m, dtype=type, name="A") + B = te.placeholder(m, dtype=type, name="B") + C = te.compute((m), lambda i: A[i] | B[i], name="C") + s = te.create_schedule([C.op]) + + f = tvm.build(s, [A, B, C], target) + + # Verify we see SVE load instructions and orr instructions using z registers + assembly = f.get_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + matches = re.findall( + r"orr\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) + + assert len(loads) > 1 + assert len(matches) > 1 + + check_correct_assembly("uint8") + check_correct_assembly("uint16") + check_correct_assembly("uint32") + check_correct_assembly("uint64") + check_correct_assembly("int8") + check_correct_assembly("int16") + check_correct_assembly("int32") + check_correct_assembly("int64") + + +@pytest.mark.skipif( + llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" +) +def test_and(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + + def check_correct_assembly(type): + m = te.var("m") + A = te.placeholder(m, dtype=type, name="A") + B = te.placeholder(m, dtype=type, name="B") + C = te.compute((m), lambda i: A[i] & B[i], name="C") + s = te.create_schedule([C.op]) + + f = tvm.build(s, [A, B, C], target) + + # Verify we see SVE load instructions and and instructions using z registers + assembly = f.get_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + matches = re.findall( + r"and\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) + + assert len(loads) > 1 + assert len(matches) > 1 + + check_correct_assembly("uint8") + check_correct_assembly("uint16") + check_correct_assembly("uint32") + check_correct_assembly("uint64") + check_correct_assembly("int8") + check_correct_assembly("int16") + check_correct_assembly("int32") + check_correct_assembly("int64") + + +@pytest.mark.skipif( + llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" +) +def test_not(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + + def check_correct_assembly(type): + m = te.var("m") + A = te.placeholder(m, dtype=type, name="A") + C = te.compute((m), lambda i: ~A[i], name="C") + s = te.create_schedule([C.op]) + + f = tvm.build(s, [A, C], target) + + # Verify we see SVE load instructions and eor instructions using z registers + assembly = f.get_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + matches = re.findall( + r"eor\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly + ) + + assert len(loads) > 1 + assert len(matches) > 1 + + check_correct_assembly("uint8") + check_correct_assembly("uint16") + check_correct_assembly("uint32") + check_correct_assembly("uint64") + check_correct_assembly("int8") + check_correct_assembly("int16") + check_correct_assembly("int32") + check_correct_assembly("int64") + + +@pytest.mark.skipif( + llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" +) +@pytest.mark.xfail( + reason="Awaiting llvm support for gathered loads", + strict=True, +) +def test_memcpy(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + + def check_correct_assembly(type): + m = te.var("m") + A = te.placeholder(m, dtype=type, name="A") + B = te.placeholder(m, dtype="int32", name="B") + C = te.compute((m), lambda i: A[B[i]], name="C") + s = te.create_schedule([C.op]) + + f = tvm.build(s, [A, B, C], target) + + # Verify we see gather instructions in the assembly + assembly = f.get_source("asm") + loads = re.findall("ld1[whdb] { z", assembly) + + assert len(loads) > 0 + + check_correct_assembly("uint8") + check_correct_assembly("uint16") + check_correct_assembly("uint32") + check_correct_assembly("uint64") + check_correct_assembly("int8") + check_correct_assembly("int16") + check_correct_assembly("int32") + check_correct_assembly("int64") + + +if __name__ == "__main__": + tvm.testing.main() From b9cb29527a8036dace2123a075d697c31aad79cf Mon Sep 17 00:00:00 2001 From: Neil Hickey Date: Thu, 9 Mar 2023 13:46:57 +0000 Subject: [PATCH 2/2] Reworking so dtype is passed as a pytest param --- .../unittest/test_target_codegen_aarch64.py | 223 +++++++----------- 1 file changed, 79 insertions(+), 144 deletions(-) diff --git a/tests/python/unittest/test_target_codegen_aarch64.py b/tests/python/unittest/test_target_codegen_aarch64.py index 78bc604612a1..e873bce52bdf 100644 --- a/tests/python/unittest/test_target_codegen_aarch64.py +++ b/tests/python/unittest/test_target_codegen_aarch64.py @@ -28,7 +28,11 @@ @pytest.mark.skipif( llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" ) -def test_mul(): +@pytest.mark.parametrize( + "dtype", + ["float", "float16", "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"], +) +def test_mul(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" def check_correct_assembly(type): @@ -50,22 +54,17 @@ def check_correct_assembly(type): assert len(loads) > 1 assert len(matches) > 1 - check_correct_assembly("float") - check_correct_assembly("float16") - check_correct_assembly("uint8") - check_correct_assembly("uint16") - check_correct_assembly("uint32") - check_correct_assembly("uint64") - check_correct_assembly("int8") - check_correct_assembly("int16") - check_correct_assembly("int32") - check_correct_assembly("int64") + check_correct_assembly(type=dtype) @pytest.mark.skipif( llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" ) -def test_add(): +@pytest.mark.parametrize( + "dtype", + ["float", "float16", "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"], +) +def test_add(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" def check_correct_assembly(type): @@ -87,22 +86,17 @@ def check_correct_assembly(type): assert len(loads) > 1 assert len(matches) > 1 - check_correct_assembly("float") - check_correct_assembly("float16") - check_correct_assembly("uint8") - check_correct_assembly("uint16") - check_correct_assembly("uint32") - check_correct_assembly("uint64") - check_correct_assembly("int8") - check_correct_assembly("int16") - check_correct_assembly("int32") - check_correct_assembly("int64") + check_correct_assembly(type=dtype) @pytest.mark.skipif( llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" ) -def test_sub(): +@pytest.mark.parametrize( + "dtype", + ["float", "float16", "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"], +) +def test_sub(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" def check_correct_assembly(type): @@ -124,22 +118,17 @@ def check_correct_assembly(type): assert len(loads) > 1 assert len(matches) > 1 - check_correct_assembly("float") - check_correct_assembly("float16") - check_correct_assembly("uint8") - check_correct_assembly("uint16") - check_correct_assembly("uint32") - check_correct_assembly("uint64") - check_correct_assembly("int8") - check_correct_assembly("int16") - check_correct_assembly("int32") - check_correct_assembly("int64") + check_correct_assembly(type=dtype) @pytest.mark.skipif( llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" ) -def test_muladd(): +@pytest.mark.parametrize( + "dtype", + ["float", "float16", "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"], +) +def test_muladd(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" def check_correct_assembly(type): @@ -162,22 +151,17 @@ def check_correct_assembly(type): assert len(loads) > 1 assert len(matches) > 1 - check_correct_assembly("float") - check_correct_assembly("float16") - check_correct_assembly("uint8") - check_correct_assembly("uint16") - check_correct_assembly("uint32") - check_correct_assembly("uint64") - check_correct_assembly("int8") - check_correct_assembly("int16") - check_correct_assembly("int32") - check_correct_assembly("int64") + check_correct_assembly(type=dtype) @pytest.mark.skipif( llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" ) -def test_max(): +@pytest.mark.parametrize( + "dtype", + ["float", "float16", "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"], +) +def test_max(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" def check_correct_assembly(type): @@ -203,22 +187,17 @@ def check_correct_assembly(type): assert len(loads) > 1 assert (len(compare) > 1 and len(select) == len(compare)) or len(max) > 1 - check_correct_assembly("float") - check_correct_assembly("float16") - check_correct_assembly("uint8") - check_correct_assembly("uint16") - check_correct_assembly("uint32") - check_correct_assembly("uint64") - check_correct_assembly("int8") - check_correct_assembly("int16") - check_correct_assembly("int32") - check_correct_assembly("int64") + check_correct_assembly(type=dtype) @pytest.mark.skipif( llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" ) -def test_min(): +@pytest.mark.parametrize( + "dtype", + ["float", "float16", "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"], +) +def test_min(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" def check_correct_assembly(type): @@ -244,22 +223,17 @@ def check_correct_assembly(type): assert len(loads) > 1 assert (len(compare) > 1 and len(select) == len(compare)) or len(min) > 1 - check_correct_assembly("float") - check_correct_assembly("float16") - check_correct_assembly("uint8") - check_correct_assembly("uint16") - check_correct_assembly("uint32") - check_correct_assembly("uint64") - check_correct_assembly("int8") - check_correct_assembly("int16") - check_correct_assembly("int32") - check_correct_assembly("int64") + check_correct_assembly(type=dtype) @pytest.mark.skipif( llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" ) -def test_div(): +@pytest.mark.parametrize( + "dtype", + ["float", "float16", "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"], +) +def test_div(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" def check_correct_assembly(type): @@ -281,22 +255,16 @@ def check_correct_assembly(type): assert len(loads) > 1 assert len(matches) > 1 - check_correct_assembly("float") - check_correct_assembly("float16") - check_correct_assembly("uint8") - check_correct_assembly("uint16") - check_correct_assembly("uint32") - check_correct_assembly("uint64") - check_correct_assembly("int8") - check_correct_assembly("int16") - check_correct_assembly("int32") - check_correct_assembly("int64") + check_correct_assembly(type=dtype) @pytest.mark.skipif( llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" ) -def test_mod(): +@pytest.mark.parametrize( + "dtype", ["uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"] +) +def test_mod(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" def check_correct_assembly(type): @@ -318,20 +286,17 @@ def check_correct_assembly(type): assert len(loads) > 1 assert len(matches) > 0 - check_correct_assembly("uint8") - check_correct_assembly("uint16") - check_correct_assembly("uint32") - check_correct_assembly("uint64") - check_correct_assembly("int8") - check_correct_assembly("int16") - check_correct_assembly("int32") - check_correct_assembly("int64") + check_correct_assembly(type=dtype) @pytest.mark.skipif( llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" ) -def test_eq(): +@pytest.mark.parametrize( + "dtype", + ["float", "float16", "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"], +) +def test_eq(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" def check_correct_assembly(type): @@ -353,22 +318,17 @@ def check_correct_assembly(type): assert len(loads) > 1 assert len(matches) > 1 - check_correct_assembly("float") - check_correct_assembly("float16") - check_correct_assembly("uint8") - check_correct_assembly("uint16") - check_correct_assembly("uint32") - check_correct_assembly("uint64") - check_correct_assembly("int8") - check_correct_assembly("int16") - check_correct_assembly("int32") - check_correct_assembly("int64") + check_correct_assembly(type=dtype) @pytest.mark.skipif( llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" ) -def test_neq(): +@pytest.mark.parametrize( + "dtype", + ["float", "float16", "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"], +) +def test_neq(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" def check_correct_assembly(type): @@ -390,22 +350,16 @@ def check_correct_assembly(type): assert len(loads) > 1 assert len(matches) > 1 - check_correct_assembly("float") - check_correct_assembly("float16") - check_correct_assembly("uint8") - check_correct_assembly("uint16") - check_correct_assembly("uint32") - check_correct_assembly("uint64") - check_correct_assembly("int8") - check_correct_assembly("int16") - check_correct_assembly("int32") - check_correct_assembly("int64") + check_correct_assembly(type=dtype) @pytest.mark.skipif( llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" ) -def test_or(): +@pytest.mark.parametrize( + "dtype", ["uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"] +) +def test_or(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" def check_correct_assembly(type): @@ -427,20 +381,16 @@ def check_correct_assembly(type): assert len(loads) > 1 assert len(matches) > 1 - check_correct_assembly("uint8") - check_correct_assembly("uint16") - check_correct_assembly("uint32") - check_correct_assembly("uint64") - check_correct_assembly("int8") - check_correct_assembly("int16") - check_correct_assembly("int32") - check_correct_assembly("int64") + check_correct_assembly(type=dtype) @pytest.mark.skipif( llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" ) -def test_and(): +@pytest.mark.parametrize( + "dtype", ["uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"] +) +def test_and(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" def check_correct_assembly(type): @@ -462,20 +412,16 @@ def check_correct_assembly(type): assert len(loads) > 1 assert len(matches) > 1 - check_correct_assembly("uint8") - check_correct_assembly("uint16") - check_correct_assembly("uint32") - check_correct_assembly("uint64") - check_correct_assembly("int8") - check_correct_assembly("int16") - check_correct_assembly("int32") - check_correct_assembly("int64") + check_correct_assembly(type=dtype) @pytest.mark.skipif( llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE" ) -def test_not(): +@pytest.mark.parametrize( + "dtype", ["uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"] +) +def test_not(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" def check_correct_assembly(type): @@ -496,14 +442,7 @@ def check_correct_assembly(type): assert len(loads) > 1 assert len(matches) > 1 - check_correct_assembly("uint8") - check_correct_assembly("uint16") - check_correct_assembly("uint32") - check_correct_assembly("uint64") - check_correct_assembly("int8") - check_correct_assembly("int16") - check_correct_assembly("int32") - check_correct_assembly("int64") + check_correct_assembly(type=dtype) @pytest.mark.skipif( @@ -513,7 +452,10 @@ def check_correct_assembly(type): reason="Awaiting llvm support for gathered loads", strict=True, ) -def test_memcpy(): +@pytest.mark.parametrize( + "dtype", ["uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"] +) +def test_memcpy(dtype): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" def check_correct_assembly(type): @@ -531,14 +473,7 @@ def check_correct_assembly(type): assert len(loads) > 0 - check_correct_assembly("uint8") - check_correct_assembly("uint16") - check_correct_assembly("uint32") - check_correct_assembly("uint64") - check_correct_assembly("int8") - check_correct_assembly("int16") - check_correct_assembly("int32") - check_correct_assembly("int64") + check_correct_assembly(type=dtype) if __name__ == "__main__":