Skip to content

Commit a772633

Browse files
committed
[mlir][ArmSME] Disallow streaming mode for gathers/scatters (llvm#96209)
Ideally, this would be based on target information (but we don't really have that), so this currently errs on the side of caution. If possible gathers/scatters should be lowered regular vector loads/stores before using invoking enable-arm-streaming.
1 parent c7a4d23 commit a772633

File tree

4 files changed

+55
-22
lines changed

4 files changed

+55
-22
lines changed

mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,10 @@ def EnableArmStreaming
120120
/*default=*/"false",
121121
"Only apply the selected streaming/ZA modes if the function contains"
122122
" ops that implement the ArmSMETileOpInterface.">,
123-
Option<"ifContainsScalableVectors", "if-contains-scalable-vectors",
123+
Option<"ifScalableAndSupported", "if-scalable-and-supported",
124124
"bool", /*default=*/"false",
125125
"Only apply the selected streaming/ZA modes if the function contains"
126-
" operations that use scalable vector types.">
126+
" supported scalable vector operations.">
127127
];
128128
let dependentDialects = ["func::FuncDialect"];
129129
}

mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -55,22 +55,33 @@ namespace {
5555
constexpr StringLiteral
5656
kEnableArmStreamingIgnoreAttr("enable_arm_streaming_ignore");
5757

58+
template <typename... Ops>
59+
constexpr auto opList() {
60+
return std::array{TypeID::get<Ops>()...};
61+
}
62+
63+
bool isScalableVector(Type type) {
64+
if (auto vectorType = dyn_cast<VectorType>(type))
65+
return vectorType.isScalable();
66+
return false;
67+
}
68+
5869
struct EnableArmStreamingPass
5970
: public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
6071
EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
61-
bool ifRequiredByOps, bool ifContainsScalableVectors) {
72+
bool ifRequiredByOps, bool ifScalableAndSupported) {
6273
this->streamingMode = streamingMode;
6374
this->zaMode = zaMode;
6475
this->ifRequiredByOps = ifRequiredByOps;
65-
this->ifContainsScalableVectors = ifContainsScalableVectors;
76+
this->ifScalableAndSupported = ifScalableAndSupported;
6677
}
6778
void runOnOperation() override {
6879
auto function = getOperation();
6980

70-
if (ifRequiredByOps && ifContainsScalableVectors) {
81+
if (ifRequiredByOps && ifScalableAndSupported) {
7182
function->emitOpError(
7283
"enable-arm-streaming: `if-required-by-ops` and "
73-
"`if-contains-scalable-vectors` are mutually exclusive");
84+
"`if-scalable-and-supported` are mutually exclusive");
7485
return signalPassFailure();
7586
}
7687

@@ -87,22 +98,27 @@ struct EnableArmStreamingPass
8798
return;
8899
}
89100

90-
if (ifContainsScalableVectors) {
91-
bool foundScalableVector = false;
92-
auto isScalableVector = [&](Type type) {
93-
if (auto vectorType = dyn_cast<VectorType>(type))
94-
return vectorType.isScalable();
95-
return false;
96-
};
101+
if (ifScalableAndSupported) {
102+
// FIXME: This should be based on target information (i.e., the presence
103+
// of FEAT_SME_FA64). This currently errs on the side of caution. If
104+
// possible gathers/scatters should be lowered regular vector loads/stores
105+
// before invoking this pass.
106+
auto disallowedOperations = opList<vector::GatherOp, vector::ScatterOp>();
107+
bool isCompatibleScalableFunction = false;
97108
function.walk([&](Operation *op) {
98-
if (llvm::any_of(op->getOperandTypes(), isScalableVector) ||
99-
llvm::any_of(op->getResultTypes(), isScalableVector)) {
100-
foundScalableVector = true;
109+
if (llvm::is_contained(disallowedOperations,
110+
op->getName().getTypeID())) {
111+
isCompatibleScalableFunction = false;
101112
return WalkResult::interrupt();
102113
}
114+
if (!isCompatibleScalableFunction &&
115+
(llvm::any_of(op->getOperandTypes(), isScalableVector) ||
116+
llvm::any_of(op->getResultTypes(), isScalableVector))) {
117+
isCompatibleScalableFunction = true;
118+
}
103119
return WalkResult::advance();
104120
});
105-
if (!foundScalableVector)
121+
if (!isCompatibleScalableFunction)
106122
return;
107123
}
108124

@@ -126,7 +142,7 @@ struct EnableArmStreamingPass
126142

127143
std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
128144
const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
129-
bool ifRequiredByOps, bool ifContainsScalableVectors) {
145+
bool ifRequiredByOps, bool ifScalableAndSupported) {
130146
return std::make_unique<EnableArmStreamingPass>(
131-
streamingMode, zaMode, ifRequiredByOps, ifContainsScalableVectors);
147+
streamingMode, zaMode, ifRequiredByOps, ifScalableAndSupported);
132148
}
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -enable-arm-streaming="if-contains-scalable-vectors if-required-by-ops" -verify-diagnostics
1+
// RUN: mlir-opt %s -enable-arm-streaming="if-scalable-and-supported if-required-by-ops" -verify-diagnostics
22

3-
// expected-error@below {{enable-arm-streaming: `if-required-by-ops` and `if-contains-scalable-vectors` are mutually exclusive}}
3+
// expected-error@below {{enable-arm-streaming: `if-required-by-ops` and `if-scalable-and-supported` are mutually exclusive}}
44
func.func @test() { return }

mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
// RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-compatible -verify-diagnostics | FileCheck %s -check-prefix=CHECK-COMPATIBLE
44
// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
55
// RUN: mlir-opt %s -enable-arm-streaming=if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
6-
// RUN: mlir-opt %s -enable-arm-streaming=if-contains-scalable-vectors -verify-diagnostics | FileCheck %s -check-prefix=IF-SCALABLE
6+
// RUN: mlir-opt %s -enable-arm-streaming=if-scalable-and-supported -verify-diagnostics | FileCheck %s -check-prefix=IF-SCALABLE
77

88
// CHECK-LABEL: @arm_streaming
99
// CHECK-SAME: attributes {arm_streaming}
@@ -53,3 +53,20 @@ func.func @no_scalable_vectors(%vec: vector<4xf32>) -> vector<4xf32> {
5353
%0 = arith.addf %vec, %vec : vector<4xf32>
5454
return %0 : vector<4xf32>
5555
}
56+
57+
// IF-SCALABLE-LABEL: @contains_gather
58+
// IF-SCALABLE-NOT: arm_streaming
59+
func.func @contains_gather(%base: memref<?xf32>, %v: vector<[4]xindex>, %mask: vector<[4]xi1>, %pass_thru: vector<[4]xf32>) -> vector<[4]xf32> {
60+
%c0 = arith.constant 0 : index
61+
%0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<?xf32>, vector<[4]xindex>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
62+
return %0 : vector<[4]xf32>
63+
}
64+
65+
// IF-SCALABLE-LABEL: @contains_scatter
66+
// IF-SCALABLE-NOT: arm_streaming
67+
func.func @contains_scatter(%base: memref<?xf32>, %v: vector<[4]xindex>,%mask: vector<[4]xi1>, %value: vector<[4]xf32>)
68+
{
69+
%c0 = arith.constant 0 : index
70+
vector.scatter %base[%c0][%v], %mask, %value : memref<?xf32>, vector<[4]xindex>, vector<[4]xi1>, vector<[4]xf32>
71+
return
72+
}

0 commit comments

Comments
 (0)