Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
9280d3f
Initial support for aten::flex_attention and rewrite to linalgext.onl…
keshavvinayak01 Oct 28, 2025
857c180
Pass iteration indices as block arguments to OnlineAttentionOp region
keshavvinayak01 Nov 8, 2025
1cd5d58
Initialized reduction vals; Corrected func::call op init
keshavvinayak01 Nov 9, 2025
eea83bd
Added lit test (lse, nolse)
keshavvinayak01 Nov 9, 2025
53796c6
Removed previous support added for linalgext::index op to nest in att…
keshavvinayak01 Nov 10, 2025
336f6f5
Formatting
keshavvinayak01 Nov 10, 2025
b6a6bb6
Added changes to use IREE::AttentionOp; support reqiured for multi-di…
keshavvinayak01 Nov 11, 2025
fca1fe8
Removed unused comment
keshavvinayak01 Nov 11, 2025
f860ccc
Added changes for correct computations:
keshavvinayak01 Nov 12, 2025
06e9499
.create -> OpTy::create
keshavvinayak01 Nov 12, 2025
eaa0b93
Removed gqa argument
keshavvinayak01 Nov 12, 2025
1f5107d
- Removed redundant/debug comments
keshavvinayak01 Nov 14, 2025
3383d55
Added missing newline
keshavvinayak01 Nov 17, 2025
8e929ce
Addressed comments and changes:
keshavvinayak01 Nov 17, 2025
f527211
Formattingh
keshavvinayak01 Nov 17, 2025
2b54533
Restored ReshapeFusion.cpp to original format
keshavvinayak01 Nov 17, 2025
abdc3da
Added 4 more lit tests covering with/without score_mod, mask_mod, and…
keshavvinayak01 Nov 18, 2025
1036526
Smaller matchAndRewrite calls to specific utility functions; easily r…
keshavvinayak01 Nov 18, 2025
64aae75
Replaced magic numbers with meaningful consts
keshavvinayak01 Nov 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -191,18 +191,49 @@ static Value applyPostQKMatmulElementwise(OpBuilder &builder, Location loc,
SmallVector<AffineMap> indexingMaps{identityMap};
SmallVector<utils::IteratorType> iteratorTypes(rank,
utils::IteratorType::parallel);
auto genericOp =
linalg::GenericOp::create(builder, loc, value.getType(), ValueRange{},
value, indexingMaps, iteratorTypes);
auto &dstRegion = genericOp.getRegion();
builder.cloneRegionBefore(region, dstRegion, dstRegion.end());
{
OpBuilder::InsertionGuard withinRegion(builder);
builder.setInsertionPoint(dstRegion.back().getTerminator());
linalg::YieldOp::create(builder, loc,
dstRegion.back().getTerminator()->getOperands());
dstRegion.back().getTerminator()->erase();
}
auto genericOp = linalg::GenericOp::create(
builder, loc, value.getType(), ValueRange{}, value, indexingMaps,
iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
Value score = args[0];

// If the region is empty (no score modification), just yield the score
if (region.empty() || region.front().empty()) {
linalg::YieldOp::create(b, loc, score);
return;
}

// Build index arguments if region expects them
SmallVector<Value> regionArgs;
regionArgs.push_back(score);

if (region.front().getNumArguments() > 1) {
unsigned numExpectedIndices = region.front().getNumArguments() - 1;
for (unsigned i = 0; i < numExpectedIndices && i < rank; ++i) {
Value idx = b.create<linalg::IndexOp>(loc, i);
regionArgs.push_back(idx);
}
// For missing dimensions, pass zero constants as dummy indices
for (unsigned i = rank; i < numExpectedIndices; ++i) {
Value zeroIdx = b.create<arith::ConstantIndexOp>(loc, 0);
regionArgs.push_back(zeroIdx);
}
}

// Clone the region body inline
IRMapping mapping;
for (auto [arg, regionArg] :
llvm::zip_equal(regionArgs, region.front().getArguments())) {
mapping.map(regionArg, arg);
}
for (Operation &op : region.front().without_terminator()) {
b.clone(op, mapping);
}
auto yieldOp =
cast<IREE::LinalgExt::YieldOp>(region.front().getTerminator());
Value result = mapping.lookup(yieldOp.getOperand(0));
linalg::YieldOp::create(b, loc, result);
});

return genericOp.getResult(0);
}

Expand Down Expand Up @@ -258,15 +289,14 @@ static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap,
return genericOp.getResult(0);
}

