From 99e4aa963dbe36c00dfc0b2f841457fcb471b9f3 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Wed, 26 Nov 2025 21:57:12 -0800 Subject: [PATCH 01/12] Added toggle for using useexp2 for onlineAttention Decomposition Signed-off-by: Keshav Vinayak Jha --- .../IR/AggregatedOpInterfaceImpl.cpp | 46 ++++++---- .../Dialect/LinalgExt/IR/LinalgExtOps.td | 2 + .../IR/test/decompose_aggregate_op.mlir | 92 +++++++++++++++++++ 3 files changed, 123 insertions(+), 17 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp index f5ffbc2de84b..2d10da0063b1 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp @@ -260,10 +260,10 @@ static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap, return genericOp.getResult(0); } -// Compute output = exp2(output - input) -static Value computeSubAndExp2(OpBuilder &builder, Location loc, - AffineMap inputMap, AffineMap outputMap, - Value input, Value output) { +// Compute output = exp2/exp(output - input) depending on useExp2 flag. +static Value computeSubAndExp(OpBuilder &builder, Location loc, + AffineMap inputMap, AffineMap outputMap, + Value input, Value output, bool useExp2) { SmallVector compressedMaps = compressUnusedDims(SmallVector{inputMap, outputMap}); inputMap = compressedMaps[0]; @@ -279,8 +279,9 @@ static Value computeSubAndExp2(OpBuilder &builder, Location loc, Value in = convertScalarToDtype(b, loc, args[0], args[1].getType(), /*isUnsignedCast=*/false); Value diff = arith::SubFOp::create(b, loc, args[1], in); - Value weight = math::Exp2Op::create(b, loc, diff); - linalg::YieldOp::create(b, loc, weight); + Operation *weight = useExp2 ? math::Exp2Op::create(b, loc, diff) + : math::ExpOp::create(b, loc, diff); + linalg::YieldOp::create(b, loc, weight->getResult(0)); }); return genericOp.getResult(0); } @@ -316,12 +317,18 @@ Value computeQKAndElementwise(Location loc, OpBuilder &b, Value query, std::optional maskMap, SmallVector iterationDomain, Type sElementType, Region &elementwiseRegion, - DictionaryAttr qkAttrs, bool lowPrecision) { + DictionaryAttr qkAttrs, bool lowPrecision, + bool useExp2) { MLIRContext *ctx = b.getContext(); - // Since we use exp2 for attention instead of the original exp, we have to + // If using exp2 for attention instead of the original exp, we have to // multiply the scale by log2(e). We use exp2 instead of exp as most platforms // have better support for exp2 (we verified that we gain some speedup on // some GPUs). + if (useExp2) { + Value log2e = arith::ConstantOp::create( + b, loc, b.getFloatAttr(scale.getType(), M_LOG2E)); + scale = arith::MulFOp::create(b, loc, scale, log2e); + } Value log2e = arith::ConstantOp::create( b, loc, b.getFloatAttr(scale.getType(), M_LOG2E)); scale = arith::MulFOp::create(b, loc, scale, log2e); @@ -436,9 +443,9 @@ FailureOr> AttentionOp::decomposeOperation(OpBuilder &b) { Type f32Type = b.getF32Type(); // ---- QK Matmul + elementwise math ---- - Value s = computeQKAndElementwise(loc, b, query, key, getScale(), mask, qMap, - kMap, sMap, getMaskMap(), sizes, f32Type, - getRegion(), qkAttrs, lowPrecision); + Value s = computeQKAndElementwise( + loc, b, query, key, getScale(), mask, qMap, kMap, sMap, getMaskMap(), + sizes, f32Type, getRegion(), qkAttrs, lowPrecision, /*useExp2=*/true); // ---- Softmax ---- @@ -480,7 +487,7 @@ FailureOr> AttentionOp::decomposeOperation(OpBuilder &b) { // P = exp2(S - max) AffineMap pMap = sMap; - Value p = computeSubAndExp2(b, loc, maxMap, sMap, max, s); + Value p = computeSubAndExp(b, loc, maxMap, sMap, max, s, /*useExp2=*/true); // sum = rowSum(P) Value sum = reduce(b, loc, pMap, sumMap, p, sumFill); @@ -530,9 +537,13 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { DictionaryAttr config = getDecompositionConfigAttr(); DictionaryAttr qkAttrs, pvAttrs; + bool useExp2 = true; if (config) { qkAttrs = config.getAs(getQKAttrStr()); pvAttrs = config.getAs(getPVAttrStr()); + if (auto useExp2Attr = config.getAs(getUseExp2AttrStr())) { + useExp2 = useExp2Attr.getValue(); + } } FailureOr maybeOpInfo = AttentionOpDetail::get( @@ -553,7 +564,7 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { // ---- QK Matmul + elementwise math ---- Value s = computeQKAndElementwise( loc, b, query, key, getScale(), mask, qMap, kMap, sMap, getMaskMap(), - sizes, elementType, getRegion(), qkAttrs, lowPrecision); + sizes, elementType, getRegion(), qkAttrs, lowPrecision, useExp2); // TODO: This decomposition should be in a seperate op called // "online softmax". @@ -563,20 +574,21 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { AffineMap maxMap = getMaxMap(); Value newMax = reduce(b, loc, sMap, maxMap, s, oldMax); - // norm = exp2(oldMax - newMax) + // norm = exp2(oldMax - newMax) or exp(oldMax - newMax) depending on useExp2 // normMap = maxMap AffineMap normMap = getMaxMap(); - Value norm = computeSubAndExp2(b, loc, maxMap, normMap, newMax, oldMax); + Value norm = + computeSubAndExp(b, loc, maxMap, normMap, newMax, oldMax, useExp2); // normSum = norm * oldSum AffineMap sumMap = getSumMap(); Value normSum = elementwiseValueInPlace(b, loc, sumMap, normMap, oldSum, norm); - // P = exp2(S - newMax) + // P = exp2(S - newMax) or exp(S - newMax) depending on useExp2 // PMap = SMap AffineMap pMap = sMap; - Value p = computeSubAndExp2(b, loc, maxMap, sMap, newMax, s); + Value p = computeSubAndExp(b, loc, maxMap, sMap, newMax, s, useExp2); // newSum = normSum + rowSum(P) Value newSum = reduce(b, loc, pMap, sumMap, p, normSum); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index ac18c72e7347..a9d2f41f8c78 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -1153,6 +1153,8 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_Op<"online_attention", // Attributes to set on QK and PV matmul after decomposition. static StringRef getQKAttrStr() { return "qk_attrs"; } static StringRef getPVAttrStr() { return "pv_attrs"; } + // Flag to control whether to use exp2 (with log2(e) scaling) or exp. + static StringRef getUseExp2AttrStr() { return "use_exp2"; } }]; let hasCanonicalizer = 1; diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir index 7de65d86c8c7..c973621d57af 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir @@ -419,6 +419,98 @@ func.func @online_attention_f8_masked(%query: tensor<192x1024x64xf8E4M3FNUZ>, // ----- +// Spec to decompose online attention op. +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op + transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> () + transform.yield + } +} + +#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> +#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> +#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> +#mapS = affine_map<(batch, m, k1, k2, n) -> ()> +#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)> +#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> + +func.func @online_attention_f16_noexp2(%query: tensor<192x1024x64xf16>, + %key: tensor<192x1024x64xf16>, + %value: tensor<192x1024x64xf16>, + %output: tensor<192x1024x64xf32>, + %max: tensor<192x1024xf32>, + %sum: tensor<192x1024xf32>) + -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) { + %scale = arith.constant 1.0 : f16 + + %out:3 = iree_linalg_ext.online_attention + {decomposition_config = {use_exp2=false}, indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] } + ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16) + outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) { + ^bb0(%score: f32): + iree_linalg_ext.yield %score: f32 + } + -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32> + + return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32> +} + +// We want to check that we're correctly using exp +// when specified so from the decomposition_config. +// CHECK-LABEL: @online_attention_f16_noexp2 +// Q = Q * scale +// CHECK: linalg.generic +// CHECK: arith.mulf +// S = Q @ K +// CHECK: linalg.generic +// CHECK: arith.extf +// CHECK: arith.extf +// CHECK: arith.mulf +// CHECK: arith.addf +// CHECK: linalg.yield +// newMax = max(oldMax, rowMax(S)) +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.maximumf +// CHECK: linalg.yield +// norm = exp (oldMax - newMax) +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.subf +// CHECK-NOT: math.exp2 +// CHECK: linalg.yield +// normSum = norm * oldSum +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.mulf +// CHECK: linalg.yield +// P = exp(S - newMax) +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.subf +// CHECK-NOT: math.exp2 +// CHECK: linalg.yield +// newSum = normSum + rowSum(P) +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.addf +// CHECK: linalg.yield +// newAcc = norm * oldAcc +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.mulf +// CHECK: linalg.yield +// newAcc = P @ V + newAcc +// CHECK: linalg.generic +// CHECK: arith.extf +// CHECK: arith.extf +// CHECK: arith.mulf +// CHECK: arith.addf +// CHECK: linalg.yield + +// ----- + // Spec to decompose exp reduction op. module attributes { transform.with_named_sequence } { transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { From 81691492826722cb92bcdfd7cc80174212787dbe Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Wed, 26 Nov 2025 22:24:09 -0800 Subject: [PATCH 02/12] Added useExp2 as pass option to DecomposeAttention Signed-off-by: Keshav Vinayak Jha --- .../Dialect/LinalgExt/Transforms/DecomposeAttention.cpp | 8 ++++++++ .../iree/compiler/Dialect/LinalgExt/Transforms/Passes.td | 5 +++++ 2 files changed, 13 insertions(+) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp index ac7c42ab58ec..97a0d8740710 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp @@ -32,8 +32,16 @@ struct DecomposeAttentionPass final void DecomposeAttentionPass::runOnOperation() { MLIRContext *context = &getContext(); IRRewriter rewriter(context); + + SmallVector decompositionConfigAttrs; + decompositionConfigAttrs.push_back( + rewriter.getNamedAttr("use_exp2", rewriter.getBoolAttr(useExp2))); + DictionaryAttr decompositionConfig = + rewriter.getDictionaryAttr(decompositionConfigAttrs); + getOperation().walk([&](OnlineAttentionOp onlineAtt) { rewriter.setInsertionPoint(onlineAtt); + onlineAtt.setDecompositionConfigAttr(decompositionConfig); FailureOr> results = onlineAtt.decomposeOperation(rewriter); if (failed(results)) { diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td index c1ce03397950..ea27c2e16fd7 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td @@ -116,6 +116,11 @@ def DecomposeAttentionPass : InterfacePass<"iree-linalg-ext-decompose-attention", "mlir::FunctionOpInterface"> { let summary = "Decomposes attention op into a sequence of linalg ops"; + let options = [ + Option<"useExp2", "use-exp2", "bool", /*default=*/"true", + "Use exp2 for computations; Tunable to allow for accuracte computations" + "in case of accuracy losses due to fp-reassociation.">, + ]; } def ConvertAttentionToOnlineAttentionPass : From 299e1eebd0d25e90fc4851dc1594b0b037f6343d Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Thu, 27 Nov 2025 02:29:39 -0800 Subject: [PATCH 03/12] Removed Typo Signed-off-by: Keshav Vinayak Jha --- .../Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp index 2d10da0063b1..170fcfc367ce 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp @@ -329,9 +329,6 @@ Value computeQKAndElementwise(Location loc, OpBuilder &b, Value query, b, loc, b.getFloatAttr(scale.getType(), M_LOG2E)); scale = arith::MulFOp::create(b, loc, scale, log2e); } - Value log2e = arith::ConstantOp::create( - b, loc, b.getFloatAttr(scale.getType(), M_LOG2E)); - scale = arith::MulFOp::create(b, loc, scale, log2e); auto qETy = getElementTypeOrSelf(query.getType()); From 616c36c7f203053ba59a42ca8bb3c41014db8b07 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 16 Dec 2025 00:12:05 -0800 Subject: [PATCH 04/12] Simplified FileChecks; Added check of log2e vs 1.0 scaling Signed-off-by: Keshav Vinayak Jha --- .../IR/test/decompose_aggregate_op.mlir | 23 ++++--------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir index c973621d57af..21185db2d70c 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir @@ -190,6 +190,7 @@ func.func @online_attention_f16(%query: tensor<192x1024x64xf16>, // correct number of extf/truncfs are emitted. // CHECK-LABEL: @online_attention_f16 // Q = Q * scale +// CHECK: arith.constant 1.442380e+00 : f16 // CHECK: linalg.generic // CHECK: arith.mulf // S = Q @ K @@ -460,35 +461,19 @@ func.func @online_attention_f16_noexp2(%query: tensor<192x1024x64xf16>, // when specified so from the decomposition_config. // CHECK-LABEL: @online_attention_f16_noexp2 // Q = Q * scale +// CHECK: arith.constant 1.000000e+00 : f16 // CHECK: linalg.generic // CHECK: arith.mulf -// S = Q @ K -// CHECK: linalg.generic -// CHECK: arith.extf -// CHECK: arith.extf -// CHECK: arith.mulf -// CHECK: arith.addf -// CHECK: linalg.yield -// newMax = max(oldMax, rowMax(S)) -// CHECK: linalg.generic -// CHECK-NOT: arith.extf -// CHECK: arith.maximumf -// CHECK: linalg.yield // norm = exp (oldMax - newMax) // CHECK: linalg.generic -// CHECK-NOT: arith.extf // CHECK: arith.subf -// CHECK-NOT: math.exp2 -// CHECK: linalg.yield -// normSum = norm * oldSum -// CHECK: linalg.generic // CHECK-NOT: arith.extf -// CHECK: arith.mulf +// CHECK-NOT: math.exp2 // CHECK: linalg.yield // P = exp(S - newMax) // CHECK: linalg.generic -// CHECK-NOT: arith.extf // CHECK: arith.subf +// CHECK-NOT: arith.extf // CHECK-NOT: math.exp2 // CHECK: linalg.yield // newSum = normSum + rowSum(P) From acf8c232942fa8fb780f06e6a2c53779f84f9ff4 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 16 Dec 2025 00:12:42 -0800 Subject: [PATCH 05/12] Newline at EOF Signed-off-by: Keshav Vinayak Jha --- .../IR/test/decompose_aggregate_op.mlir | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir index 21185db2d70c..7eff2e014b79 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir @@ -476,23 +476,6 @@ func.func @online_attention_f16_noexp2(%query: tensor<192x1024x64xf16>, // CHECK-NOT: arith.extf // CHECK-NOT: math.exp2 // CHECK: linalg.yield -// newSum = normSum + rowSum(P) -// CHECK: linalg.generic -// CHECK-NOT: arith.extf -// CHECK: arith.addf -// CHECK: linalg.yield -// newAcc = norm * oldAcc -// CHECK: linalg.generic -// CHECK-NOT: arith.extf -// CHECK: arith.mulf -// CHECK: linalg.yield -// newAcc = P @ V + newAcc -// CHECK: linalg.generic -// CHECK: arith.extf -// CHECK: arith.extf -// CHECK: arith.mulf -// CHECK: arith.addf -// CHECK: linalg.yield // ----- @@ -572,4 +555,4 @@ func.func @exp_reduction( // CHECK-SAME: outs(%[[acc_norm]] // CHECK: arith.mulf // CHECK: arith.addf -// CHECK: return %[[M]], %[[SUM]], %[[PV]] +// CHECK: return %[[M]], %[[SUM]], %[[PV]] \ No newline at end of file From 2267a06728b5b3270492d95f3276eded22aef8c7 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 16 Dec 2025 01:28:55 -0800 Subject: [PATCH 06/12] Mask scaling is conditional to useExp2 Signed-off-by: Keshav Vinayak Jha --- .../LinalgExt/IR/AggregatedOpInterfaceImpl.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp index 170fcfc367ce..2a2b00500da8 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp @@ -209,7 +209,7 @@ static Value applyPostQKMatmulElementwise(OpBuilder &builder, Location loc, } static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap, - AffineMap maskMap, Value qk, Value mask) { + AffineMap maskMap, Value qk, Value mask, bool useExp2) { SmallVector compressedMaps = compressUnusedDims(SmallVector{qkMap, maskMap}); @@ -245,9 +245,11 @@ static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap, maskVal = convertScalarToDtype(b, loc, maskVal, qkVal.getType(), /*isUnsignedCast=*/false); // Scaling to compensate for base-2 softmax - Value log2e = arith::ConstantOp::create( - b, loc, b.getFloatAttr(qkVal.getType(), M_LOG2E)); - maskVal = arith::MulFOp::create(b, loc, maskVal, log2e); + if (useExp2) { + Value log2e = arith::ConstantOp::create( + b, loc, b.getFloatAttr(qkVal.getType(), M_LOG2E)); + maskVal = arith::MulFOp::create(b, loc, maskVal, log2e); + } } // Finally, set the returned value to the qk element plus the mask // element (or 0/-infinity if bool mask). We opt for a AddFOp (instead @@ -396,7 +398,7 @@ Value computeQKAndElementwise(Location loc, OpBuilder &b, Value query, // S += mask if (mask != nullptr) { - s = applyMask(b, loc, sMap, *maskMap, s, mask.value()); + s = applyMask(b, loc, sMap, *maskMap, s, mask.value(), useExp2); } return s; From bec3f0da31cdff184c7c76759e2ebc5533f5dabb Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 16 Dec 2025 03:25:39 -0800 Subject: [PATCH 07/12] Bug in code: Overwriting the existing DecompositionAttr, we want to add use_exp2 not overwrite Signed-off-by: Keshav Vinayak Jha --- .../LinalgExt/Transforms/DecomposeAttention.cpp | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp index 97a0d8740710..1fb50d83d52a 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp @@ -33,15 +33,14 @@ void DecomposeAttentionPass::runOnOperation() { MLIRContext *context = &getContext(); IRRewriter rewriter(context); - SmallVector decompositionConfigAttrs; - decompositionConfigAttrs.push_back( - rewriter.getNamedAttr("use_exp2", rewriter.getBoolAttr(useExp2))); - DictionaryAttr decompositionConfig = - rewriter.getDictionaryAttr(decompositionConfigAttrs); - getOperation().walk([&](OnlineAttentionOp onlineAtt) { rewriter.setInsertionPoint(onlineAtt); - onlineAtt.setDecompositionConfigAttr(decompositionConfig); + + NamedAttrList decompositionConfig(onlineAtt.getDecompositionConfigAttr()); + decompositionConfig.set("use_exp2", rewriter.getBoolAttr(useExp2)); + onlineAtt.setDecompositionConfigAttr( + decompositionConfig.getDictionary(context)); + FailureOr> results = onlineAtt.decomposeOperation(rewriter); if (failed(results)) { From f14ccfe61fa8a62ff7144e482d54fa28c5ad8a01 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 16 Dec 2025 03:35:28 -0800 Subject: [PATCH 08/12] Added docs for Decomposition Configuration: Signed-off-by: Keshav Vinayak Jha --- .../iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index a9d2f41f8c78..729f8d0cac73 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -1069,6 +1069,16 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_Op<"online_attention", it over the entire softmax reduction dimension by: x, _, sum : results x = (1 / sum) * x + + Decomposition Configuration: + The `decomposition_config` attribute is a DictionaryAttr that controls how + this operation is decomposed into lower-level operations. It supports: + - "qk_attrs": DictionaryAttr - Attributes to attach to the Q@K matmul + operation after decomposition (e.g., lowering_config, attention markers) + - "pv_attrs": DictionaryAttr - Attributes to attach to the P@V matmul + operation after decomposition + - "use_exp2": BoolAttr - If true, uses exp2 with log2(e) scaling instead + of exp. (Might be better accuracy-wise on some hardware) }]; let arguments = (ins AnyShaped:$query, From 9811d72ce4feb466d829fc876fb003fb0c2c3598 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 16 Dec 2025 07:39:42 -0800 Subject: [PATCH 09/12] Nit comment Signed-off-by: Keshav Vinayak Jha --- compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index 729f8d0cac73..7b5d58098012 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -1078,7 +1078,7 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_Op<"online_attention", - "pv_attrs": DictionaryAttr - Attributes to attach to the P@V matmul operation after decomposition - "use_exp2": BoolAttr - If true, uses exp2 with log2(e) scaling instead - of exp. (Might be better accuracy-wise on some hardware) + of exp. (Gives better perf on some hardware, but trades off accuracy) }]; let arguments = (ins AnyShaped:$query, From 45195f969eced6943e83e5947f668541fa648226 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha <31160700+keshavvinayak01@users.noreply.github.com> Date: Tue, 23 Dec 2025 13:27:40 +0530 Subject: [PATCH 10/12] Refactor computeSubAndExp2 to computeSubAndExp Updated computeSubAndExp2 calls to computeSubAndExp with useExp2 flag. Signed-off-by: Keshav Vinayak Jha --- .../Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp index 2a2b00500da8..d2d35fb5dbed 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp @@ -1222,11 +1222,11 @@ FailureOr> ExpReductionOp::decomposeOperation(OpBuilder &b) { Value currMax = reduce( rewriter, loc, normValMap, prevMaxMap, sValue->get(), prevMax->get()); // ex = e^{sValue - curr_max} - Value ex = computeSubAndExp2(rewriter, loc, prevMaxMap, normValMap, currMax, - sValue->get()); + Value ex = computeSubAndExp(rewriter, loc, prevMaxMap, normValMap, currMax, + sValue->get(), /*useExp2=*/true); // norm = e^(prev_max - curr_max) - Value norm = computeSubAndExp2(rewriter, loc, prevMaxMap, prevMaxMap, currMax, - prevMax->get()); + Value norm = computeSubAndExp(rewriter, loc, prevMaxMap, prevMaxMap, currMax, + prevMax->get(), /*useExp2=*/true); SmallVector inputs = getDpsInputs(); SmallVector normOuts(getNumDpsInits()); From 56e06170399b49c318fb347bfd663ab06d16fc10 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 23 Dec 2025 08:11:26 +0000 Subject: [PATCH 11/12] Formatting Signed-off-by: Keshav Vinayak Jha --- .../Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp index d2d35fb5dbed..3cb319dbe17a 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp @@ -1223,10 +1223,10 @@ FailureOr> ExpReductionOp::decomposeOperation(OpBuilder &b) { rewriter, loc, normValMap, prevMaxMap, sValue->get(), prevMax->get()); // ex = e^{sValue - curr_max} Value ex = computeSubAndExp(rewriter, loc, prevMaxMap, normValMap, currMax, - sValue->get(), /*useExp2=*/true); + sValue->get(), /*useExp2=*/true); // norm = e^(prev_max - curr_max) Value norm = computeSubAndExp(rewriter, loc, prevMaxMap, prevMaxMap, currMax, - prevMax->get(), /*useExp2=*/true); + prevMax->get(), /*useExp2=*/true); SmallVector inputs = getDpsInputs(); SmallVector normOuts(getNumDpsInits()); From 7d4f0b0b814ca6af7faad21c4e21815364f10293 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 20 Jan 2026 06:01:56 +0000 Subject: [PATCH 12/12] Formatting Signed-off-by: Keshav Vinayak Jha --- .../Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir index 7eff2e014b79..85c21814e701 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir @@ -555,4 +555,4 @@ func.func @exp_reduction( // CHECK-SAME: outs(%[[acc_norm]] // CHECK: arith.mulf // CHECK: arith.addf -// CHECK: return %[[M]], %[[SUM]], %[[PV]] \ No newline at end of file +// CHECK: return %[[M]], %[[SUM]], %[[PV]]