diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 4a1527cd0369f..8f2915aa76e7c 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -24,6 +24,7 @@ #include "mlir/Interfaces/FunctionImplementation.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/Function.h" #include "llvm/IR/Type.h" @@ -3187,6 +3188,18 @@ static int64_t getNumElements(Type t) { return 1; } +/// Determine the element type of `type`. Supported types are `VectorType`, +/// `TensorType`, and `LLVMArrayType`. Everything else is treated as a scalar. +static Type getElementType(Type type) { + while (auto arrayType = dyn_cast(type)) + type = arrayType.getElementType(); + if (auto vecType = dyn_cast(type)) + return vecType.getElementType(); + if (auto tenType = dyn_cast(type)) + return tenType.getElementType(); + return type; +} + /// Check if the given type is a scalable vector type or a vector/array type /// that contains a nested scalable vector type. static bool hasScalableVectorType(Type t) { @@ -3281,60 +3294,69 @@ LogicalResult LLVM::ConstantOp::verify() { } if (auto structType = dyn_cast(getType())) { auto arrayAttr = dyn_cast(getValue()); - if (!arrayAttr) { - return emitOpError() << "expected array attribute for a struct constant"; - } + if (!arrayAttr) + return emitOpError() << "expected array attribute for struct type"; ArrayRef elementTypes = structType.getBody(); if (arrayAttr.size() != elementTypes.size()) { return emitOpError() << "expected array attribute of size " << elementTypes.size(); } - for (auto elementTy : elementTypes) { - if (!isa(elementTy)) { + for (auto [i, attr, type] : llvm::enumerate(arrayAttr, elementTypes)) { + if (!type.isSignlessIntOrIndexOrFloat()) { return emitOpError() << "expected struct element types to be floating " "point type or integer type"; } - } - - for (size_t i = 0; i < elementTypes.size(); ++i) { - Attribute element = arrayAttr[i]; - if (!isa(element)) { - return emitOpError() - << "expected struct element attribute types to be floating " - "point type or integer type"; + if (!isa(attr)) { + return emitOpError() << "expected element of array attribute to be " + "floating point or integer"; } - auto elementType = cast(element).getType(); - if (elementType != elementTypes[i]) { + if (cast(attr).getType() != type) return emitOpError() << "struct element at index " << i << " is of wrong type"; - } } return success(); } - if (auto targetExtType = dyn_cast(getType())) { + if (auto targetExtType = dyn_cast(getType())) return emitOpError() << "does not support target extension type."; - } + + // Check that an attribute whose element type has floating point semantics + // `attributeFloatSemantics` is compatible with a type whose element type + // is `constantElementType`. + // + // Requirement is that either + // 1) They have identical floating point types. + // 2) `constantElementType` is an integer type of the same width as the float + // attribute. This is to support builtin MLIR float types without LLVM + // equivalents, see comments in getLLVMConstant for more details. + auto verifyFloatSemantics = + [this](const llvm::fltSemantics &attributeFloatSemantics, + Type constantElementType) -> LogicalResult { + if (auto floatType = dyn_cast(constantElementType)) { + if (&floatType.getFloatSemantics() != &attributeFloatSemantics) { + return emitOpError() + << "attribute and type have different float semantics"; + } + return success(); + } + unsigned floatWidth = APFloat::getSizeInBits(attributeFloatSemantics); + if (isa(constantElementType)) { + if (!constantElementType.isInteger(floatWidth)) + return emitOpError() << "expected integer type of width " << floatWidth; + + return success(); + } + return success(); + }; // Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr. - if (auto intAttr = dyn_cast(getValue())) { + if (isa(getValue())) { if (!llvm::isa(getType())) return emitOpError() << "expected integer type"; } else if (auto floatAttr = dyn_cast(getValue())) { - const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics(); - unsigned floatWidth = APFloat::getSizeInBits(sem); - if (auto floatTy = dyn_cast(getType())) { - if (floatTy.getWidth() != floatWidth) { - return emitOpError() << "expected float type of width " << floatWidth; - } - } - // See the comment for getLLVMConstant for more details about why 8-bit - // floats can be represented by integers. - if (isa(getType()) && !getType().isInteger(floatWidth)) { - return emitOpError() << "expected integer type of width " << floatWidth; - } - } else if (isa(getValue())) { + return verifyFloatSemantics(floatAttr.getValue().getSemantics(), getType()); + } else if (auto elementsAttr = dyn_cast(getValue())) { if (hasScalableVectorType(getType())) { // The exact number of elements of a scalable vector is unknown, so we // allow only splat attributes. @@ -3346,18 +3368,32 @@ LogicalResult LLVM::ConstantOp::verify() { } if (!isa(getType())) return emitOpError() << "expected vector or array type"; + // The number of elements of the attribute and the type must match. - if (auto elementsAttr = dyn_cast(getValue())) { - int64_t attrNumElements = elementsAttr.getNumElements(); - if (getNumElements(getType()) != attrNumElements) - return emitOpError() - << "type and attribute have a different number of elements: " - << getNumElements(getType()) << " vs. " << attrNumElements; + int64_t attrNumElements = elementsAttr.getNumElements(); + if (getNumElements(getType()) != attrNumElements) { + return emitOpError() + << "type and attribute have a different number of elements: " + << getNumElements(getType()) << " vs. " << attrNumElements; + } + + Type attrElmType = getElementType(elementsAttr.getType()); + Type resultElmType = getElementType(getType()); + if (auto floatType = dyn_cast(attrElmType)) + return verifyFloatSemantics(floatType.getFloatSemantics(), resultElmType); + + if (isa(attrElmType) && !isa(resultElmType)) { + return emitOpError( + "expected integer element type for integer elements attribute"); } } else if (auto arrayAttr = dyn_cast(getValue())) { + + // The case where the constant is LLVMStructType has already been handled. auto arrayType = dyn_cast(getType()); if (!arrayType) - return emitOpError() << "expected array type"; + return emitOpError() + << "expected array or struct type for array attribute"; + // When the attribute is an ArrayAttr, check that its nesting matches the // corresponding ArrayType or VectorType nesting. return verifyStructArrayConstant(*this, arrayType, arrayAttr, /*dim=*/0); diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 7f2c8c72e5cf9..ac1737444fcf0 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -394,7 +394,7 @@ llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> { // ----- llvm.func @struct_wrong_attribute_type() -> !llvm.struct<(f64, f64)> { - // expected-error @+1 {{expected array attribute}} + // expected-error @+1 {{expected array attribute for struct type}} %0 = llvm.mlir.constant(1.0 : f64) : !llvm.struct<(f64, f64)> llvm.return %0 : !llvm.struct<(f64, f64)> } @@ -439,6 +439,111 @@ llvm.func @scalable_vec_requires_splat() -> vector<[4]xf64> { llvm.return %0 : vector<[4]xf64> } + +// ----- + +llvm.func @int_attr_requires_int_type() -> f32 { + // expected-error @below{{expected integer type}} + %0 = llvm.mlir.constant(1 : index) : f32 + llvm.return %0 : f32 +} + +// ----- + +llvm.func @vector_int_attr_requires_int_type() -> vector<2xf32> { + // expected-error @below{{expected integer element type}} + %0 = llvm.mlir.constant(dense<[1, 2]> : vector<2xi32>) : vector<2xf32> + llvm.return %0 : vector<2xf32> +} + +// ----- + +llvm.func @float_attr_and_type_required_same() -> f16 { + // expected-error @below{{attribute and type have different float semantics}} + %cst = llvm.mlir.constant(1.0 : bf16) : f16 + llvm.return %cst : f16 +} + +// ----- + +llvm.func @vector_float_attr_and_type_required_same() -> vector<2xf16> { + // expected-error @below{{attribute and type have different float semantics}} + %cst = llvm.mlir.constant(dense<[1.0, 2.0]> : vector<2xbf16>) : vector<2xf16> + llvm.return %cst : vector<2xf16> +} + +// ----- + +llvm.func @incompatible_integer_type_for_float_attr() -> i32 { + // expected-error @below{{expected integer type of width 16}} + %cst = llvm.mlir.constant(1.0 : f16) : i32 + llvm.return %cst : i32 +} + +// ----- + +llvm.func @vector_incompatible_integer_type_for_float_attr() -> vector<2xi8> { + // expected-error @below{{expected integer type of width 16}} + %cst = llvm.mlir.constant(dense<[1.0, 2.0]> : vector<2xf16>) : vector<2xi8> + llvm.return %cst : vector<2xi8> +} + +// ----- + +llvm.func @vector_with_non_vector_type() -> f32 { + // expected-error @below{{expected vector or array type}} + %cst = llvm.mlir.constant(dense<100.0> : vector<1xf64>) : f32 + llvm.return %cst : f32 +} + +// ----- + +llvm.func @array_attr_with_invalid_type() -> i32 { + // expected-error @below{{expected array or struct type for array attribute}} + %0 = llvm.mlir.constant([1 : i32]) : i32 + llvm.return %0 : i32 +} + +// ----- + +llvm.func @elements_attribute_incompatible_nested_array_struct1_type() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> { + // expected-error @below{{expected integer element type for integer elements attribute}} + %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> + llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> +} + +// ----- + +llvm.func @elements_attribute_incompatible_nested_array_struct3_type() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> { + // expected-error @below{{expected integer element type for integer elements attribute}} + %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> + llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> +} + +// ----- + +llvm.func @invalid_struct_element_type() -> !llvm.struct<(f64, array<2 x i32>)> { + // expected-error @below{{expected struct element types to be floating point type or integer type}} + %0 = llvm.mlir.constant([1.0 : f64, dense<[1, 2]> : tensor<2xi32>]) : !llvm.struct<(f64, array<2 x i32>)> + llvm.return %0 : !llvm.struct<(f64, array<2 x i32>)> +} + +// ----- + +llvm.func @wrong_struct_element_attr_type() -> !llvm.struct<(f64, f64)> { + // expected-error @below{{expected element of array attribute to be floating point or integer}} + %0 = llvm.mlir.constant([dense<[1, 2]> : tensor<2xi32>, 2.0 : f64]) : !llvm.struct<(f64, f64)> + llvm.return %0 : !llvm.struct<(f64, f64)> +} + +// ----- + +llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> { + // expected-error @below{{struct element at index 0 is of wrong type}} + %0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)> + llvm.return %0 : !llvm.struct<(f64, f64)> +} + // ----- func.func @insertvalue_non_llvm_type(%a : i32, %b : i32) { @@ -484,13 +589,13 @@ func.func @extractvalue_invalid_type(%a : !llvm.array<4 x vector<8xf32>>) -> !ll return %b : !llvm.array<4 x vector<8xf32>> } - // ----- func.func @extractvalue_non_llvm_type(%a : i32, %b : tensor<*xi32>) { // expected-error@+2 {{expected LLVM IR Dialect type}} llvm.extractvalue %b[0] : tensor<*xi32> } + // ----- func.func @extractvalue_struct_out_of_bounds() { @@ -659,6 +764,7 @@ func.func @atomicrmw_scalable_vector(%ptr : !llvm.ptr, %f32_vec : vector<[2]xf32 %0 = llvm.atomicrmw fadd %ptr, %f32_vec unordered : !llvm.ptr, vector<[2]xf32> llvm.return } + // ----- func.func @atomicrmw_vector_expected_float(%ptr : !llvm.ptr, %i32_vec : vector<3xi32>) { @@ -1667,7 +1773,6 @@ func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: ! return } - // ----- func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) { diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir index a8ef401fff27e..b09ceeeb86cc0 100644 --- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir @@ -7,78 +7,6 @@ func.func @foo() { // ----- -llvm.func @vector_with_non_vector_type() -> f32 { - // expected-error @below{{expected vector or array type}} - %cst = llvm.mlir.constant(dense<100.0> : vector<1xf64>) : f32 - llvm.return %cst : f32 -} - -// ----- - -llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> { - // expected-error @below{{expected an array attribute for a struct constant}} - %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> - llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> -} - -// ----- - -llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> { - // expected-error @below{{expected an array attribute for a struct constant}} - %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> - llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> -} - -// ----- - -llvm.func @invalid_struct_element_type() -> !llvm.struct<(f64, array<2 x i32>)> { - // expected-error @below{{expected struct element types to be floating point type or integer type}} - %0 = llvm.mlir.constant([1.0 : f64, dense<[1, 2]> : tensor<2xi32>]) : !llvm.struct<(f64, array<2 x i32>)> - llvm.return %0 : !llvm.struct<(f64, array<2 x i32>)> -} - -// ----- - -llvm.func @wrong_struct_element_attr_type() -> !llvm.struct<(f64, f64)> { - // expected-error @below{{expected struct element attribute types to be floating point type or integer type}} - %0 = llvm.mlir.constant([dense<[1, 2]> : tensor<2xi32>, 2.0 : f64]) : !llvm.struct<(f64, f64)> - llvm.return %0 : !llvm.struct<(f64, f64)> -} - -// ----- - -llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> { - // expected-error @below{{struct element at index 0 is of wrong type}} - %0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)> - llvm.return %0 : !llvm.struct<(f64, f64)> -} - -// ----- - -llvm.func @integer_with_float_type() -> f32 { - // expected-error @+1 {{expected integer type}} - %0 = llvm.mlir.constant(1 : index) : f32 - llvm.return %0 : f32 -} - -// ----- - -llvm.func @incompatible_float_attribute_type() -> f32 { - // expected-error @below{{expected float type of width 64}} - %cst = llvm.mlir.constant(1.0 : f64) : f32 - llvm.return %cst : f32 -} - -// ----- - -llvm.func @incompatible_integer_type_for_float_attr() -> i32 { - // expected-error @below{{expected integer type of width 16}} - %cst = llvm.mlir.constant(1.0 : f16) : i32 - llvm.return %cst : i32 -} - -// ----- - // expected-error @below{{LLVM attribute 'readonly' does not expect a value}} llvm.func @passthrough_unexpected_value() attributes {passthrough = [["readonly", "42"]]}