diff --git a/python/tvm/topi/arm_cpu/pstate_attributes.py b/python/tvm/topi/arm_cpu/pstate_attributes.py new file mode 100644 index 000000000000..439337bac5b2 --- /dev/null +++ b/python/tvm/topi/arm_cpu/pstate_attributes.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Specialized attributes that can be added to schedules to alter +the behaviour of AArch64 codegen. +""" + + +class SMEAttributes: + """ + This class serves as a convenience wrapper for processor state annotations + relating to the Scalable Matrix Extension (SME). Processor state annotations + are inserted at compile time and alter some global state of the processor + during execution. For example, the streaming mode attribute can be used to + transfer some vector operations to a separate processing element. These + attributes can be added to block-level annotations in AArch64 schedules to + define a desired state. + + Please refer to the following pages for more information regarding the SME + attributes and their behaviours: + - https://arm-software.github.io/acle/main/acle.html#markdown-toc-sme-attributes + - https://llvm.org/docs/AArch64SME.html + + Attributes + ---------- + STREAMING_MODE : str + Whether execution should occur in regular mode or streaming mode. When + enabled, some vector operations may be transferred to a separate processing + element. + ZA_STORAGE : str + Defines how the ZA area of storage provided by the SME extension should be + utilized. + """ + + STREAMING_MODE = "pragma_aarch64_pstate_sm" + + class StreamingModeValues: + """ + Streaming mode attribute values. By default, a function is considered + 'non-streaming' (often referred to as 'regular'). + + Attributes + ---------- + ENABLED : str + The processor state must be in streaming mode before executing the marked function. + COMPATIBLE : str + The marked function can be run in either streaming or non-streaming mode. + """ + + ENABLED = "enabled" + COMPATIBLE = "compatible" + + ZA_STORAGE = "pragma_aarch64_pstate_za" + + class ZAStorageValues: + """ + ZA Storage attribure values. By default, a function has no ZA state. In other words, it + does not use the ZA storage. + + Attributes + ---------- + NEW : str + A new ZA state is created "from scratch". + SHARED : str + The ZA state is shared with the calling function. + """ + + NEW = "new" + SHARED = "shared" diff --git a/src/target/llvm/codegen_aarch64.cc b/src/target/llvm/codegen_aarch64.cc new file mode 100644 index 000000000000..94ad34bbcff2 --- /dev/null +++ b/src/target/llvm/codegen_aarch64.cc @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/target/llvm/codegen_aarch64.cc + * \brief AArch64 specific LLVM code generator. + */ +#ifdef TVM_LLVM_VERSION + +#include +#include +#include + +#include "codegen_cpu.h" +#include "llvm_instance.h" + +namespace tvm { +namespace codegen { + +class CodeGenAArch64 final : public CodeGenCPU { + public: + CodeGenAArch64() = default; + virtual ~CodeGenAArch64() = default; + + void VisitStmt_(const AttrStmtNode* op); + void AddFunction(const GlobalVar& gvar, const PrimFunc& f); + + bool func_has_pstate_sm = false; + bool func_has_pstate_za = false; +}; + +void CodeGenAArch64::AddFunction(const GlobalVar& gvar, const PrimFunc& f) { + func_has_pstate_sm = false; + func_has_pstate_za = false; + CodeGenCPU::AddFunction(gvar, f); +} + +/*! + * \brief Visit and handle AArch64 specific pragmas. To be AArch64 specific, + * the expectation is that they are prepended with "pragma_aarch64". + */ +void CodeGenAArch64::VisitStmt_(const AttrStmtNode* op) { + std::string attr_key = op->attr_key; + + if (!tir::attr::IsPragmaKey(attr_key)) { + CodeGenCPU::VisitStmt_(op); + return; + } + bool is_aarch64_specific_pragma = attr_key.substr(7, 7) == "aarch64"; + if (!is_aarch64_specific_pragma) { + CodeGenCPU::VisitStmt_(op); + return; + } + + const auto* attr_value = op->value.as(); + ICHECK(attr_value) << "Expect " << attr_key << " to have a String value but was " + << op->value->GetTypeKey(); + + std::string aarch64_attr_key = attr_key.substr(7); + if (aarch64_attr_key == "aarch64_pstate_sm") { + ICHECK(!func_has_pstate_sm) << "Multiple definitions of " << op->attr_key + << " attribute found in the function " + << function_->getName().data(); + function_->addFnAttr({aarch64_attr_key + "_" + attr_value->value}); + func_has_pstate_sm = true; + } else if (aarch64_attr_key == "aarch64_pstate_za") { + ICHECK(!func_has_pstate_za) << "Multiple definitions of " << op->attr_key + << " attribute found in the function " + << function_->getName().data(); + function_->addFnAttr({aarch64_attr_key + "_" + attr_value->value}); + func_has_pstate_za = true; + } else { + LOG(WARNING) << "Unknown pragma " << op->attr_key; + } + this->VisitStmt(op->body); +} + +TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_aarch64") + .set_body([](const TVMArgs& targs, TVMRetValue* rv) { + *rv = static_cast(new CodeGenAArch64()); + }); + +} // namespace codegen +} // namespace tvm + +#endif // TVM_LLVM_VERSION diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index 773c113f4a42..80aedd60b3f7 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -14,11 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +import re + +import pytest + import tvm from tvm import te from tvm.script import tir as T -import re -import pytest +from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes from tvm.target.codegen import llvm_version_major @@ -533,5 +537,113 @@ def my_func(a: T.handle): assert re.findall(r" store ", llvm), "No scalable store in generated LLVM." +@pytest.mark.skipif( + llvm_version_major() < 16, reason="Test requires an LLVM version of at least 16 to target SME" +) +@pytest.mark.parametrize( + "attr_key,attr_value,expected", + [ + ( + SMEAttributes.STREAMING_MODE, + SMEAttributes.StreamingModeValues.ENABLED, + "aarch64_pstate_sm_enabled", + ), + ( + SMEAttributes.STREAMING_MODE, + SMEAttributes.StreamingModeValues.COMPATIBLE, + "aarch64_pstate_sm_compatible", + ), + (SMEAttributes.ZA_STORAGE, SMEAttributes.ZAStorageValues.NEW, "aarch64_pstate_za_new"), + ( + SMEAttributes.ZA_STORAGE, + SMEAttributes.ZAStorageValues.SHARED, + "aarch64_pstate_za_shared", + ), + ], +) +def test_function_attributes(attr_key, attr_value, expected): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sme" + + @T.prim_func + def prim_func(a: T.handle, c: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + A = T.match_buffer(a, (16,), "float32") + C = T.match_buffer(c, (1,), "float32") + + with T.block("extern"): + T.block_attr({attr_key: attr_value}) + for i in range(16): + C[0] += A[i] + + func = tvm.build(prim_func, target=target) + ll = func.get_source("ll") + + # Check that the attribute exists + attr = re.findall(rf".*{expected}*.", ll) + assert attr, f"Function attribute {expected} was not found in generated LLVM IR" + + # Check this attribute is used on the "compute" function + func_attr_label = attr[0].split(" ")[1] + found_compute_func = False + for match in re.findall(rf".*{func_attr_label}*.", ll): + if "_compute_" in match: + found_compute_func = True + + assert found_compute_func, ( + f"The attribute {expected} was found to be under the label {func_attr_label}, " + "but it was not used by the 'compute' scope function." + ) + + +def test_unsupported_function_attribute_type(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sme" + + @T.prim_func + def prim_func(a: T.handle, c: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + A = T.match_buffer(a, (16,), "float32") + C = T.match_buffer(c, (1,), "float32") + + with T.block("extern"): + T.block_attr({SMEAttributes.STREAMING_MODE: True}) + with T.block("root"): + for i in range(16): + C[0] += A[i] + + err_msg = f"Expect {SMEAttributes.STREAMING_MODE} to have a String value but was IntImm" + with pytest.raises(tvm.error.TVMError, match=err_msg): + tvm.build(prim_func, target=target) + + +@pytest.mark.parametrize( + "attr_key,attr_value", + [ + (SMEAttributes.STREAMING_MODE, SMEAttributes.StreamingModeValues.ENABLED), + (SMEAttributes.ZA_STORAGE, SMEAttributes.ZAStorageValues.NEW), + ], +) +def test_unsupported_multiple_function_attributes(attr_key, attr_value): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sme" + + @T.prim_func + def prim_func(a: T.handle, c: T.handle): + A = T.match_buffer(a, (16,), "float32") + C = T.match_buffer(c, (1,), "float32") + + with T.block("root"): + with T.block("extern"): + T.block_attr({attr_key: attr_value}) + for i in range(16): + C[0] += A[i] * 2 + with T.block("extern2"): + T.block_attr({attr_key: attr_value}) + for i in range(16): + C[0] += A[i] * 3 + + err_msg = f"Multiple definitions of {attr_key} attribute found in the function default_function_compute_" + with pytest.raises(tvm.error.TVMError, match=err_msg): + tvm.build(prim_func, target=target) + + if __name__ == "__main__": tvm.testing.main()