Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ static Value applyPostQKMatmulElementwise(OpBuilder &builder, Location loc,
}

static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap,
AffineMap maskMap, Value qk, Value mask) {
AffineMap maskMap, Value qk, Value mask, bool useExp2) {

SmallVector<AffineMap> compressedMaps =
compressUnusedDims(SmallVector<AffineMap>{qkMap, maskMap});
Expand Down Expand Up @@ -245,9 +245,11 @@ static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap,
maskVal = convertScalarToDtype(b, loc, maskVal, qkVal.getType(),
/*isUnsignedCast=*/false);
// Scaling to compensate for base-2 softmax
Value log2e = arith::ConstantOp::create(
b, loc, b.getFloatAttr(qkVal.getType(), M_LOG2E));
maskVal = arith::MulFOp::create(b, loc, maskVal, log2e);
if (useExp2) {
Value log2e = arith::ConstantOp::create(
b, loc, b.getFloatAttr(qkVal.getType(), M_LOG2E));
maskVal = arith::MulFOp::create(b, loc, maskVal, log2e);
}
}
// Finally, set the returned value to the qk element plus the mask
// element (or 0/-infinity if bool mask). We opt for a AddFOp (instead
Expand All @@ -260,10 +262,10 @@ static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap,
return genericOp.getResult(0);
}

