diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 3ec7354562d23..28e562c813eb3 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -1018,15 +1018,8 @@ void TosaValidation::runOnOperation() { if (op->getDialect() != tosaDialect) return; - // Profile-Extension based validation should be performed at the beginning. - if (strictOpSpecAlignment && - failed(profileComp.checkProfile(op, targetEnv))) - return signalPassFailure(); - - if (strictOpSpecAlignment && - failed(profileComp.checkExtension(op, targetEnv))) - return signalPassFailure(); - + // perform valid element type check at the beginning to + // protect rest of code against quantized element types for (Value operand : op->getOperands()) { auto elementTy = getElementTypeOrSelf(operand); if (!isValidElementType(elementTy)) { @@ -1044,6 +1037,14 @@ void TosaValidation::runOnOperation() { } } + if (strictOpSpecAlignment && + failed(profileComp.checkProfile(op, targetEnv))) + return signalPassFailure(); + + if (strictOpSpecAlignment && + failed(profileComp.checkExtension(op, targetEnv))) + return signalPassFailure(); + if (!allowInvalidOpDatatypeCombinations && failed(profileComp.checkInvalid(op))) { op->emitOpError("illegal: operand/result data types not supported"); diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 8cf6d4b154792..12b2379a592c3 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -253,6 +253,15 @@ func.func @test_conv2d_quant_any_acc(%arg0: tensor<1x4x4x4x!quant.any>> return %0 : tensor<1x4x4x8x!quant.any>> } +// ----- +// CHECK-LABEL: conv2d_quant_any +func.func @test_conv2d_quant_any(%arg0: tensor<1x4x4x4x!quant.any>>, %arg1: tensor<8x1x1x4x!quant.any>>, %arg2: tensor<8x!quant.any>>) -> tensor<1x4x4x8x!quant.any>> { + %zp = "tosa.const" () { values = dense<0> : tensor<1xi8> } : () -> tensor<1xi8> + // expected-error@+1 {{'tosa.conv2d' op is not profile-aligned: element type '!quant.any>'}} + %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array, pad = array, stride = array, local_bound = true} : (tensor<1x4x4x4x!quant.any>>, tensor<8x1x1x4x!quant.any>>, tensor<8x!quant.any>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8x!quant.any>> + return %0 : tensor<1x4x4x8x!quant.any>> +} + // ----- func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor {