From ed1eb3a7b1636e03d1cd7a02876b8cd6f9962773 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Thu, 21 Mar 2024 19:01:12 +0000 Subject: [PATCH 1/2] [mlir][arith] Refine the verifier for arith.constant Disallows initialization of scalable vectors with an attribute of arbitrary values, e.g.: ```mlir %c = arith.constant dense<[0, 1]> : vector<[2] x i32> ``` Initialization using vector splats remains allowed (i.e. when all the init values are identical): ```mlir %c = arith.constant dense<[1, 1]> : vector<[2] x i32> ``` --- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 8 ++++++++ mlir/test/Dialect/Arith/invalid.mlir | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 9f64a07f31e3a..22aa4ec0cdb08 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -190,6 +190,14 @@ LogicalResult arith::ConstantOp::verify() { return emitOpError( "value must be an integer, float, or elements attribute"); } + + // Intializing scalable vectors with elements attribute is not supported + // unless it's a vector splot. + auto vecType = dyn_cast(type); + auto val = dyn_cast(getValue()); + if ((vecType && val) && vecType.isScalable() && !val.isSplat()) + return emitOpError( + "using elements attribute to initialize a scalable vector"); return success(); } diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir index 6d8ac0ada52be..ac28075df33e9 100644 --- a/mlir/test/Dialect/Arith/invalid.mlir +++ b/mlir/test/Dialect/Arith/invalid.mlir @@ -215,6 +215,14 @@ func.func @func_with_ops() { // ----- +func.func @func_with_ops() { +^bb0: + // expected-error@+1 {{op failed to verify that result type has i1 element type and same shape as operands}} + %c = arith.constant dense<[0, 1]> : vector<[2] x i32> +} + +// ----- + func.func @invalid_cmp_shape(%idx : () -> ()) { // expected-error@+1 {{'lhs' must be signless-integer-like, but got '() -> ()'}} %cmp = arith.cmpi eq, %idx, %idx : () -> () From f8532c33b1c9620cf79b228c7c1518d2fdc8536b Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Fri, 22 Mar 2024 14:01:58 +0000 Subject: [PATCH 2/2] fixup! [mlir][arith] Refine the verifier for arith.constant Address PR comments --- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 8 +++----- mlir/test/Dialect/Arith/invalid.mlir | 17 +++++++++-------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 22aa4ec0cdb08..036fdf555f041 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -191,13 +191,11 @@ LogicalResult arith::ConstantOp::verify() { "value must be an integer, float, or elements attribute"); } - // Intializing scalable vectors with elements attribute is not supported - // unless it's a vector splot. auto vecType = dyn_cast(type); - auto val = dyn_cast(getValue()); - if ((vecType && val) && vecType.isScalable() && !val.isSplat()) + if (vecType && vecType.isScalable() && !isa(getValue())) return emitOpError( - "using elements attribute to initialize a scalable vector"); + "intializing scalable vectors with elements attribute is not supported" + " unless it's a vector splat"); return success(); } diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir index ac28075df33e9..fdc907a7c6af1 100644 --- a/mlir/test/Dialect/Arith/invalid.mlir +++ b/mlir/test/Dialect/Arith/invalid.mlir @@ -64,6 +64,15 @@ func.func @constant_out_of_range() { // ----- +func.func @constant_invalid_scalable_vec_initialization() { +^bb0: + // expected-error@+1 {{'arith.constant' op intializing scalable vectors with elements attribute is not supported unless it's a vector splat}} + %c = arith.constant dense<[0, 1]> : vector<[2] x i32> + return +} + +// ----- + func.func @constant_wrong_type() { ^bb: %x = "arith.constant"(){value = 10.} : () -> f32 // expected-error {{'arith.constant' op failed to verify that all of {value, result} have same type}} @@ -215,14 +224,6 @@ func.func @func_with_ops() { // ----- -func.func @func_with_ops() { -^bb0: - // expected-error@+1 {{op failed to verify that result type has i1 element type and same shape as operands}} - %c = arith.constant dense<[0, 1]> : vector<[2] x i32> -} - -// ----- - func.func @invalid_cmp_shape(%idx : () -> ()) { // expected-error@+1 {{'lhs' must be signless-integer-like, but got '() -> ()'}} %cmp = arith.cmpi eq, %idx, %idx : () -> ()