// Compute output = exp2(output - input)
static Value computeSubAndExp2(OpBuilder &builder, Location loc,
AffineMap inputMap, AffineMap outputMap,
Value input, Value output) {
// Compute output = exp2/exp(output - input) depending on useExp2 flag.
static Value computeSubAndExp(OpBuilder &builder, Location loc,
AffineMap inputMap, AffineMap outputMap,
Value input, Value output, bool useExp2) {
SmallVector<AffineMap> compressedMaps =
compressUnusedDims(SmallVector<AffineMap>{inputMap, outputMap});
inputMap = compressedMaps[0];
Expand All @@ -279,8 +281,9 @@ static Value computeSubAndExp2(OpBuilder &builder, Location loc,
Value in = convertScalarToDtype(b, loc, args[0], args[1].getType(),
/*isUnsignedCast=*/false);
Value diff = arith::SubFOp::create(b, loc, args[1], in);
Value weight = math::Exp2Op::create(b, loc, diff);
linalg::YieldOp::create(b, loc, weight);
Operation *weight = useExp2 ? math::Exp2Op::create(b, loc, diff)
: math::ExpOp::create(b, loc, diff);
linalg::YieldOp::create(b, loc, weight->getResult(0));
});
return genericOp.getResult(0);
}
Expand Down Expand Up @@ -316,15 +319,18 @@ Value computeQKAndElementwise(Location loc, OpBuilder &b, Value query,
std::optional<AffineMap> maskMap,
SmallVector<OpFoldResult> iterationDomain,
Type sElementType, Region &elementwiseRegion,
DictionaryAttr qkAttrs, bool lowPrecision) {
DictionaryAttr qkAttrs, bool lowPrecision,
bool useExp2) {
MLIRContext *ctx = b.getContext();
// Since we use exp2 for attention instead of the original exp, we have to
// If using exp2 for attention instead of the original exp, we have to
// multiply the scale by log2(e). We use exp2 instead of exp as most platforms
// have better support for exp2 (we verified that we gain some speedup on
// some GPUs).
Value log2e = arith::ConstantOp::create(
b, loc, b.getFloatAttr(scale.getType(), M_LOG2E));
scale = arith::MulFOp::create(b, loc, scale, log2e);
if (useExp2) {
Value log2e = arith::ConstantOp::create(
b, loc, b.getFloatAttr(scale.getType(), M_LOG2E));
scale = arith::MulFOp::create(b, loc, scale, log2e);
}

auto qETy = getElementTypeOrSelf(query.getType());

Expand Down Expand Up @@ -392,7 +398,7 @@ Value computeQKAndElementwise(Location loc, OpBuilder &b, Value query,

// S += mask
if (mask != nullptr) {
s = applyMask(b, loc, sMap, *maskMap, s, mask.value());
s = applyMask(b, loc, sMap, *maskMap, s, mask.value(), useExp2);
}

return s;
Expand Down Expand Up @@ -436,9 +442,9 @@ FailureOr<SmallVector<Value>> AttentionOp::decomposeOperation(OpBuilder &b) {
Type f32Type = b.getF32Type();

// ---- QK Matmul + elementwise math ----
Value s = computeQKAndElementwise(loc, b, query, key, getScale(), mask, qMap,
kMap, sMap, getMaskMap(), sizes, f32Type,
getRegion(), qkAttrs, lowPrecision);
Value s = computeQKAndElementwise(
loc, b, query, key, getScale(), mask, qMap, kMap, sMap, getMaskMap(),
sizes, f32Type, getRegion(), qkAttrs, lowPrecision, /*useExp2=*/true);

// ---- Softmax ----

Expand Down Expand Up @@ -480,7 +486,7 @@ FailureOr<SmallVector<Value>> AttentionOp::decomposeOperation(OpBuilder &b) {

// P = exp2(S - max)
AffineMap pMap = sMap;
Value p = computeSubAndExp2(b, loc, maxMap, sMap, max, s);
Value p = computeSubAndExp(b, loc, maxMap, sMap, max, s, /*useExp2=*/true);

// sum = rowSum(P)
Value sum = reduce<arith::AddFOp>(b, loc, pMap, sumMap, p, sumFill);
Expand Down Expand Up @@ -530,9 +536,13 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
DictionaryAttr config = getDecompositionConfigAttr();

DictionaryAttr qkAttrs, pvAttrs;
bool useExp2 = true;
if (config) {
qkAttrs = config.getAs<DictionaryAttr>(getQKAttrStr());
pvAttrs = config.getAs<DictionaryAttr>(getPVAttrStr());
if (auto useExp2Attr = config.getAs<BoolAttr>(getUseExp2AttrStr())) {
useExp2 = useExp2Attr.getValue();
}
}

FailureOr<AttentionOpDetail> maybeOpInfo = AttentionOpDetail::get(
Expand All @@ -553,7 +563,7 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
// ---- QK Matmul + elementwise math ----
Value s = computeQKAndElementwise(
loc, b, query, key, getScale(), mask, qMap, kMap, sMap, getMaskMap(),
sizes, elementType, getRegion(), qkAttrs, lowPrecision);
sizes, elementType, getRegion(), qkAttrs, lowPrecision, useExp2);

// TODO: This decomposition should be in a seperate op called
// "online softmax".
Expand All @@ -563,20 +573,21 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
AffineMap maxMap = getMaxMap();
Value newMax = reduce<arith::MaximumFOp>(b, loc, sMap, maxMap, s, oldMax);

// norm = exp2(oldMax - newMax)
// norm = exp2(oldMax - newMax) or exp(oldMax - newMax) depending on useExp2
// normMap = maxMap
AffineMap normMap = getMaxMap();
Value norm = computeSubAndExp2(b, loc, maxMap, normMap, newMax, oldMax);
Value norm =
computeSubAndExp(b, loc, maxMap, normMap, newMax, oldMax, useExp2);

// normSum = norm * oldSum
AffineMap sumMap = getSumMap();
Value normSum = elementwiseValueInPlace<arith::MulFOp>(b, loc, sumMap,
normMap, oldSum, norm);

// P = exp2(S - newMax)
// P = exp2(S - newMax) or exp(S - newMax) depending on useExp2
// PMap = SMap
AffineMap pMap = sMap;
Value p = computeSubAndExp2(b, loc, maxMap, sMap, newMax, s);
Value p = computeSubAndExp(b, loc, maxMap, sMap, newMax, s, useExp2);

// newSum = normSum + rowSum(P)
Value newSum = reduce<arith::AddFOp>(b, loc, pMap, sumMap, p, normSum);
Expand Down Expand Up @@ -1211,11 +1222,11 @@ FailureOr<SmallVector<Value>> ExpReductionOp::decomposeOperation(OpBuilder &b) {
Value currMax = reduce<arith::MaximumFOp>(
rewriter, loc, normValMap, prevMaxMap, sValue->get(), prevMax->get());
// ex = e^{sValue - curr_max}
Value ex = computeSubAndExp2(rewriter, loc, prevMaxMap, normValMap, currMax,
sValue->get());
Value ex = computeSubAndExp(rewriter, loc, prevMaxMap, normValMap, currMax,
sValue->get(), /*useExp2=*/true);
// norm = e^(prev_max - curr_max)
Value norm = computeSubAndExp2(rewriter, loc, prevMaxMap, prevMaxMap, currMax,
prevMax->get());
Value norm = computeSubAndExp(rewriter, loc, prevMaxMap, prevMaxMap, currMax,
prevMax->get(), /*useExp2=*/true);

SmallVector<Value> inputs = getDpsInputs();
SmallVector<Value> normOuts(getNumDpsInits());
Expand Down
12 changes: 12 additions & 0 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,16 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_Op<"online_attention",
it over the entire softmax reduction dimension by:
x, _, sum : results
x = (1 / sum) * x

Decomposition Configuration:
The `decomposition_config` attribute is a DictionaryAttr that controls how
this operation is decomposed into lower-level operations. It supports:
- "qk_attrs": DictionaryAttr - Attributes to attach to the Q@K matmul
operation after decomposition (e.g., lowering_config, attention markers)
- "pv_attrs": DictionaryAttr - Attributes to attach to the P@V matmul
operation after decomposition
- "use_exp2": BoolAttr - If true, uses exp2 with log2(e) scaling instead
of exp. (Gives better perf on some hardware, but trades off accuracy)
}];

let arguments = (ins AnyShaped:$query,
Expand Down Expand Up @@ -1153,6 +1163,8 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_Op<"online_attention",
// Attributes to set on QK and PV matmul after decomposition.
static StringRef getQKAttrStr() { return "qk_attrs"; }
static StringRef getPVAttrStr() { return "pv_attrs"; }
// Flag to control whether to use exp2 (with log2(e) scaling) or exp.
static StringRef getUseExp2AttrStr() { return "use_exp2"; }
}];

let hasCanonicalizer = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ func.func @online_attention_f16(%query: tensor<192x1024x64xf16>,
// correct number of extf/truncfs are emitted.
// CHECK-LABEL: @online_attention_f16
// Q = Q * scale
// CHECK: arith.constant 1.442380e+00 : f16
// CHECK: linalg.generic
// CHECK: arith.mulf
// S = Q @ K
Expand Down Expand Up @@ -419,6 +420,65 @@ func.func @online_attention_f8_masked(%query: tensor<192x1024x64xf8E4M3FNUZ>,

// -----

// Spec to decompose online attention op.
module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op
transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> ()
transform.yield
}
}

#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>

func.func @online_attention_f16_noexp2(%query: tensor<192x1024x64xf16>,
%key: tensor<192x1024x64xf16>,
%value: tensor<192x1024x64xf16>,
%output: tensor<192x1024x64xf32>,
%max: tensor<192x1024xf32>,
%sum: tensor<192x1024xf32>)
-> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) {
%scale = arith.constant 1.0 : f16

%out:3 = iree_linalg_ext.online_attention
{decomposition_config = {use_exp2=false}, indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] }
ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16)
outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) {
^bb0(%score: f32):
iree_linalg_ext.yield %score: f32
}
-> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>

