diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp index f5ffbc2de84b..3cb319dbe17a 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp @@ -209,7 +209,7 @@ static Value applyPostQKMatmulElementwise(OpBuilder &builder, Location loc, } static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap, - AffineMap maskMap, Value qk, Value mask) { + AffineMap maskMap, Value qk, Value mask, bool useExp2) { SmallVector compressedMaps = compressUnusedDims(SmallVector{qkMap, maskMap}); @@ -245,9 +245,11 @@ static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap, maskVal = convertScalarToDtype(b, loc, maskVal, qkVal.getType(), /*isUnsignedCast=*/false); // Scaling to compensate for base-2 softmax - Value log2e = arith::ConstantOp::create( - b, loc, b.getFloatAttr(qkVal.getType(), M_LOG2E)); - maskVal = arith::MulFOp::create(b, loc, maskVal, log2e); + if (useExp2) { + Value log2e = arith::ConstantOp::create( + b, loc, b.getFloatAttr(qkVal.getType(), M_LOG2E)); + maskVal = arith::MulFOp::create(b, loc, maskVal, log2e); + } } // Finally, set the returned value to the qk element plus the mask // element (or 0/-infinity if bool mask). We opt for a AddFOp (instead @@ -260,10 +262,10 @@ static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap, return genericOp.getResult(0); } -// Compute output = exp2(output - input) -static Value computeSubAndExp2(OpBuilder &builder, Location loc, - AffineMap inputMap, AffineMap outputMap, - Value input, Value output) { +// Compute output = exp2/exp(output - input) depending on useExp2 flag. +static Value computeSubAndExp(OpBuilder &builder, Location loc, + AffineMap inputMap, AffineMap outputMap, + Value input, Value output, bool useExp2) { SmallVector compressedMaps = compressUnusedDims(SmallVector{inputMap, outputMap}); inputMap = compressedMaps[0]; @@ -279,8 +281,9 @@ static Value computeSubAndExp2(OpBuilder &builder, Location loc, Value in = convertScalarToDtype(b, loc, args[0], args[1].getType(), /*isUnsignedCast=*/false); Value diff = arith::SubFOp::create(b, loc, args[1], in); - Value weight = math::Exp2Op::create(b, loc, diff); - linalg::YieldOp::create(b, loc, weight); + Operation *weight = useExp2 ? math::Exp2Op::create(b, loc, diff) + : math::ExpOp::create(b, loc, diff); + linalg::YieldOp::create(b, loc, weight->getResult(0)); }); return genericOp.getResult(0); } @@ -316,15 +319,18 @@ Value computeQKAndElementwise(Location loc, OpBuilder &b, Value query, std::optional maskMap, SmallVector iterationDomain, Type sElementType, Region &elementwiseRegion, - DictionaryAttr qkAttrs, bool lowPrecision) { + DictionaryAttr qkAttrs, bool lowPrecision, + bool useExp2) { MLIRContext *ctx = b.getContext(); - // Since we use exp2 for attention instead of the original exp, we have to + // If using exp2 for attention instead of the original exp, we have to // multiply the scale by log2(e). We use exp2 instead of exp as most platforms // have better support for exp2 (we verified that we gain some speedup on // some GPUs). - Value log2e = arith::ConstantOp::create( - b, loc, b.getFloatAttr(scale.getType(), M_LOG2E)); - scale = arith::MulFOp::create(b, loc, scale, log2e); + if (useExp2) { + Value log2e = arith::ConstantOp::create( + b, loc, b.getFloatAttr(scale.getType(), M_LOG2E)); + scale = arith::MulFOp::create(b, loc, scale, log2e); + } auto qETy = getElementTypeOrSelf(query.getType()); @@ -392,7 +398,7 @@ Value computeQKAndElementwise(Location loc, OpBuilder &b, Value query, // S += mask if (mask != nullptr) { - s = applyMask(b, loc, sMap, *maskMap, s, mask.value()); + s = applyMask(b, loc, sMap, *maskMap, s, mask.value(), useExp2); } return s; @@ -436,9 +442,9 @@ FailureOr> AttentionOp::decomposeOperation(OpBuilder &b) { Type f32Type = b.getF32Type(); // ---- QK Matmul + elementwise math ---- - Value s = computeQKAndElementwise(loc, b, query, key, getScale(), mask, qMap, - kMap, sMap, getMaskMap(), sizes, f32Type, - getRegion(), qkAttrs, lowPrecision); + Value s = computeQKAndElementwise( + loc, b, query, key, getScale(), mask, qMap, kMap, sMap, getMaskMap(), + sizes, f32Type, getRegion(), qkAttrs, lowPrecision, /*useExp2=*/true); // ---- Softmax ---- @@ -480,7 +486,7 @@ FailureOr> AttentionOp::decomposeOperation(OpBuilder &b) { // P = exp2(S - max) AffineMap pMap = sMap; - Value p = computeSubAndExp2(b, loc, maxMap, sMap, max, s); + Value p = computeSubAndExp(b, loc, maxMap, sMap, max, s, /*useExp2=*/true); // sum = rowSum(P) Value sum = reduce(b, loc, pMap, sumMap, p, sumFill); @@ -530,9 +536,13 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { DictionaryAttr config = getDecompositionConfigAttr(); DictionaryAttr qkAttrs, pvAttrs; + bool useExp2 = true; if (config) { qkAttrs = config.getAs(getQKAttrStr()); pvAttrs = config.getAs(getPVAttrStr()); + if (auto useExp2Attr = config.getAs(getUseExp2AttrStr())) { + useExp2 = useExp2Attr.getValue(); + } } FailureOr maybeOpInfo = AttentionOpDetail::get( @@ -553,7 +563,7 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { // ---- QK Matmul + elementwise math ---- Value s = computeQKAndElementwise( loc, b, query, key, getScale(), mask, qMap, kMap, sMap, getMaskMap(), - sizes, elementType, getRegion(), qkAttrs, lowPrecision); + sizes, elementType, getRegion(), qkAttrs, lowPrecision, useExp2); // TODO: This decomposition should be in a seperate op called // "online softmax". @@ -563,20 +573,21 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { AffineMap maxMap = getMaxMap(); Value newMax = reduce(b, loc, sMap, maxMap, s, oldMax); - // norm = exp2(oldMax - newMax) + // norm = exp2(oldMax - newMax) or exp(oldMax - newMax) depending on useExp2 // normMap = maxMap AffineMap normMap = getMaxMap(); - Value norm = computeSubAndExp2(b, loc, maxMap, normMap, newMax, oldMax); + Value norm = + computeSubAndExp(b, loc, maxMap, normMap, newMax, oldMax, useExp2); // normSum = norm * oldSum AffineMap sumMap = getSumMap(); Value normSum = elementwiseValueInPlace(b, loc, sumMap, normMap, oldSum, norm); - // P = exp2(S - newMax) + // P = exp2(S - newMax) or exp(S - newMax) depending on useExp2 // PMap = SMap AffineMap pMap = sMap; - Value p = computeSubAndExp2(b, loc, maxMap, sMap, newMax, s); + Value p = computeSubAndExp(b, loc, maxMap, sMap, newMax, s, useExp2); // newSum = normSum + rowSum(P) Value newSum = reduce(b, loc, pMap, sumMap, p, normSum); @@ -1211,11 +1222,11 @@ FailureOr> ExpReductionOp::decomposeOperation(OpBuilder &b) { Value currMax = reduce( rewriter, loc, normValMap, prevMaxMap, sValue->get(), prevMax->get()); // ex = e^{sValue - curr_max} - Value ex = computeSubAndExp2(rewriter, loc, prevMaxMap, normValMap, currMax, - sValue->get()); + Value ex = computeSubAndExp(rewriter, loc, prevMaxMap, normValMap, currMax, + sValue->get(), /*useExp2=*/true); // norm = e^(prev_max - curr_max) - Value norm = computeSubAndExp2(rewriter, loc, prevMaxMap, prevMaxMap, currMax, - prevMax->get()); + Value norm = computeSubAndExp(rewriter, loc, prevMaxMap, prevMaxMap, currMax, + prevMax->get(), /*useExp2=*/true); SmallVector inputs = getDpsInputs(); SmallVector normOuts(getNumDpsInits()); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index ac18c72e7347..7b5d58098012 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -1069,6 +1069,16 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_Op<"online_attention", it over the entire softmax reduction dimension by: x, _, sum : results x = (1 / sum) * x + + Decomposition Configuration: + The `decomposition_config` attribute is a DictionaryAttr that controls how + this operation is decomposed into lower-level operations. It supports: + - "qk_attrs": DictionaryAttr - Attributes to attach to the Q@K matmul + operation after decomposition (e.g., lowering_config, attention markers) + - "pv_attrs": DictionaryAttr - Attributes to attach to the P@V matmul + operation after decomposition + - "use_exp2": BoolAttr - If true, uses exp2 with log2(e) scaling instead + of exp. (Gives better perf on some hardware, but trades off accuracy) }]; let arguments = (ins AnyShaped:$query, @@ -1153,6 +1163,8 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_Op<"online_attention", // Attributes to set on QK and PV matmul after decomposition. static StringRef getQKAttrStr() { return "qk_attrs"; } static StringRef getPVAttrStr() { return "pv_attrs"; } + // Flag to control whether to use exp2 (with log2(e) scaling) or exp. + static StringRef getUseExp2AttrStr() { return "use_exp2"; } }]; let hasCanonicalizer = 1; diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir index 7de65d86c8c7..85c21814e701 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir @@ -190,6 +190,7 @@ func.func @online_attention_f16(%query: tensor<192x1024x64xf16>, // correct number of extf/truncfs are emitted. // CHECK-LABEL: @online_attention_f16 // Q = Q * scale +// CHECK: arith.constant 1.442380e+00 : f16 // CHECK: linalg.generic // CHECK: arith.mulf // S = Q @ K @@ -419,6 +420,65 @@ func.func @online_attention_f8_masked(%query: tensor<192x1024x64xf8E4M3FNUZ>, // ----- +// Spec to decompose online attention op. +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op + transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> () + transform.yield + } +} + +#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> +#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> +#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> +#mapS = affine_map<(batch, m, k1, k2, n) -> ()> +#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)> +#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> + +func.func @online_attention_f16_noexp2(%query: tensor<192x1024x64xf16>, + %key: tensor<192x1024x64xf16>, + %value: tensor<192x1024x64xf16>, + %output: tensor<192x1024x64xf32>, + %max: tensor<192x1024xf32>, + %sum: tensor<192x1024xf32>) + -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) { + %scale = arith.constant 1.0 : f16 + + %out:3 = iree_linalg_ext.online_attention + {decomposition_config = {use_exp2=false}, indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] } + ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16) + outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) { + ^bb0(%score: f32): + iree_linalg_ext.yield %score: f32 + } + -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32> + + return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32> +} + +// We want to check that we're correctly using exp +// when specified so from the decomposition_config. +// CHECK-LABEL: @online_attention_f16_noexp2 +// Q = Q * scale +// CHECK: arith.constant 1.000000e+00 : f16 +// CHECK: linalg.generic +// CHECK: arith.mulf +// norm = exp (oldMax - newMax) +// CHECK: linalg.generic +// CHECK: arith.subf +// CHECK-NOT: arith.extf +// CHECK-NOT: math.exp2 +// CHECK: linalg.yield +// P = exp(S - newMax) +// CHECK: linalg.generic +// CHECK: arith.subf +// CHECK-NOT: arith.extf +// CHECK-NOT: math.exp2 +// CHECK: linalg.yield + +// ----- + // Spec to decompose exp reduction op. module attributes { transform.with_named_sequence } { transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp index ac7c42ab58ec..1fb50d83d52a 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp @@ -32,8 +32,15 @@ struct DecomposeAttentionPass final void DecomposeAttentionPass::runOnOperation() { MLIRContext *context = &getContext(); IRRewriter rewriter(context); + getOperation().walk([&](OnlineAttentionOp onlineAtt) { rewriter.setInsertionPoint(onlineAtt); + + NamedAttrList decompositionConfig(onlineAtt.getDecompositionConfigAttr()); + decompositionConfig.set("use_exp2", rewriter.getBoolAttr(useExp2)); + onlineAtt.setDecompositionConfigAttr( + decompositionConfig.getDictionary(context)); + FailureOr> results = onlineAtt.decomposeOperation(rewriter); if (failed(results)) { diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td index c1ce03397950..ea27c2e16fd7 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td @@ -116,6 +116,11 @@ def DecomposeAttentionPass : InterfacePass<"iree-linalg-ext-decompose-attention", "mlir::FunctionOpInterface"> { let summary = "Decomposes attention op into a sequence of linalg ops"; + let options = [ + Option<"useExp2", "use-exp2", "bool", /*default=*/"true", + "Use exp2 for computations; Tunable to allow for accuracte computations" + "in case of accuracy losses due to fp-reassociation.">, + ]; } def ConvertAttentionToOnlineAttentionPass :