diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp index 7f18e169f157..5b52e01e2b0d 100644 --- a/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp +++ b/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp @@ -6,7 +6,16 @@ #include "compiler/plugins/input/Torch/InputConversion/Passes.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" +#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" +#include "llvm/ADT/APFloat.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" @@ -163,6 +172,353 @@ struct FftRfftOpConversion } }; +struct FlexAttentionOpConversion + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + // Attention tensors are 4D: [batch, head, query_seq, key_seq]. + static constexpr int kAttentionRank = 4; + // Modification functions receive 4 index arguments: (b, h, m, n). + static constexpr int kNumModificationIndices = 4; + + // Makes it convenient to pass around commonly used types. + struct TypeInfo { + Type i32Type; + Type si32Type; + RankedTensorType scalarTensorType; + RankedTensorType i32ScalarTensorType; + RankedTensorType boolScalarTensorType; + torch::Torch::ValueTensorType torchScalarType; + torch::Torch::ValueTensorType torchI32ScalarType; + torch::Torch::ValueTensorType torchBoolScalarType; + } mutable typeInfo; + + LogicalResult matchAndRewrite(torch::Torch::AtenFlexAttentionOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *ctx = getContext(); + Value query = op.getQuery(); + Value key = op.getKey(); + Value value = op.getValue(); + Value scaleValue = op.getScale(); + auto scoreModSymbol = op.getScoreModFn(); + auto maskModSymbol = op.getMaskModFn(); + + bool returnLseValue; + if (!matchPattern(op.getReturnLse(), + torch::Torch::m_TorchConstantBool(&returnLseValue))) { + return rewriter.notifyMatchFailure( + op, "expected return_lse to be a constant bool"); + } + + auto queryType = cast(query.getType()); + auto keyType = cast(key.getType()); + auto valueType = cast(value.getType()); + + if (!queryType.hasSizes() || !keyType.hasSizes() || !valueType.hasSizes()) { + return rewriter.notifyMatchFailure( + op, "expected input(s) types having sizes"); + } + + ArrayRef queryShape = queryType.getSizes(); + ArrayRef valueShape = valueType.getSizes(); + + // Query shape: [B, H, M, E]. + if (queryShape.size() != kAttentionRank) { + return rewriter.notifyMatchFailure(op, "expected 4D query tensor"); + } + + int64_t batch = queryShape[0]; + int64_t numHeads = queryShape[1]; + int64_t seqLenQ = queryShape[2]; + int64_t headDim = queryShape[3]; + int64_t seqLenKV = keyType.getSizes()[2]; + int64_t valueDim = valueShape[3]; + + if (headDim == torch::Torch::kUnknownSize) { + return rewriter.notifyMatchFailure(op, "NYI: dynamic head dimension"); + } + + // Check if the element type is a float. + Type elementType = queryType.getOptionalDtype(); + auto floatType = dyn_cast(elementType); + if (!floatType) { + return rewriter.notifyMatchFailure(op, "expected float element type"); + } + + // Default scale: 1.0 / sqrt(head_dim). + double scaleVal; + if (!matchPattern(scaleValue, + torch::Torch::m_TorchConstantFloat(&scaleVal))) { + scaleVal = 1.0 / std::sqrt(static_cast(headDim)); + } + + Value scale = arith::ConstantOp::create( + rewriter, loc, floatType, rewriter.getFloatAttr(floatType, scaleVal)); + + Value builtinQuery = convertToBuiltinTensor(rewriter, loc, query); + Value builtinKey = convertToBuiltinTensor(rewriter, loc, key); + Value builtinValue = convertToBuiltinTensor(rewriter, loc, value); + + // Declare common types for mask and score modification regions. + setTypeInfo(rewriter, floatType); + Value zero = arith::ConstantFloatOp::create( + rewriter, loc, floatType, + llvm::APFloat::getZero(floatType.getFloatSemantics())); + Value mask; + if (maskModSymbol) { + FlatSymbolRefAttr maskModRef = + FlatSymbolRefAttr::get(ctx, *maskModSymbol); + mask = createModifiedMask(rewriter, loc, ctx, maskModRef, batch, numHeads, + seqLenQ, seqLenKV, floatType, builtinQuery, + builtinKey, zero); + } + + // Create output tensor for attention. + SmallVector outputDynSizes; + SmallVector outputShape = {batch, numHeads, seqLenQ, valueDim}; + computeDynamicSizes(rewriter, loc, outputShape, outputDynSizes, + builtinQuery, builtinValue); + + // Initialize output tensor with identity value (0.0 for addition). + Value outputInit = arith::getIdentityValue(arith::AtomicRMWKind::addf, + floatType, rewriter, loc, + /*useOnlyFiniteValue=*/true); + Value outputTensor = tensor::SplatOp::create(rewriter, loc, outputInit, + outputShape, outputDynSizes); + + // Build indexing maps for attention. + // Standard maps: Q, K, V, scale, [mask], output. + AffineExpr b, h, m, n, k1, k2; + bindDims(ctx, b, h, m, n, k1, k2); + + auto qMap = AffineMap::get(6, 0, {b, h, m, k1}, ctx); + auto kMap = AffineMap::get(6, 0, {b, h, n, k1}, ctx); + auto vMap = AffineMap::get(6, 0, {b, h, n, k2}, ctx); + auto sMap = AffineMap::get(6, 0, {}, ctx); + auto oMap = AffineMap::get(6, 0, {b, h, m, k2}, ctx); + + SmallVector indexingMaps = {qMap, kMap, vMap, sMap}; + if (mask) { + indexingMaps.push_back(AffineMap::get(6, 0, {b, h, m, n}, ctx)); + } + + indexingMaps.push_back(oMap); + + // Create attention op. + auto attentionOp = IREE::LinalgExt::AttentionOp::create( + rewriter, loc, outputTensor.getType(), builtinQuery, builtinKey, + builtinValue, scale, outputTensor, + rewriter.getAffineMapArrayAttr(indexingMaps), mask); + + createScoreModificationRegion(rewriter, loc, attentionOp, scoreModSymbol, + floatType); + + rewriter.setInsertionPointAfter(attentionOp); + + Value normalizedOutput = attentionOp.getResult(0); + + auto outputTorchType = + queryType.getWithSizesAndDtype(outputShape, elementType); + Value torchOutput = torch::TorchConversion::FromBuiltinTensorOp::create( + rewriter, loc, outputTorchType, normalizedOutput); + + // Handle logsumexp. + // Note: AttentionOp doesn't expose intermediate max/sum + // values needed for LSE calculation. Return a dummy tensor - logsumexp + // shape is output_shape[:-1] (remove last dim). + if (returnLseValue) { + op.emitWarning("FlexAttention: logsumexp output is a dummy (zeros), " + "actual values are not available from AttentionOp"); + } + SmallVector lseShape = outputShape; + lseShape.pop_back(); + + SmallVector lseDynSizes = outputDynSizes; + if (!outputDynSizes.empty()) { + lseDynSizes.pop_back(); + } + + Value lseTensor = + tensor::SplatOp::create(rewriter, loc, zero, lseShape, lseDynSizes); + + auto lseTorchType = queryType.getWithSizesAndDtype(lseShape, elementType); + Value torchLogsumexp = torch::TorchConversion::FromBuiltinTensorOp::create( + rewriter, loc, lseTorchType, lseTensor); + + rewriter.replaceOp(op, {torchOutput, torchLogsumexp}); + return success(); + } + + Value convertToBuiltinTensor(PatternRewriter &rewriter, Location loc, + Value torchTensor) const { + auto torchType = cast(torchTensor.getType()); + return torch::TorchConversion::ToBuiltinTensorOp::create( + rewriter, loc, torchType.toBuiltinTensor(), torchTensor); + } + + void setTypeInfo(PatternRewriter &rewriter, FloatType floatType) const { + typeInfo.i32Type = rewriter.getI32Type(); + typeInfo.si32Type = + IntegerType::get(rewriter.getContext(), 32, IntegerType::Signed); + typeInfo.scalarTensorType = RankedTensorType::get({}, floatType); + typeInfo.i32ScalarTensorType = RankedTensorType::get({}, typeInfo.i32Type); + typeInfo.boolScalarTensorType = + RankedTensorType::get({}, rewriter.getI1Type()); + typeInfo.torchScalarType = rewriter.getType( + ArrayRef{}, floatType); + typeInfo.torchI32ScalarType = + rewriter.getType(ArrayRef{}, + typeInfo.si32Type); + typeInfo.torchBoolScalarType = + rewriter.getType(ArrayRef{}, + rewriter.getI1Type()); + } + + void computeDynamicSizes(PatternRewriter &rewriter, Location loc, + const SmallVector &shape, + SmallVector &dynSizes, Value first, + Value second) const { + for (int i = 0; i < kAttentionRank; ++i) { + if (shape[i] == torch::Torch::kUnknownSize) { + Value idx = + arith::ConstantIndexOp::create(rewriter, loc, std::min(i, 2)); + Value dim = + tensor::DimOp::create(rewriter, loc, i < 3 ? first : second, idx); + dynSizes.push_back(dim); + } + } + } + + // Creates a modified mask tensor. + Value createModifiedMask(PatternRewriter &rewriter, Location loc, + MLIRContext *ctx, FlatSymbolRefAttr maskModRef, + int64_t batch, int64_t numHeads, int64_t seqLenQ, + int64_t seqLenKV, FloatType floatType, + Value builtinQuery, Value builtinKey, + Value zero) const { + // Create mask tensor [B, H, M, N] with values 0.0 (attend) or -inf + // (mask). + SmallVector maskShape = {batch, numHeads, seqLenQ, seqLenKV}; + SmallVector maskDynSizes; + + computeDynamicSizes(rewriter, loc, maskShape, maskDynSizes, builtinQuery, + builtinKey); + + Value maskTensor = tensor::EmptyOp::create(rewriter, loc, maskShape, + floatType, maskDynSizes); + // Create linalg.generic to materialize mask. + SmallVector maskMaps; + maskMaps.push_back(AffineMap::getMultiDimIdentityMap(kAttentionRank, ctx)); + + SmallVector iteratorTypes( + kAttentionRank, utils::IteratorType::parallel); + + Value negInf = arith::ConstantFloatOp::create( + rewriter, loc, floatType, + llvm::APFloat::getInf(floatType.getFloatSemantics(), + /*Negative=*/true)); + + auto maskGeneric = linalg::GenericOp::create( + rewriter, loc, TypeRange{maskTensor.getType()}, ValueRange{}, + ValueRange{maskTensor}, maskMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + // Get indices and convert to torch tensors. + SmallVector torchIndices; + for (unsigned i = 0; i < kNumModificationIndices; ++i) { + Value idx = linalg::IndexOp::create(b, loc, i); + Value idxI32 = + arith::IndexCastOp::create(b, loc, typeInfo.i32Type, idx); + Value idxTensor = tensor::FromElementsOp::create( + b, loc, typeInfo.i32ScalarTensorType, ValueRange{idxI32}); + Value torchIdx = + torch::TorchConversion::FromBuiltinTensorOp::create( + b, loc, typeInfo.torchI32ScalarType, idxTensor); + torchIndices.push_back(torchIdx); + } + + // Call mask_mod_fn(b, h, q_idx, kv_idx). + auto callOp = func::CallOp::create( + b, loc, TypeRange{typeInfo.torchBoolScalarType}, maskModRef, + ValueRange(torchIndices)); + Value torchMaskResult = callOp.getResult(0); + + Value maskResult = torch::TorchConversion::ToBuiltinTensorOp::create( + b, loc, typeInfo.boolScalarTensorType, torchMaskResult); + + Value maskBool = + tensor::ExtractOp::create(b, loc, maskResult, ValueRange{}); + + Value maskValue = + arith::SelectOp::create(b, loc, maskBool, zero, negInf); + + linalg::YieldOp::create(b, loc, maskValue); + }); + + return maskGeneric.getResult(0); + } + + // Adds a score modification region to the attention op. + void + createScoreModificationRegion(PatternRewriter &rewriter, Location loc, + IREE::LinalgExt::AttentionOp attentionOp, + std::optional scoreModSymbol, + FloatType floatType) const { + OpBuilder::InsertionGuard g(rewriter); + Block *block = rewriter.createBlock(&attentionOp.getRegion()); + + // Add block arguments: score (floatType), b, h, m, n (all index type). + Type indexType = rewriter.getIndexType(); + block->addArgument(floatType, loc); + for (int i = 0; i < kNumModificationIndices; ++i) { + block->addArgument(indexType, loc); + } + rewriter.setInsertionPointToStart(block); + + Value score = block->getArgument(0); + SmallVector indices; + for (int i = 0; i < kNumModificationIndices; ++i) { + indices.push_back(block->getArgument(i + 1)); + } + Value modifiedScore = score; + + if (scoreModSymbol) { + // The score_mod_fn takes (score, b, h, m, n) where m=q_idx, n=kv_idx. + + Value scoreTensor = tensor::FromElementsOp::create( + rewriter, loc, typeInfo.scalarTensorType, ValueRange{score}); + Value torchScore = torch::TorchConversion::FromBuiltinTensorOp::create( + rewriter, loc, typeInfo.torchScalarType, scoreTensor); + + SmallVector callArgs; + callArgs.push_back(torchScore); + + for (Value idx : indices) { + Value idxI32 = + arith::IndexCastOp::create(rewriter, loc, typeInfo.i32Type, idx); + Value idxTensor = tensor::FromElementsOp::create( + rewriter, loc, typeInfo.i32ScalarTensorType, ValueRange{idxI32}); + Value torchIdx = torch::TorchConversion::FromBuiltinTensorOp::create( + rewriter, loc, typeInfo.torchI32ScalarType, idxTensor); + callArgs.push_back(torchIdx); + } + + auto callOp = func::CallOp::create( + rewriter, loc, TypeRange{typeInfo.torchScalarType}, + scoreModSymbol.value(), ValueRange(callArgs)); + Value torchResult = callOp.getResult(0); + + Value resultTensor = torch::TorchConversion::ToBuiltinTensorOp::create( + rewriter, loc, typeInfo.scalarTensorType, torchResult); + + modifiedScore = + tensor::ExtractOp::create(rewriter, loc, resultTensor, ValueRange{}); + } + + IREE::LinalgExt::YieldOp::create(rewriter, loc, modifiedScore); + } +}; + class ConvertTorchUnstructuredToLinalgExtPass final : public impl::ConvertTorchUnstructuredToLinalgExtPassBase< ConvertTorchUnstructuredToLinalgExtPass> { @@ -171,13 +527,14 @@ class ConvertTorchUnstructuredToLinalgExtPass final registry.insert(); } void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); - patterns.add(context); + patterns.add(context); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); diff --git a/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir b/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir index a568966d906b..8769abb0a5c1 100644 --- a/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir +++ b/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(torch-iree-torch-unstructured-to-linalg-ext))" %s | FileCheck %s +// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(torch-iree-torch-unstructured-to-linalg-ext))" --verify-diagnostics %s | FileCheck %s // CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d2)> // CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> @@ -52,6 +52,7 @@ func.func @fft_rfft.with_transpose(%arg0: !torch.vtensor<[3,8,16],f32>) -> !torc // CHECK: return %[[VAR14]] : !torch.vtensor<[3,5,16],complex> // ----- + func.func @fft_rfft.last(%arg0: !torch.vtensor<[3,8,16],f32>) -> !torch.vtensor<[3,8,9],complex> { %int-1 = torch.constant.int -1 %none = torch.constant.none @@ -99,3 +100,215 @@ func.func @fft_rfft.last(%arg0: !torch.vtensor<[3,8,16],f32>) -> !torch.vtensor< // CHECK: %[[VAR12:.*]] = torch.aten.cat %[[VAR11]], %[[INTM1]] : !torch.list>, !torch.int -> !torch.vtensor<[3,8,9,2],f32> // CHECK: %[[VAR13:.*]] = torch.aten.view_as_complex %[[VAR12]] : !torch.vtensor<[3,8,9,2],f32> -> !torch.vtensor<[3,8,9],complex> // CHECK: return %[[VAR13]] : !torch.vtensor<[3,8,9],complex> + +// ----- + +// flex_attention with both score_mod and mask_mod. +func.func @flex_attn_with_scoremod_and_maskmod(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>,!torch.vtensor<[4,8,1024],f32>) attributes {torch.assume_strict_symbolic_shapes} { + %float1.000000e00 = torch.constant.float 1.000000e+00 + %true = torch.constant.bool true + // expected-warning @+1 {{FlexAttention: logsumexp output is a dummy (zeros), actual values are not available from AttentionOp}} + %output, %logsumexp = torch.aten.flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %true {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32> + return %output, %logsumexp : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32> +} +func.func private @sdpa_score0(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32> { + %0 = torch.aten.tanh %arg0 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} +func.func private @sdpa_mask0(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>) -> !torch.vtensor<[],i1> { + %0 = torch.aten.ge.Tensor %arg2, %arg3 : !torch.vtensor<[],si32>, !torch.vtensor<[],si32> -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> +} +// CHECK-LABEL: func.func @flex_attn_with_scoremod_and_maskmod( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG1:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG2:.*]]: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>) +// CHECK: %[[LSE_BUF:.*]] = arith.constant dense<0.000000e+00> : tensor<4x8x1024xf32> +// CHECK: %[[OUT_BUF:.*]] = arith.constant dense<0.000000e+00> : tensor<4x8x1024x64xf32> +// CHECK: %[[NEG_INF:.*]] = arith.constant 0xFF800000 : f32 +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[Q:.*]] = torch_c.to_builtin_tensor %[[ARG0]] +// CHECK: %[[K:.*]] = torch_c.to_builtin_tensor %[[ARG1]] +// CHECK: %[[V:.*]] = torch_c.to_builtin_tensor %[[ARG2]] +// CHECK: %[[MASK_BUF:.*]] = tensor.empty() : tensor<4x8x1024x1024xf32> +// CHECK: %[[MASK_T:.*]] = linalg.generic +// CHECK-SAME: outs(%[[MASK_BUF]] : tensor<4x8x1024x1024xf32>) +// CHECK: ^bb0(%{{.*}}: f32): +// CHECK: %{{.*}} = func.call @sdpa_mask0 +// CHECK: %{{.*}} = torch_c.to_builtin_tensor %{{.*}} : !torch.vtensor<[],i1> -> tensor +// CHECK: %{{.*}} = tensor.extract %{{.*}}[] : tensor +// CHECK: %[[SEL:.*]] = arith.select %{{.*}}, %[[ZERO]], %[[NEG_INF]] : f32 +// CHECK: linalg.yield %[[SEL]] : f32 +// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention +// CHECK-SAME: ins(%[[Q]], %[[K]], %[[V]], %[[ONE]], %[[MASK_T]] +// CHECK-SAME: outs(%[[OUT_BUF]] : tensor<4x8x1024x64xf32>) { +// CHECK: ^bb0(%[[SCORE:.*]]: f32, %[[B:.*]]: index, %[[H:.*]]: index, %[[M:.*]]: index, %[[N:.*]]: index): +// CHECK: %[[SCORE_T:.*]] = tensor.from_elements %[[SCORE]] : tensor +// CHECK: %[[TORCH_SCORE:.*]] = torch_c.from_builtin_tensor %[[SCORE_T]] : tensor -> !torch.vtensor<[],f32> +// CHECK: %[[B_I32:.*]] = arith.index_cast %[[B]] : index to i32 +// CHECK: %[[B_T:.*]] = tensor.from_elements %[[B_I32]] : tensor +// CHECK: %[[TORCH_B:.*]] = torch_c.from_builtin_tensor %[[B_T]] : tensor -> !torch.vtensor<[],si32> +// CHECK: %[[H_I32:.*]] = arith.index_cast %[[H]] : index to i32 +// CHECK: %[[H_T:.*]] = tensor.from_elements %[[H_I32]] : tensor +// CHECK: %[[TORCH_H:.*]] = torch_c.from_builtin_tensor %[[H_T]] : tensor -> !torch.vtensor<[],si32> +// CHECK: %[[M_I32:.*]] = arith.index_cast %[[M]] : index to i32 +// CHECK: %[[M_T:.*]] = tensor.from_elements %[[M_I32]] : tensor +// CHECK: %[[TORCH_M:.*]] = torch_c.from_builtin_tensor %[[M_T]] : tensor -> !torch.vtensor<[],si32> +// CHECK: %[[N_I32:.*]] = arith.index_cast %[[N]] : index to i32 +// CHECK: %[[N_T:.*]] = tensor.from_elements %[[N_I32]] : tensor +// CHECK: %[[TORCH_N:.*]] = torch_c.from_builtin_tensor %[[N_T]] : tensor -> !torch.vtensor<[],si32> +// CHECK: %[[CALL:.*]] = func.call @sdpa_score0(%[[TORCH_SCORE]], %[[TORCH_B]], %[[TORCH_H]], %[[TORCH_M]], %[[TORCH_N]]) +// CHECK: %[[CALL_T:.*]] = torch_c.to_builtin_tensor %[[CALL]] : !torch.vtensor<[],f32> -> tensor +// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[CALL_T]][] : tensor +// CHECK: iree_linalg_ext.yield %[[EXTRACT]] : f32 +// CHECK: } -> tensor<4x8x1024x64xf32> +// CHECK: %[[T_OUT:.*]] = torch_c.from_builtin_tensor %[[ATTN]] +// CHECK: %[[T_LSE:.*]] = torch_c.from_builtin_tensor %[[LSE_BUF]] +// CHECK: return %[[T_OUT]], %[[T_LSE]] + +// ----- + +// flex_attention with mask_mod only. +func.func @flex_attn_with_maskmod_only(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>,!torch.vtensor<[4,8,1024],f32>) attributes {torch.assume_strict_symbolic_shapes} { + %float1.000000e00 = torch.constant.float 1.000000e+00 + %true = torch.constant.bool true + // expected-warning @+1 {{FlexAttention: logsumexp output is a dummy (zeros), actual values are not available from AttentionOp}} + %output, %logsumexp = torch.aten.flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %true {mask_mod_fn = @sdpa_mask0} : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32> + return %output, %logsumexp : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32> +} +func.func private @sdpa_mask0(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>) -> !torch.vtensor<[],i1> { + %0 = torch.aten.ge.Tensor %arg2, %arg3 : !torch.vtensor<[],si32>, !torch.vtensor<[],si32> -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> +} +// CHECK-LABEL: func.func @flex_attn_with_maskmod_only( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG1:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG2:.*]]: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>) +// CHECK: %[[LSE_BUF:.*]] = arith.constant dense<0.000000e+00> : tensor<4x8x1024xf32> +// CHECK: %[[OUT_BUF:.*]] = arith.constant dense<0.000000e+00> : tensor<4x8x1024x64xf32> +// CHECK: %[[NEG_INF:.*]] = arith.constant 0xFF800000 : f32 +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[Q:.*]] = torch_c.to_builtin_tensor %[[ARG0]] +// CHECK: %[[K:.*]] = torch_c.to_builtin_tensor %[[ARG1]] +// CHECK: %[[V:.*]] = torch_c.to_builtin_tensor %[[ARG2]] +// CHECK: %[[MASK_BUF:.*]] = tensor.empty() : tensor<4x8x1024x1024xf32> +// CHECK: %[[MASK_T:.*]] = linalg.generic +// CHECK-SAME: outs(%[[MASK_BUF]] : tensor<4x8x1024x1024xf32>) +// CHECK: ^bb0(%{{.*}}: f32): +// CHECK: %{{.*}} = func.call @sdpa_mask0 +// CHECK: %{{.*}} = torch_c.to_builtin_tensor %{{.*}} : !torch.vtensor<[],i1> -> tensor +// CHECK: %{{.*}} = tensor.extract %{{.*}}[] : tensor +// CHECK: %[[SEL:.*]] = arith.select %{{.*}}, %[[ZERO]], %[[NEG_INF]] : f32 +// CHECK: linalg.yield %[[SEL]] : f32 +// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention +// CHECK-SAME: ins(%[[Q]], %[[K]], %[[V]], %[[ONE]], %[[MASK_T]] +// CHECK-SAME: outs(%[[OUT_BUF]] : tensor<4x8x1024x64xf32>) { +// CHECK: ^bb0(%[[SCORE:.*]]: f32, %[[B:.*]]: index, %[[H:.*]]: index, %[[M:.*]]: index, %[[N:.*]]: index): +// CHECK: iree_linalg_ext.yield %[[SCORE]] : f32 +// CHECK: } -> tensor<4x8x1024x64xf32> +// CHECK: %[[T_OUT:.*]] = torch_c.from_builtin_tensor %[[ATTN]] +// CHECK: %[[T_LSE:.*]] = torch_c.from_builtin_tensor %[[LSE_BUF]] +// CHECK: return %[[T_OUT]], %[[T_LSE]] + +// ----- + +// flex_attention with score_mod only. +func.func @flex_attn_with_scoremod_only(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>,!torch.vtensor<[4,8,1024],f32>) attributes {torch.assume_strict_symbolic_shapes} { + %float1.000000e00 = torch.constant.float 1.000000e+00 + %true = torch.constant.bool true + // expected-warning @+1 {{FlexAttention: logsumexp output is a dummy (zeros), actual values are not available from AttentionOp}} + %output, %logsumexp = torch.aten.flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %true {score_mod_fn = @sdpa_score0} : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32> + return %output, %logsumexp : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32> +} +func.func private @sdpa_score0(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32> { + %0 = torch.aten.tanh %arg0 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} +// CHECK-LABEL: func.func @flex_attn_with_scoremod_only( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG1:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG2:.*]]: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>) +// CHECK: %[[LSE_BUF:.*]] = arith.constant dense<0.000000e+00> : tensor<4x8x1024xf32> +// CHECK: %[[OUT_BUF:.*]] = arith.constant dense<0.000000e+00> : tensor<4x8x1024x64xf32> +// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[Q:.*]] = torch_c.to_builtin_tensor %[[ARG0]] +// CHECK: %[[K:.*]] = torch_c.to_builtin_tensor %[[ARG1]] +// CHECK: %[[V:.*]] = torch_c.to_builtin_tensor %[[ARG2]] +// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention +// CHECK-SAME: ins(%[[Q]], %[[K]], %[[V]], %[[ONE]] +// CHECK-SAME: outs(%[[OUT_BUF]] : tensor<4x8x1024x64xf32>) { +// CHECK: ^bb0(%[[SCORE:.*]]: f32, %[[B:.*]]: index, %[[H:.*]]: index, %[[M:.*]]: index, %[[N:.*]]: index): +// CHECK: %[[SCORE_T:.*]] = tensor.from_elements %[[SCORE]] : tensor +// CHECK: %[[TORCH_SCORE:.*]] = torch_c.from_builtin_tensor %[[SCORE_T]] : tensor -> !torch.vtensor<[],f32> +// CHECK: %[[B_I32:.*]] = arith.index_cast %[[B]] : index to i32 +// CHECK: %[[B_T:.*]] = tensor.from_elements %[[B_I32]] : tensor +// CHECK: %[[TORCH_B:.*]] = torch_c.from_builtin_tensor %[[B_T]] : tensor -> !torch.vtensor<[],si32> +// CHECK: %[[H_I32:.*]] = arith.index_cast %[[H]] : index to i32 +// CHECK: %[[H_T:.*]] = tensor.from_elements %[[H_I32]] : tensor +// CHECK: %[[TORCH_H:.*]] = torch_c.from_builtin_tensor %[[H_T]] : tensor -> !torch.vtensor<[],si32> +// CHECK: %[[M_I32:.*]] = arith.index_cast %[[M]] : index to i32 +// CHECK: %[[M_T:.*]] = tensor.from_elements %[[M_I32]] : tensor +// CHECK: %[[TORCH_M:.*]] = torch_c.from_builtin_tensor %[[M_T]] : tensor -> !torch.vtensor<[],si32> +// CHECK: %[[N_I32:.*]] = arith.index_cast %[[N]] : index to i32 +// CHECK: %[[N_T:.*]] = tensor.from_elements %[[N_I32]] : tensor +// CHECK: %[[TORCH_N:.*]] = torch_c.from_builtin_tensor %[[N_T]] : tensor -> !torch.vtensor<[],si32> +// CHECK: %[[CALL:.*]] = func.call @sdpa_score0(%[[TORCH_SCORE]], %[[TORCH_B]], %[[TORCH_H]], %[[TORCH_M]], %[[TORCH_N]]) +// CHECK: %[[CALL_T:.*]] = torch_c.to_builtin_tensor %[[CALL]] : !torch.vtensor<[],f32> -> tensor +// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[CALL_T]][] : tensor +// CHECK: iree_linalg_ext.yield %[[EXTRACT]] : f32 +// CHECK: } -> tensor<4x8x1024x64xf32> +// CHECK: %[[T_OUT:.*]] = torch_c.from_builtin_tensor %[[ATTN]] +// CHECK: %[[T_LSE:.*]] = torch_c.from_builtin_tensor %[[LSE_BUF]] +// CHECK: return %[[T_OUT]], %[[T_LSE]] + +// ----- + +// flex_attention with neither score_mod nor mask_mod. +func.func @flex_attn_no_mods(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>,!torch.vtensor<[4,8,1024],f32>) attributes {torch.assume_strict_symbolic_shapes} { + %float1.000000e00 = torch.constant.float 1.000000e+00 + %true = torch.constant.bool true + // expected-warning @+1 {{FlexAttention: logsumexp output is a dummy (zeros), actual values are not available from AttentionOp}} + %output, %logsumexp = torch.aten.flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %true : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32> + return %output, %logsumexp : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32> +} +// CHECK-LABEL: func.func @flex_attn_no_mods( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG1:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG2:.*]]: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>) +// CHECK: %[[LSE_BUF:.*]] = arith.constant dense<0.000000e+00> : tensor<4x8x1024xf32> +// CHECK: %[[OUT_BUF:.*]] = arith.constant dense<0.000000e+00> : tensor<4x8x1024x64xf32> +// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[Q:.*]] = torch_c.to_builtin_tensor %[[ARG0]] +// CHECK: %[[K:.*]] = torch_c.to_builtin_tensor %[[ARG1]] +// CHECK: %[[V:.*]] = torch_c.to_builtin_tensor %[[ARG2]] +// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention +// CHECK-SAME: ins(%[[Q]], %[[K]], %[[V]], %[[ONE]] +// CHECK-SAME: outs(%[[OUT_BUF]] : tensor<4x8x1024x64xf32>) { +// CHECK: ^bb0(%[[SCORE:.*]]: f32, %[[B:.*]]: index, %[[H:.*]]: index, %[[M:.*]]: index, %[[N:.*]]: index): +// CHECK: iree_linalg_ext.yield %[[SCORE]] : f32 +// CHECK: } -> tensor<4x8x1024x64xf32> +// CHECK: %[[T_OUT:.*]] = torch_c.from_builtin_tensor %[[ATTN]] +// CHECK: %[[T_LSE:.*]] = torch_c.from_builtin_tensor %[[LSE_BUF]] +// CHECK: return %[[T_OUT]], %[[T_LSE]] + +// ----- + +// flex_attention without return_lse. +// No warning expected. (lse value is garbage.) +func.func @flex_attn_no_lse(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>,!torch.vtensor<[4,8,1024],f32>) attributes {torch.assume_strict_symbolic_shapes} { + %float1.000000e00 = torch.constant.float 1.000000e+00 + %false = torch.constant.bool false + %output, %logsumexp = torch.aten.flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %false : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32> + return %output, %logsumexp : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32> +} +// CHECK-LABEL: func.func @flex_attn_no_lse( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG1:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG2:.*]]: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>) +// CHECK: %[[LSE_BUF:.*]] = arith.constant dense<0.000000e+00> : tensor<4x8x1024xf32> +// CHECK: %[[OUT_BUF:.*]] = arith.constant dense<0.000000e+00> : tensor<4x8x1024x64xf32> +// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[Q:.*]] = torch_c.to_builtin_tensor %[[ARG0]] +// CHECK: %[[K:.*]] = torch_c.to_builtin_tensor %[[ARG1]] +// CHECK: %[[V:.*]] = torch_c.to_builtin_tensor %[[ARG2]] +// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention +// CHECK-SAME: ins(%[[Q]], %[[K]], %[[V]], %[[ONE]] +// CHECK-SAME: outs(%[[OUT_BUF]] : tensor<4x8x1024x64xf32>) { +// CHECK: ^bb0(%[[SCORE:.*]]: f32, %[[B:.*]]: index, %[[H:.*]]: index, %[[M:.*]]: index, %[[N:.*]]: index): +// CHECK: iree_linalg_ext.yield %[[SCORE]] : f32 +// CHECK: } -> tensor<4x8x1024x64xf32> +// CHECK: %[[T_OUT:.*]] = torch_c.from_builtin_tensor %[[ATTN]] +// CHECK: %[[T_LSE:.*]] = torch_c.from_builtin_tensor %[[LSE_BUF]] +// CHECK: return %[[T_OUT]], %[[T_LSE]] diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp index dc28a51e5251..3bf1a1a05954 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp @@ -191,18 +191,49 @@ static Value applyPostQKMatmulElementwise(OpBuilder &builder, Location loc, SmallVector indexingMaps{identityMap}; SmallVector iteratorTypes(rank, utils::IteratorType::parallel); - auto genericOp = - linalg::GenericOp::create(builder, loc, value.getType(), ValueRange{}, - value, indexingMaps, iteratorTypes); - auto &dstRegion = genericOp.getRegion(); - builder.cloneRegionBefore(region, dstRegion, dstRegion.end()); - { - OpBuilder::InsertionGuard withinRegion(builder); - builder.setInsertionPoint(dstRegion.back().getTerminator()); - linalg::YieldOp::create(builder, loc, - dstRegion.back().getTerminator()->getOperands()); - dstRegion.back().getTerminator()->erase(); - } + auto genericOp = linalg::GenericOp::create( + builder, loc, value.getType(), ValueRange{}, value, indexingMaps, + iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { + Value score = args[0]; + + // If the region is empty (no score modification), just yield the score + if (region.empty() || region.front().empty()) { + linalg::YieldOp::create(b, loc, score); + return; + } + + // Build index arguments if region expects them + SmallVector regionArgs; + regionArgs.push_back(score); + + if (region.front().getNumArguments() > 1) { + unsigned numExpectedIndices = region.front().getNumArguments() - 1; + for (unsigned i = 0; i < numExpectedIndices && i < rank; ++i) { + Value idx = b.create(loc, i); + regionArgs.push_back(idx); + } + // For missing dimensions, pass zero constants as dummy indices + for (unsigned i = rank; i < numExpectedIndices; ++i) { + Value zeroIdx = b.create(loc, 0); + regionArgs.push_back(zeroIdx); + } + } + + // Clone the region body inline + IRMapping mapping; + for (auto [arg, regionArg] : + llvm::zip_equal(regionArgs, region.front().getArguments())) { + mapping.map(regionArg, arg); + } + for (Operation &op : region.front().without_terminator()) { + b.clone(op, mapping); + } + auto yieldOp = + cast(region.front().getTerminator()); + Value result = mapping.lookup(yieldOp.getOperand(0)); + linalg::YieldOp::create(b, loc, result); + }); + return genericOp.getResult(0); } @@ -258,15 +289,14 @@ static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap, return genericOp.getResult(0); } -// Compute output = exp2(output - input) -static Value computeSubAndExp2(OpBuilder &builder, Location loc, - AffineMap inputMap, AffineMap outputMap, - Value input, Value output) { +// Compute output = exp2/exp(output - input) depending on useExp2 flag. +static Value computeSubAndExp(OpBuilder &builder, Location loc, + AffineMap inputMap, AffineMap outputMap, + Value input, Value output, bool useExp2) { SmallVector compressedMaps = compressUnusedDims(SmallVector{inputMap, outputMap}); inputMap = compressedMaps[0]; outputMap = compressedMaps[1]; - SmallVector iteratorTypes(inputMap.getNumDims(), utils::IteratorType::parallel); auto genericOp = linalg::GenericOp::create( @@ -277,8 +307,9 @@ static Value computeSubAndExp2(OpBuilder &builder, Location loc, Value in = convertScalarToDtype(b, loc, args[0], args[1].getType(), /*isUnsignedCast=*/false); Value diff = arith::SubFOp::create(b, loc, args[1], in); - Value weight = math::Exp2Op::create(b, loc, diff); - linalg::YieldOp::create(b, loc, weight); + Operation *weight = useExp2 ? math::Exp2Op::create(b, loc, diff) + : math::ExpOp::create(b, loc, diff); + linalg::YieldOp::create(b, loc, weight->getResult(0)); }); return genericOp.getResult(0); } @@ -314,15 +345,18 @@ Value computeQKAndElementwise(Location loc, OpBuilder &b, Value query, std::optional maskMap, SmallVector iterationDomain, Type sElementType, Region &elementwiseRegion, - DictionaryAttr qkAttrs, bool lowPrecision) { + DictionaryAttr qkAttrs, bool lowPrecision, + bool useExp2) { MLIRContext *ctx = b.getContext(); - // Since we use exp2 for attention instead of the original exp, we have to + // If using exp2 for attention instead of the original exp, we have to // multiply the scale by log2(e). We use exp2 instead of exp as most platforms // have better support for exp2 (we verified that we gain some speedup on // some GPUs). - Value log2e = arith::ConstantOp::create( - b, loc, b.getFloatAttr(scale.getType(), M_LOG2E)); - scale = arith::MulFOp::create(b, loc, scale, log2e); + if (useExp2) { + Value log2e = arith::ConstantOp::create( + b, loc, b.getFloatAttr(scale.getType(), M_LOG2E)); + scale = arith::MulFOp::create(b, loc, scale, log2e); + } auto qETy = getElementTypeOrSelf(query.getType()); @@ -434,9 +468,9 @@ FailureOr> AttentionOp::decomposeOperation(OpBuilder &b) { Type f32Type = b.getF32Type(); // ---- QK Matmul + elementwise math ---- - Value s = computeQKAndElementwise(loc, b, query, key, getScale(), mask, qMap, - kMap, sMap, getMaskMap(), sizes, f32Type, - getRegion(), qkAttrs, lowPrecision); + Value s = computeQKAndElementwise( + loc, b, query, key, getScale(), mask, qMap, kMap, sMap, getMaskMap(), + sizes, f32Type, getRegion(), qkAttrs, lowPrecision, /*useExp2=*/true); // ---- Softmax ---- @@ -476,9 +510,9 @@ FailureOr> AttentionOp::decomposeOperation(OpBuilder &b) { // max = rowMax(S) Value max = reduce(b, loc, sMap, maxMap, s, maxFill); - // P = exp2(S - max) + // P = exp2(S - max). AffineMap pMap = sMap; - Value p = computeSubAndExp2(b, loc, maxMap, sMap, max, s); + Value p = computeSubAndExp(b, loc, maxMap, sMap, max, s, /*useExp2=*/true); // sum = rowSum(P) Value sum = reduce(b, loc, pMap, sumMap, p, sumFill); @@ -528,9 +562,12 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { DictionaryAttr config = getDecompositionConfigAttr(); DictionaryAttr qkAttrs, pvAttrs; + bool useExp2 = true; if (config) { qkAttrs = config.getAs(getQKAttrStr()); pvAttrs = config.getAs(getPVAttrStr()); + if (auto useExp2Attr = config.getAs(getUseExp2AttrStr())) + useExp2 = useExp2Attr.getValue(); } FailureOr maybeOpInfo = AttentionOpDetail::get( @@ -551,7 +588,7 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { // ---- QK Matmul + elementwise math ---- Value s = computeQKAndElementwise( loc, b, query, key, getScale(), mask, qMap, kMap, sMap, getMaskMap(), - sizes, elementType, getRegion(), qkAttrs, lowPrecision); + sizes, elementType, getRegion(), qkAttrs, lowPrecision, useExp2); // TODO: This decomposition should be in a seperate op called // "online softmax". @@ -561,20 +598,21 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { AffineMap maxMap = getMaxMap(); Value newMax = reduce(b, loc, sMap, maxMap, s, oldMax); - // norm = exp2(oldMax - newMax) + // norm = exp2(oldMax - newMax) or exp(oldMax - newMax) depending on useExp2 // normMap = maxMap AffineMap normMap = getMaxMap(); - Value norm = computeSubAndExp2(b, loc, maxMap, normMap, newMax, oldMax); + Value norm = + computeSubAndExp(b, loc, maxMap, normMap, newMax, oldMax, useExp2); // normSum = norm * oldSum AffineMap sumMap = getSumMap(); Value normSum = elementwiseValueInPlace(b, loc, sumMap, normMap, oldSum, norm); - // P = exp2(S - newMax) + // P = exp2(S - newMax) or exp(S - newMax) depending on useExp2 // PMap = SMap AffineMap pMap = sMap; - Value p = computeSubAndExp2(b, loc, maxMap, sMap, newMax, s); + Value p = computeSubAndExp(b, loc, maxMap, sMap, newMax, s, useExp2); // newSum = normSum + rowSum(P) Value newSum = reduce(b, loc, pMap, sumMap, p, normSum); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp index 200d6fb546ae..769de1397eeb 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp @@ -2041,9 +2041,24 @@ LogicalResult AttentionOp::verify() { auto &block = getRegion().front(); auto blockTys = block.getArgumentTypes(); + if (blockTys.size() != 1 && blockTys.size() != 5) { + return attnOp->emitOpError( + "expects either 1 block argument (score) or 5 block arguments " + "(score, b, h, m, n)"); + } + if (!isa(blockTys[0])) return attnOp->emitOpError("block argument 0 should be float"); + // If 5 arguments, verify the indices are of index type + if (blockTys.size() == 5) { + for (unsigned i = 1; i < 5; ++i) { + if (!blockTys[i].isIndex()) { + return attnOp->emitOpError("block arguments 1-4 should be index type"); + } + } + } + auto yieldOp = dyn_cast(block.getTerminator()); if (!yieldOp) { return attnOp->emitOpError("expected linalg_ext.yield"); @@ -2220,14 +2235,25 @@ LogicalResult OnlineAttentionOp::verify() { Block &block = attnOp.getRegion().front(); auto blockTys = block.getArgumentTypes(); - if (blockTys.size() != 1) { - return attnOp->emitOpError("expects single block argument for score"); + if (blockTys.size() != 1 && blockTys.size() != 5) { + return attnOp->emitOpError( + "expects either 1 block argument (score) or 5 block arguments " + "(score, b, h, m, n)"); } if (!isa(blockTys[0])) { return attnOp->emitOpError("block argument 0 should be float"); } + // If 5 arguments, verify the indices are of index type + if (blockTys.size() == 5) { + for (unsigned i = 1; i < 5; ++i) { + if (!blockTys[i].isIndex()) { + return attnOp->emitOpError("block arguments 1-4 should be index type"); + } + } + } + auto yieldOp = dyn_cast(block.getTerminator()); if (!yieldOp) { return attnOp->emitOpError("expected linalg_ext.yield"); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index c44ffcf3eb2e..92e96bada2e4 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -789,6 +789,12 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_Op<"attention", If an additional mask argument M is included, the result of the first matmul is modified according to: Q @ K.T += M + + Region: + The region body can receive either 1 or 5 block arguments: + - 1 argument (legacy): score (element type of output) + - 5 arguments: score, b (batch index), h (head index), m (query seq index), n (key/value seq index) + The region should yield a single value (the modified score). }]; let arguments = (ins AnyShaped:$query, @@ -914,6 +920,15 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_Op<"online_attention", it over the entire softmax reduction dimension by: x, _, sum : results x = (1 / sum) * x + + Region: + The region body receives the following block arguments: + - score: the computed score value from Q @ K.T (element type of output) + - b: batch index (index type) + - h: head index (index type) + - m: query sequence index (index type) + - n: key/value sequence index (index type) + The region should yield a single value (the modified score). }]; let arguments = (ins AnyShaped:$query, @@ -998,6 +1013,8 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_Op<"online_attention", // Attributes to set on QK and PV matmul after decomposition. static StringRef getQKAttrStr() { return "qk_attrs"; } static StringRef getPVAttrStr() { return "pv_attrs"; } + // Flag to control whether to use exp2 (with log2(e) scaling) or exp. + static StringRef getUseExp2AttrStr() { return "use_exp2"; } }]; let hasCanonicalizer = 1; diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir index fae9e5b76b23..acf1c0f66419 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir @@ -416,3 +416,96 @@ func.func @online_attention_f8_masked(%query: tensor<192x1024x64xf8E4M3FNUZ>, // CHECK: linalg.generic // CHECK: arith.addf // CHECK: linalg.yield + + +// ----- + +// Spec to decompose online attention op. +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op + transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> () + transform.yield + } +} + +#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> +#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> +#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> +#mapS = affine_map<(batch, m, k1, k2, n) -> ()> +#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)> +#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> + +func.func @online_attention_f16_noexp2(%query: tensor<192x1024x64xf16>, + %key: tensor<192x1024x64xf16>, + %value: tensor<192x1024x64xf16>, + %output: tensor<192x1024x64xf32>, + %max: tensor<192x1024xf32>, + %sum: tensor<192x1024xf32>) + -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) { + %scale = arith.constant 1.0 : f16 + + %out:3 = iree_linalg_ext.online_attention + {decomposition_config = {use_exp2=false}, indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] } + ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16) + outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) { + ^bb0(%score: f32): + iree_linalg_ext.yield %score: f32 + } + -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32> + + return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32> +} + +// We want to check that we're correctly using exp +// when specified so from the decomposition_config. +// CHECK-LABEL: @online_attention_f16_noexp2 +// Q = Q * scale +// CHECK: linalg.generic +// CHECK: arith.mulf +// S = Q @ K +// CHECK: linalg.generic +// CHECK: arith.extf +// CHECK: arith.extf +// CHECK: arith.mulf +// CHECK: arith.addf +// CHECK: linalg.yield +// newMax = max(oldMax, rowMax(S)) +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.maximumf +// CHECK: linalg.yield +// norm = exp2(oldMax - newMax) +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.subf +// CHECK-NOT: math.exp2 +// CHECK: linalg.yield +// normSum = norm * oldSum +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.mulf +// CHECK: linalg.yield +// P = exp2(S - newMax) +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.subf +// CHECK-NOT: math.exp2 +// CHECK: linalg.yield +// newSum = normSum + rowSum(P) +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.addf +// CHECK: linalg.yield +// newAcc = norm * oldAcc +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.mulf +// CHECK: linalg.yield +// newAcc = P @ V + newAcc +// CHECK: linalg.generic +// CHECK: arith.extf +// CHECK: arith.extf +// CHECK: arith.mulf +// CHECK: arith.addf +// CHECK: linalg.yield diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp index ac7c42ab58ec..4cfa1154aa62 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp @@ -32,8 +32,15 @@ struct DecomposeAttentionPass final void DecomposeAttentionPass::runOnOperation() { MLIRContext *context = &getContext(); IRRewriter rewriter(context); + SmallVector decompositionConfigAttrs; + decompositionConfigAttrs.push_back( + rewriter.getNamedAttr("use_exp2", rewriter.getBoolAttr(useExp2))); + DictionaryAttr decompositionConfig = + rewriter.getDictionaryAttr(decompositionConfigAttrs); + getOperation().walk([&](OnlineAttentionOp onlineAtt) { rewriter.setInsertionPoint(onlineAtt); + onlineAtt.setDecompositionConfigAttr(decompositionConfig); FailureOr> results = onlineAtt.decomposeOperation(rewriter); if (failed(results)) { diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td index 841018d34ebe..60207d14e199 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td @@ -103,6 +103,11 @@ def DecomposeAttentionPass : InterfacePass<"iree-linalg-ext-decompose-attention", "mlir::FunctionOpInterface"> { let summary = "Decomposes attention op into a sequence of linalg ops"; + let options = [ + Option<"useExp2", "use-exp2", "bool", /*default=*/"true", + "Use exp2 for computations; Tunable to allow for accuracte computations" + "in case of accuracy losses due to fp-reassociation.">, + ]; } def ConvertAttentionToOnlineAttentionPass :