return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>
}

// We want to check that we're correctly using exp
// when specified so from the decomposition_config.
// CHECK-LABEL: @online_attention_f16_noexp2
// Q = Q * scale
// CHECK: arith.constant 1.000000e+00 : f16
// CHECK: linalg.generic
// CHECK: arith.mulf
// norm = exp (oldMax - newMax)
// CHECK: linalg.generic
// CHECK: arith.subf
// CHECK-NOT: arith.extf
// CHECK-NOT: math.exp2
// CHECK: linalg.yield
// P = exp(S - newMax)
// CHECK: linalg.generic
// CHECK: arith.subf
// CHECK-NOT: arith.extf
// CHECK-NOT: math.exp2
// CHECK: linalg.yield

// -----

// Spec to decompose exp reduction op.
module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,15 @@ struct DecomposeAttentionPass final
void DecomposeAttentionPass::runOnOperation() {
MLIRContext *context = &getContext();
IRRewriter rewriter(context);

getOperation().walk([&](OnlineAttentionOp onlineAtt) {
rewriter.setInsertionPoint(onlineAtt);

NamedAttrList decompositionConfig(onlineAtt.getDecompositionConfigAttr());
decompositionConfig.set("use_exp2", rewriter.getBoolAttr(useExp2));
onlineAtt.setDecompositionConfigAttr(
decompositionConfig.getDictionary(context));

FailureOr<SmallVector<Value>> results =
onlineAtt.decomposeOperation(rewriter);
if (failed(results)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ def DecomposeAttentionPass :
InterfacePass<"iree-linalg-ext-decompose-attention", "mlir::FunctionOpInterface"> {
let summary =
"Decomposes attention op into a sequence of linalg ops";
let options = [
Option<"useExp2", "use-exp2", "bool", /*default=*/"true",
"Use exp2 for computations; Tunable to allow for accuracte computations"
"in case of accuracy losses due to fp-reassociation.">,
];
}

def ConvertAttentionToOnlineAttentionPass :
Expand Down
Loading