@@ -55,22 +55,33 @@ namespace {
5555constexpr StringLiteral
5656 kEnableArmStreamingIgnoreAttr (" enable_arm_streaming_ignore" );
5757
58+ template <typename ... Ops>
59+ constexpr auto opList () {
60+ return std::array{TypeID::get<Ops>()...};
61+ }
62+
63+ bool isScalableVector (Type type) {
64+ if (auto vectorType = dyn_cast<VectorType>(type))
65+ return vectorType.isScalable ();
66+ return false ;
67+ }
68+
5869struct EnableArmStreamingPass
5970 : public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
6071 EnableArmStreamingPass (ArmStreamingMode streamingMode, ArmZaMode zaMode,
61- bool ifRequiredByOps, bool ifContainsScalableVectors ) {
72+ bool ifRequiredByOps, bool ifScalableAndSupported ) {
6273 this ->streamingMode = streamingMode;
6374 this ->zaMode = zaMode;
6475 this ->ifRequiredByOps = ifRequiredByOps;
65- this ->ifContainsScalableVectors = ifContainsScalableVectors ;
76+ this ->ifScalableAndSupported = ifScalableAndSupported ;
6677 }
6778 void runOnOperation () override {
6879 auto function = getOperation ();
6980
70- if (ifRequiredByOps && ifContainsScalableVectors ) {
81+ if (ifRequiredByOps && ifScalableAndSupported ) {
7182 function->emitOpError (
7283 " enable-arm-streaming: `if-required-by-ops` and "
73- " `if-contains- scalable-vectors ` are mutually exclusive" );
84+ " `if-scalable-and-supported ` are mutually exclusive" );
7485 return signalPassFailure ();
7586 }
7687
@@ -87,22 +98,27 @@ struct EnableArmStreamingPass
8798 return ;
8899 }
89100
90- if (ifContainsScalableVectors ) {
91- bool foundScalableVector = false ;
92- auto isScalableVector = [&](Type type) {
93- if ( auto vectorType = dyn_cast<VectorType>(type))
94- return vectorType. isScalable ();
95- return false ;
96- } ;
101+ if (ifScalableAndSupported ) {
102+ // FIXME: This should be based on target information (i.e., the presence
103+ // of FEAT_SME_FA64). This currently errs on the side of caution. If
104+ // possible gathers/scatters should be lowered regular vector loads/stores
105+ // before invoking this pass.
106+ auto disallowedOperations = opList<vector::GatherOp, vector::ScatterOp>() ;
107+ bool isCompatibleScalableFunction = false ;
97108 function.walk ([&](Operation *op) {
98- if (llvm::any_of (op-> getOperandTypes (), isScalableVector) ||
99- llvm::any_of ( op->getResultTypes (), isScalableVector )) {
100- foundScalableVector = true ;
109+ if (llvm::is_contained (disallowedOperations,
110+ op->getName (). getTypeID () )) {
111+ isCompatibleScalableFunction = false ;
101112 return WalkResult::interrupt ();
102113 }
114+ if (!isCompatibleScalableFunction &&
115+ (llvm::any_of (op->getOperandTypes (), isScalableVector) ||
116+ llvm::any_of (op->getResultTypes (), isScalableVector))) {
117+ isCompatibleScalableFunction = true ;
118+ }
103119 return WalkResult::advance ();
104120 });
105- if (!foundScalableVector )
121+ if (!isCompatibleScalableFunction )
106122 return ;
107123 }
108124
@@ -126,7 +142,7 @@ struct EnableArmStreamingPass
126142
127143std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass (
128144 const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
129- bool ifRequiredByOps, bool ifContainsScalableVectors ) {
145+ bool ifRequiredByOps, bool ifScalableAndSupported ) {
130146 return std::make_unique<EnableArmStreamingPass>(
131- streamingMode, zaMode, ifRequiredByOps, ifContainsScalableVectors );
147+ streamingMode, zaMode, ifRequiredByOps, ifScalableAndSupported );
132148}
0 commit comments