Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions python/tvm/topi/arm_cpu/pstate_attributes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Specialized attributes that can be added to schedules to alter
the behaviour of AArch64 codegen.
"""


class SMEAttributes:
"""
This class serves as a convenience wrapper for processor state annotations
relating to the Scalable Matrix Extension (SME). Processor state annotations
are inserted at compile time and alter some global state of the processor
during execution. For example, the streaming mode attribute can be used to
transfer some vector operations to a separate processing element. These
attributes can be added to block-level annotations in AArch64 schedules to
define a desired state.

Please refer to the following pages for more information regarding the SME
attributes and their behaviours:
- https://arm-software.github.io/acle/main/acle.html#markdown-toc-sme-attributes
- https://llvm.org/docs/AArch64SME.html

Attributes
----------
STREAMING_MODE : str
Whether execution should occur in regular mode or streaming mode. When
enabled, some vector operations may be transferred to a separate processing
element.
ZA_STORAGE : str
Defines how the ZA area of storage provided by the SME extension should be
utilized.
"""

STREAMING_MODE = "pragma_aarch64_pstate_sm"

class StreamingModeValues:
"""
Streaming mode attribute values. By default, a function is considered
'non-streaming' (often referred to as 'regular').

Attributes
----------
ENABLED : str
The processor state must be in streaming mode before executing the marked function.
COMPATIBLE : str
The marked function can be run in either streaming or non-streaming mode.
"""

ENABLED = "enabled"
COMPATIBLE = "compatible"

ZA_STORAGE = "pragma_aarch64_pstate_za"

class ZAStorageValues:
"""
ZA Storage attribure values. By default, a function has no ZA state. In other words, it
does not use the ZA storage.