// Compute output = exp2(output - input)
static Value computeSubAndExp2(OpBuilder &builder, Location loc,
AffineMap inputMap, AffineMap outputMap,
Value input, Value output) {
// Compute output = exp2/exp(output - input) depending on useExp2 flag.
static Value computeSubAndExp(OpBuilder &builder, Location loc,
AffineMap inputMap, AffineMap outputMap,
Value input, Value output, bool useExp2) {
SmallVector<AffineMap> compressedMaps =
compressUnusedDims(SmallVector<AffineMap>{inputMap, outputMap});
inputMap = compressedMaps[0];
outputMap = compressedMaps[1];

SmallVector<utils::IteratorType> iteratorTypes(inputMap.getNumDims(),
utils::IteratorType::parallel);
auto genericOp = linalg::GenericOp::create(
Expand All @@ -277,8 +307,9 @@ static Value computeSubAndExp2(OpBuilder &builder, Location loc,
Value in = convertScalarToDtype(b, loc, args[0], args[1].getType(),
/*isUnsignedCast=*/false);
Value diff = arith::SubFOp::create(b, loc, args[1], in);
Value weight = math::Exp2Op::create(b, loc, diff);
linalg::YieldOp::create(b, loc, weight);
Operation *weight = useExp2 ? math::Exp2Op::create(b, loc, diff)
: math::ExpOp::create(b, loc, diff);
linalg::YieldOp::create(b, loc, weight->getResult(0));
});
return genericOp.getResult(0);
}
Expand Down Expand Up @@ -314,15 +345,18 @@ Value computeQKAndElementwise(Location loc, OpBuilder &b, Value query,
std::optional<AffineMap> maskMap,
SmallVector<OpFoldResult> iterationDomain,
Type sElementType, Region &elementwiseRegion,
DictionaryAttr qkAttrs, bool lowPrecision) {
DictionaryAttr qkAttrs, bool lowPrecision,
bool useExp2) {
MLIRContext *ctx = b.getContext();
// Since we use exp2 for attention instead of the original exp, we have to
// If using exp2 for attention instead of the original exp, we have to
// multiply the scale by log2(e). We use exp2 instead of exp as most platforms
// have better support for exp2 (we verified that we gain some speedup on
// some GPUs).
Value log2e = arith::ConstantOp::create(
b, loc, b.getFloatAttr(scale.getType(), M_LOG2E));
scale = arith::MulFOp::create(b, loc, scale, log2e);
if (useExp2) {
Value log2e = arith::ConstantOp::create(
b, loc, b.getFloatAttr(scale.getType(), M_LOG2E));
scale = arith::MulFOp::create(b, loc, scale, log2e);
}

auto qETy = getElementTypeOrSelf(query.getType());

Expand Down Expand Up @@ -434,9 +468,9 @@ FailureOr<SmallVector<Value>> AttentionOp::decomposeOperation(OpBuilder &b) {
Type f32Type = b.getF32Type();

// ---- QK Matmul + elementwise math ----
Value s = computeQKAndElementwise(loc, b, query, key, getScale(), mask, qMap,
kMap, sMap, getMaskMap(), sizes, f32Type,
getRegion(), qkAttrs, lowPrecision);
Value s = computeQKAndElementwise(
loc, b, query, key, getScale(), mask, qMap, kMap, sMap, getMaskMap(),
sizes, f32Type, getRegion(), qkAttrs, lowPrecision, /*useExp2=*/true);

// ---- Softmax ----

Expand Down Expand Up @@ -476,9 +510,9 @@ FailureOr<SmallVector<Value>> AttentionOp::decomposeOperation(OpBuilder &b) {
// max = rowMax(S)
Value max = reduce<arith::MaximumFOp>(b, loc, sMap, maxMap, s, maxFill);

// P = exp2(S - max)
// P = exp2(S - max).
AffineMap pMap = sMap;
Value p = computeSubAndExp2(b, loc, maxMap, sMap, max, s);
Value p = computeSubAndExp(b, loc, maxMap, sMap, max, s, /*useExp2=*/true);

// sum = rowSum(P)
Value sum = reduce<arith::AddFOp>(b, loc, pMap, sumMap, p, sumFill);
Expand Down Expand Up @@ -528,9 +562,12 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
DictionaryAttr config = getDecompositionConfigAttr();

DictionaryAttr qkAttrs, pvAttrs;
bool useExp2 = true;
if (config) {
qkAttrs = config.getAs<DictionaryAttr>(getQKAttrStr());
pvAttrs = config.getAs<DictionaryAttr>(getPVAttrStr());
if (auto useExp2Attr = config.getAs<BoolAttr>(getUseExp2AttrStr()))
useExp2 = useExp2Attr.getValue();
}

FailureOr<AttentionOpDetail> maybeOpInfo = AttentionOpDetail::get(
Expand All @@ -551,7 +588,7 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
// ---- QK Matmul + elementwise math ----
Value s = computeQKAndElementwise(
loc, b, query, key, getScale(), mask, qMap, kMap, sMap, getMaskMap(),
sizes, elementType, getRegion(), qkAttrs, lowPrecision);
sizes, elementType, getRegion(), qkAttrs, lowPrecision, useExp2);

// TODO: This decomposition should be in a seperate op called
// "online softmax".
Expand All @@ -561,20 +598,21 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
AffineMap maxMap = getMaxMap();
Value newMax = reduce<arith::MaximumFOp>(b, loc, sMap, maxMap, s, oldMax);

// norm = exp2(oldMax - newMax)
// norm = exp2(oldMax - newMax) or exp(oldMax - newMax) depending on useExp2
// normMap = maxMap
AffineMap normMap = getMaxMap();
Value norm = computeSubAndExp2(b, loc, maxMap, normMap, newMax, oldMax);
Value norm =
computeSubAndExp(b, loc, maxMap, normMap, newMax, oldMax, useExp2);

// normSum = norm * oldSum
AffineMap sumMap = getSumMap();
Value normSum = elementwiseValueInPlace<arith::MulFOp>(b, loc, sumMap,
normMap, oldSum, norm);

// P = exp2(S - newMax)
// P = exp2(S - newMax) or exp(S - newMax) depending on useExp2
// PMap = SMap
AffineMap pMap = sMap;
Value p = computeSubAndExp2(b, loc, maxMap, sMap, newMax, s);
Value p = computeSubAndExp(b, loc, maxMap, sMap, newMax, s, useExp2);

// newSum = normSum + rowSum(P)
Value newSum = reduce<arith::AddFOp>(b, loc, pMap, sumMap, p, normSum);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2041,9 +2041,24 @@ LogicalResult AttentionOp::verify() {

auto &block = getRegion().front();
auto blockTys = block.getArgumentTypes();
if (blockTys.size() != 1 && blockTys.size() != 5) {
return attnOp->emitOpError(
"expects either 1 block argument (score) or 5 block arguments "
"(score, b, h, m, n)");
}

if (!isa<FloatType>(blockTys[0]))
return attnOp->emitOpError("block argument 0 should be float");

// If 5 arguments, verify the indices are of index type
if (blockTys.size() == 5) {
for (unsigned i = 1; i < 5; ++i) {
if (!blockTys[i].isIndex()) {
return attnOp->emitOpError("block arguments 1-4 should be index type");
}
}
}

auto yieldOp = dyn_cast<IREE::LinalgExt::YieldOp>(block.getTerminator());
if (!yieldOp) {
return attnOp->emitOpError("expected linalg_ext.yield");
Expand Down Expand Up @@ -2220,14 +2235,25 @@ LogicalResult OnlineAttentionOp::verify() {

Block &block = attnOp.getRegion().front();
auto blockTys = block.getArgumentTypes();
if (blockTys.size() != 1) {
return attnOp->emitOpError("expects single block argument for score");
if (blockTys.size() != 1 && blockTys.size() != 5) {
return attnOp->emitOpError(
"expects either 1 block argument (score) or 5 block arguments "
"(score, b, h, m, n)");
}

if (!isa<FloatType>(blockTys[0])) {
return attnOp->emitOpError("block argument 0 should be float");
}

// If 5 arguments, verify the indices are of index type
if (blockTys.size() == 5) {
for (unsigned i = 1; i < 5; ++i) {
if (!blockTys[i].isIndex()) {
return attnOp->emitOpError("block arguments 1-4 should be index type");
}
}
}

auto yieldOp = dyn_cast<IREE::LinalgExt::YieldOp>(block.getTerminator());
if (!yieldOp) {
return attnOp->emitOpError("expected linalg_ext.yield");
Expand Down
17 changes: 17 additions & 0 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,12 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_Op<"attention",
If an additional mask argument M is included, the result of the first matmul is modified according to:

Q @ K.T += M

Region:
The region body can receive either 1 or 5 block arguments:
- 1 argument (legacy): score (element type of output)
- 5 arguments: score, b (batch index), h (head index), m (query seq index), n (key/value seq index)
The region should yield a single value (the modified score).
}];

let arguments = (ins AnyShaped:$query,
Expand Down Expand Up @@ -914,6 +920,15 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_Op<"online_attention",
it over the entire softmax reduction dimension by:
x, _, sum : results
x = (1 / sum) * x

Region:
The region body receives the following block arguments:
- score: the computed score value from Q @ K.T (element type of output)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can you make this such that the index operands come first and then the score....

Not Nit: The way the attention operation is setup it can increase dimensionality of the operation. So there isnt necessarily one batch dimension, there could be multiple batch dimensions. Same for head/m/n etc. That probably needs to be accounted for in the change. Just having a single batch dimension is not going to work.

- b: batch index (index type)
- h: head index (index type)
- m: query sequence index (index type)
- n: key/value sequence index (index type)
The region should yield a single value (the modified score).
}];

let arguments = (ins AnyShaped:$query,
Expand Down Expand Up @@ -998,6 +1013,8 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_Op<"online_attention",
// Attributes to set on QK and PV matmul after decomposition.
static StringRef getQKAttrStr() { return "qk_attrs"; }
static StringRef getPVAttrStr() { return "pv_attrs"; }
// Flag to control whether to use exp2 (with log2(e) scaling) or exp.
static StringRef getUseExp2AttrStr() { return "use_exp2"; }
}];

let hasCanonicalizer = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,96 @@ func.func @online_attention_f8_masked(%query: tensor<192x1024x64xf8E4M3FNUZ>,
// CHECK: linalg.generic
// CHECK: arith.addf
// CHECK: linalg.yield


// -----

// Spec to decompose online attention op.
module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op
transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> ()
transform.yield
}
}

#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>

func.func @online_attention_f16_noexp2(%query: tensor<192x1024x64xf16>,
%key: tensor<192x1024x64xf16>,
%value: tensor<192x1024x64xf16>,
%output: tensor<192x1024x64xf32>,
%max: tensor<192x1024xf32>,
%sum: tensor<192x1024xf32>)
-> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) {
%scale = arith.constant 1.0 : f16

%out:3 = iree_linalg_ext.online_attention
{decomposition_config = {use_exp2=false}, indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] }
ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16)
outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) {
^bb0(%score: f32):
iree_linalg_ext.yield %score: f32
}
-> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>

return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>
}