Attributes
----------
NEW : str
A new ZA state is created "from scratch".
SHARED : str
The ZA state is shared with the calling function.
"""

NEW = "new"
SHARED = "shared"
102 changes: 102 additions & 0 deletions src/target/llvm/codegen_aarch64.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/target/llvm/codegen_aarch64.cc
* \brief AArch64 specific LLVM code generator.
*/
#ifdef TVM_LLVM_VERSION

#include <llvm/IR/Intrinsics.h>
#include <llvm/Target/TargetMachine.h>
#include <tvm/runtime/registry.h>

#include "codegen_cpu.h"
#include "llvm_instance.h"

namespace tvm {
namespace codegen {

class CodeGenAArch64 final : public CodeGenCPU {
public:
CodeGenAArch64() = default;
virtual ~CodeGenAArch64() = default;

void VisitStmt_(const AttrStmtNode* op);
void AddFunction(const GlobalVar& gvar, const PrimFunc& f);

bool func_has_pstate_sm = false;
bool func_has_pstate_za = false;
};

void CodeGenAArch64::AddFunction(const GlobalVar& gvar, const PrimFunc& f) {
func_has_pstate_sm = false;
func_has_pstate_za = false;
CodeGenCPU::AddFunction(gvar, f);
}

/*!
* \brief Visit and handle AArch64 specific pragmas. To be AArch64 specific,
* the expectation is that they are prepended with "pragma_aarch64".
*/
void CodeGenAArch64::VisitStmt_(const AttrStmtNode* op) {
std::string attr_key = op->attr_key;

if (!tir::attr::IsPragmaKey(attr_key)) {
CodeGenCPU::VisitStmt_(op);
return;
}
bool is_aarch64_specific_pragma = attr_key.substr(7, 7) == "aarch64";
if (!is_aarch64_specific_pragma) {
CodeGenCPU::VisitStmt_(op);
return;
}

const auto* attr_value = op->value.as<StringImmNode>();
ICHECK(attr_value) << "Expect " << attr_key << " to have a String value but was "
<< op->value->GetTypeKey();

std::string aarch64_attr_key = attr_key.substr(7);
if (aarch64_attr_key == "aarch64_pstate_sm") {
ICHECK(!func_has_pstate_sm) << "Multiple definitions of " << op->attr_key
<< " attribute found in the function "
<< function_->getName().data();
function_->addFnAttr({aarch64_attr_key + "_" + attr_value->value});
func_has_pstate_sm = true;
} else if (aarch64_attr_key == "aarch64_pstate_za") {
ICHECK(!func_has_pstate_za) << "Multiple definitions of " << op->attr_key
<< " attribute found in the function "
<< function_->getName().data();
function_->addFnAttr({aarch64_attr_key + "_" + attr_value->value});
func_has_pstate_za = true;
} else {
LOG(WARNING) << "Unknown pragma " << op->attr_key;
}
this->VisitStmt(op->body);
}

TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_aarch64")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
*rv = static_cast<void*>(new CodeGenAArch64());
});

} // namespace codegen
} // namespace tvm

#endif // TVM_LLVM_VERSION
116 changes: 114 additions & 2 deletions tests/python/codegen/test_target_codegen_aarch64.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import re

import pytest

import tvm
from tvm import te
from tvm.script import tir as T
import re
import pytest
from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes

from tvm.target.codegen import llvm_version_major

Expand Down Expand Up @@ -533,5 +537,113 @@ def my_func(a: T.handle):
assert re.findall(r" store <vscale x 4 x float>", llvm), "No scalable store in generated LLVM."


@pytest.mark.skipif(
llvm_version_major() < 16, reason="Test requires an LLVM version of at least 16 to target SME"
)
@pytest.mark.parametrize(
"attr_key,attr_value,expected",
[
(
SMEAttributes.STREAMING_MODE,
SMEAttributes.StreamingModeValues.ENABLED,
"aarch64_pstate_sm_enabled",
),
(
SMEAttributes.STREAMING_MODE,
SMEAttributes.StreamingModeValues.COMPATIBLE,
"aarch64_pstate_sm_compatible",
),
(SMEAttributes.ZA_STORAGE, SMEAttributes.ZAStorageValues.NEW, "aarch64_pstate_za_new"),
(
SMEAttributes.ZA_STORAGE,
SMEAttributes.ZAStorageValues.SHARED,
"aarch64_pstate_za_shared",
),
],
)
def test_function_attributes(attr_key, attr_value, expected):
target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sme"

@T.prim_func
def prim_func(a: T.handle, c: T.handle):
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
A = T.match_buffer(a, (16,), "float32")
C = T.match_buffer(c, (1,), "float32")

with T.block("extern"):
T.block_attr({attr_key: attr_value})
for i in range(16):
C[0] += A[i]

func = tvm.build(prim_func, target=target)
ll = func.get_source("ll")

# Check that the attribute exists
attr = re.findall(rf".*{expected}*.", ll)
assert attr, f"Function attribute {expected} was not found in generated LLVM IR"

# Check this attribute is used on the "compute" function
func_attr_label = attr[0].split(" ")[1]
found_compute_func = False
for match in re.findall(rf".*{func_attr_label}*.", ll):
if "_compute_" in match:
found_compute_func = True

assert found_compute_func, (
f"The attribute {expected} was found to be under the label {func_attr_label}, "
"but it was not used by the 'compute' scope function."
)


def test_unsupported_function_attribute_type():
target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sme"

@T.prim_func
def prim_func(a: T.handle, c: T.handle):
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
A = T.match_buffer(a, (16,), "float32")
C = T.match_buffer(c, (1,), "float32")

with T.block("extern"):
T.block_attr({SMEAttributes.STREAMING_MODE: True})
with T.block("root"):
for i in range(16):
C[0] += A[i]

err_msg = f"Expect {SMEAttributes.STREAMING_MODE} to have a String value but was IntImm"
with pytest.raises(tvm.error.TVMError, match=err_msg):
tvm.build(prim_func, target=target)


@pytest.mark.parametrize(
"attr_key,attr_value",
[
(SMEAttributes.STREAMING_MODE, SMEAttributes.StreamingModeValues.ENABLED),
(SMEAttributes.ZA_STORAGE, SMEAttributes.ZAStorageValues.NEW),
],
)
def test_unsupported_multiple_function_attributes(attr_key, attr_value):
target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sme"

@T.prim_func
def prim_func(a: T.handle, c: T.handle):
A = T.match_buffer(a, (16,), "float32")
C = T.match_buffer(c, (1,), "float32")

with T.block("root"):
with T.block("extern"):
T.block_attr({attr_key: attr_value})
for i in range(16):
C[0] += A[i] * 2
with T.block("extern2"):
T.block_attr({attr_key: attr_value})
for i in range(16):
C[0] += A[i] * 3

err_msg = f"Multiple definitions of {attr_key} attribute found in the function default_function_compute_"
with pytest.raises(tvm.error.TVMError, match=err_msg):
tvm.build(prim_func, target=target)


if __name__ == "__main__":
tvm.testing.main()