// We want to check that we're correctly using exp
// when specified so from the decomposition_config.
// CHECK-LABEL: @online_attention_f16_noexp2
// Q = Q * scale
// CHECK: linalg.generic
// CHECK: arith.mulf
// S = Q @ K
// CHECK: linalg.generic
// CHECK: arith.extf
// CHECK: arith.extf
// CHECK: arith.mulf
// CHECK: arith.addf
// CHECK: linalg.yield
// newMax = max(oldMax, rowMax(S))
// CHECK: linalg.generic
// CHECK-NOT: arith.extf
// CHECK: arith.maximumf
// CHECK: linalg.yield
// norm = exp2(oldMax - newMax)
// CHECK: linalg.generic
// CHECK-NOT: arith.extf
// CHECK: arith.subf
// CHECK-NOT: math.exp2
// CHECK: linalg.yield
// normSum = norm * oldSum
// CHECK: linalg.generic
// CHECK-NOT: arith.extf
// CHECK: arith.mulf
// CHECK: linalg.yield
// P = exp2(S - newMax)
// CHECK: linalg.generic
// CHECK-NOT: arith.extf
// CHECK: arith.subf
// CHECK-NOT: math.exp2
// CHECK: linalg.yield
// newSum = normSum + rowSum(P)
// CHECK: linalg.generic
// CHECK-NOT: arith.extf
// CHECK: arith.addf
// CHECK: linalg.yield
// newAcc = norm * oldAcc
// CHECK: linalg.generic
// CHECK-NOT: arith.extf
// CHECK: arith.mulf
// CHECK: linalg.yield
// newAcc = P @ V + newAcc
// CHECK: linalg.generic
// CHECK: arith.extf
// CHECK: arith.extf
// CHECK: arith.mulf
// CHECK: arith.addf
// CHECK: linalg.yield
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,15 @@ struct DecomposeAttentionPass final
void DecomposeAttentionPass::runOnOperation() {
MLIRContext *context = &getContext();
IRRewriter rewriter(context);
SmallVector<NamedAttribute> decompositionConfigAttrs;
decompositionConfigAttrs.push_back(
rewriter.getNamedAttr("use_exp2", rewriter.getBoolAttr(useExp2)));
DictionaryAttr decompositionConfig =
rewriter.getDictionaryAttr(decompositionConfigAttrs);

getOperation().walk([&](OnlineAttentionOp onlineAtt) {
rewriter.setInsertionPoint(onlineAtt);
onlineAtt.setDecompositionConfigAttr(decompositionConfig);
FailureOr<SmallVector<Value>> results =
onlineAtt.decomposeOperation(rewriter);
if (failed(results)) {
Expand Down
Loading
Loading