Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4016,8 +4016,28 @@ LogicalResult ConvertAtenOp<AtenTransposeIntOp>::matchAndRewrite(
transposedDims[dim0] = dim1;
transposedDims[dim1] = dim0;

Type resultType = getTypeConverter()->convertType(op.getType());
if (auto rankedSelf = dyn_cast<RankedTensorType>(selfType)) {
SmallVector<int64_t> transposedShape(rankedSelf.getRank(),
ShapedType::kDynamic);
if (rankedSelf.hasStaticShape()) {
auto staticShape =
llvm::to_vector(makeShapeTorchCompatible(rankedSelf.getShape()));
auto dim0Index = static_cast<size_t>(dim0);
auto dim1Index = static_cast<size_t>(dim1);
if (dim0Index < staticShape.size() && dim1Index < staticShape.size())
std::swap(staticShape[dim0Index], staticShape[dim1Index]);
for (size_t i = 0; i < staticShape.size(); ++i)
transposedShape[i] = staticShape[i];
}
auto rankedResult = RankedTensorType::get(
makeShapeLLVMCompatible(transposedShape), rankedSelf.getElementType());
if (auto converted = getTypeConverter()->convertType(rankedResult))
resultType = converted;
}

rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
op, resultType, adaptor.getSelf(),
rewriter.getDenseI32ArrayAttr(transposedDims));

return success();
Expand Down Expand Up @@ -9402,6 +9422,21 @@ void populateTorchToTosaConversionLegalOps(ConversionTarget &target) {
target.addLegalOp<ConstantDeviceOp>();
target.addLegalOp<PrimListConstructOp>();
target.addLegalOp<PrimTupleConstructOp>();
target.addDynamicallyLegalOp<tensor::CastOp>([](tensor::CastOp op) -> bool {
auto sourceType = dyn_cast<RankedTensorType>(op.getSource().getType());
auto resultType = dyn_cast<RankedTensorType>(op.getType());
if (!sourceType || !resultType)
return true;
if (sourceType.getElementType() != resultType.getElementType())
return true;
if (!sourceType.hasStaticShape())
return true;
if (!resultType.hasStaticShape())
return true;
if (sourceType == resultType)
return true;
return false;
});
}

std::set<StringRef> populateTorchToTosaConversionPatternsAndIllegalOps(
Expand Down
228 changes: 219 additions & 9 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2295,17 +2295,223 @@ class DecomposeAtenTraceOp : public OpRewritePattern<AtenTraceOp> {
};
} // namespace

static Value getSoftmaxResult(Operation *op, Value self, Value dim,
Type resultType, Type accumulatorType,
PatternRewriter &rewriter);

namespace {
// Decompose scaled dot product attention into matmul/softmax pipeline when
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this decomposition producing any different IR compared to leveraging the decomposition of sdpa with ExportedProgram.run_decompositions https://docs.pytorch.org/docs/stable/export.html#export-ir-decompositions -- see https://discord.com/channels/636084430946959380/742573221882364009/1446121930922004623 for reference.

I am wondering if the sdpa op should be added to the default decomposition list in

DEFAULT_DECOMPOSITIONS = [
instead?

// there is no masking, dropout, causal, or GQA behaviour.
class DecomposeAtenScaledDotProductAttentionOp
: public OpRewritePattern<AtenScaledDotProductAttentionOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenScaledDotProductAttentionOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();

if (!isa<Torch::NoneType>(op.getAttnMask().getType()))
return rewriter.notifyMatchFailure(
op, "attention mask decomposition not implemented");

double dropoutP;
if (!matchPattern(op.getDropoutP(), m_TorchConstantFloat(&dropoutP)) ||
dropoutP != 0.0)
return rewriter.notifyMatchFailure(
op, "expected dropout_p to be the constant 0.0");

bool isCausal;
if (!matchPattern(op.getIsCausal(), m_TorchConstantBool(&isCausal)) ||
isCausal)
return rewriter.notifyMatchFailure(op,
"causal attention not supported yet");

bool enableGqa;
if (!matchPattern(op.getEnableGqa(), m_TorchConstantBool(&enableGqa)) ||
enableGqa)
return rewriter.notifyMatchFailure(op,
"grouped-query attention unsupported");

Value query = op.getQuery();
Value key = op.getKey();
Value value = op.getValue();

auto queryTensorType = dyn_cast<BaseTensorType>(query.getType());
auto keyTensorType = dyn_cast<BaseTensorType>(key.getType());
auto valueTensorType = dyn_cast<BaseTensorType>(value.getType());
if (!queryTensorType || !keyTensorType || !valueTensorType)
return rewriter.notifyMatchFailure(op, "expected tensor inputs");
if (!queryTensorType.hasSizes() || !keyTensorType.hasSizes() ||
!valueTensorType.hasSizes())
return rewriter.notifyMatchFailure(
op, "expected tensor inputs to have known shapes");
auto queryValueTensorType = dyn_cast<ValueTensorType>(queryTensorType);
auto keyValueTensorType = dyn_cast<ValueTensorType>(keyTensorType);
auto valueValueTensorType = dyn_cast<ValueTensorType>(valueTensorType);
if (!queryValueTensorType || !keyValueTensorType || !valueValueTensorType)
return rewriter.notifyMatchFailure(op, "expected value tensor semantics");
if (!queryValueTensorType.hasDtype() || !keyValueTensorType.hasDtype() ||
!valueValueTensorType.hasDtype())
return rewriter.notifyMatchFailure(
op, "expected tensor inputs to have dtypes");
Type queryDtype = queryValueTensorType.getOptionalDtype();
if (queryDtype != keyValueTensorType.getOptionalDtype() ||
queryDtype != valueValueTensorType.getOptionalDtype())
return rewriter.notifyMatchFailure(
op, "expected query, key, and value to share dtype");

Value oneInt =
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1));
Value zeroInt =
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0));
Value rank = AtenDimOp::create(rewriter, loc, query);
Value lastDim = AtenSubIntOp::create(rewriter, loc, rank, oneInt);
Value headDim = AtenSizeIntOp::create(rewriter, loc, query, lastDim);
Value seqDimIndex = AtenSubIntOp::create(rewriter, loc, lastDim, oneInt);
Value seqLen = AtenSizeIntOp::create(rewriter, loc, query, seqDimIndex);
Value keySeqLen = AtenSizeIntOp::create(rewriter, loc, key, seqDimIndex);
ArrayRef<int64_t> querySizes = queryValueTensorType.getSizes();
int64_t queryRank = querySizes.size();
if (queryRank < 3 || queryRank > 4)
return rewriter.notifyMatchFailure(
op, "expected query tensor rank to be 3 or 4");
ArrayRef<int64_t> keySizes = keyValueTensorType.getSizes();
ArrayRef<int64_t> valueSizes = valueValueTensorType.getSizes();
if (static_cast<int64_t>(keySizes.size()) != queryRank ||
static_cast<int64_t>(valueSizes.size()) != queryRank)
return rewriter.notifyMatchFailure(
op, "expected query, key, and value to share rank");
bool hasExplicitHeadDim = queryRank == 4;
Value numHeadsSize =
hasExplicitHeadDim
? (Value)AtenSizeIntOp::create(rewriter, loc, query, oneInt)
: oneInt;
Value batchSize = AtenSizeIntOp::create(rewriter, loc, query, zeroInt);
auto listIntType =
Torch::ListType::get(Torch::IntType::get(rewriter.getContext()));

auto getDimValue = [&](int64_t staticDim, Value fallback) -> Value {
if (staticDim != Torch::kUnknownSize)
return ConstantIntOp::create(rewriter, loc,
rewriter.getI64IntegerAttr(staticDim));
return fallback;
};

Value scaleFloat;
if (isa<Torch::NoneType>(op.getScale().getType())) {
Value sqrtHeadDim = AtenSqrtIntOp::create(rewriter, loc, headDim);
Value oneFloat =
ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0));
scaleFloat = AtenDivFloatOp::create(rewriter, loc, oneFloat, sqrtHeadDim);
} else {
scaleFloat = op.getScale();
}

Value negTwo =
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-2));
Value negOne =
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-1));

SmallVector<int64_t> keyTransposedSizes(keySizes.begin(), keySizes.end());
if (keyTransposedSizes.size() < 2)
return rewriter.notifyMatchFailure(
op, "expected key tensor rank >= 2 for transpose");
std::swap(keyTransposedSizes[keyTransposedSizes.size() - 1],
keyTransposedSizes[keyTransposedSizes.size() - 2]);
ArrayRef<int64_t> keyTransposedRef(keyTransposedSizes);
std::optional<ArrayRef<int64_t>> keyTransposedOpt(keyTransposedRef);
Type keyTransposedType = keyValueTensorType.getWithSizesAndDtypeAndSparsity(
keyTransposedSizes, keyValueTensorType.getOptionalDtype(),
keyValueTensorType.getOptionalSparsity());
Value keyTransposed = AtenTransposeIntOp::create(
rewriter, loc, keyTransposedType, key, negTwo, negOne);
SmallVector<Value> keyDims;
auto getOrFallback = [&](ArrayRef<int64_t> staticDims, unsigned idx,
Value fallback) -> Value {
return getDimValue(idx < staticDims.size() ? staticDims[idx]
: Torch::kUnknownSize,
fallback);
};
keyDims.push_back(getOrFallback(keyTransposedSizes, 0, batchSize));
if (hasExplicitHeadDim) {
keyDims.push_back(getOrFallback(keyTransposedSizes, 1, numHeadsSize));
keyDims.push_back(getOrFallback(keyTransposedSizes, 2, headDim));
keyDims.push_back(getOrFallback(keyTransposedSizes, 3, keySeqLen));
} else {
keyDims.push_back(getOrFallback(keyTransposedSizes, 1, headDim));
keyDims.push_back(getOrFallback(keyTransposedSizes, 2, keySeqLen));
}
Value keyTransposeShapeList =
PrimListConstructOp::create(rewriter, loc, listIntType, keyDims);
keyTransposed = AtenViewOp::create(rewriter, loc, keyTransposedType,
keyTransposed, keyTransposeShapeList);

auto getStaticDim = [](ArrayRef<int64_t> sizes, int64_t index) {
if (index < 0)
index += sizes.size();
if (index < 0 || index >= static_cast<int64_t>(sizes.size()))
return Torch::kUnknownSize;
return sizes[index];
};
int64_t queryBatchStatic = getStaticDim(querySizes, 0);
int64_t querySeqStatic = getStaticDim(querySizes, -2);
int64_t keySeqStatic = getStaticDim(keySizes, -2);
int64_t queryHeadsStatic =
hasExplicitHeadDim ? getStaticDim(querySizes, 1) : 1;
SmallVector<int64_t, 4> scoresSizes;
if (hasExplicitHeadDim)
scoresSizes.assign(
{queryBatchStatic, queryHeadsStatic, querySeqStatic, keySeqStatic});
else
scoresSizes.assign({queryBatchStatic, querySeqStatic, keySeqStatic});
Type scoresType = ValueTensorType::get(
op->getContext(),
ArrayRef<int64_t>(scoresSizes.begin(), scoresSizes.end()),
queryValueTensorType.getOptionalDtype(),
queryValueTensorType.getOptionalSparsity());
Value scores =
AtenMatmulOp::create(rewriter, loc, scoresType, query, keyTransposed);
SmallVector<Value> scoresDims;
scoresDims.push_back(getDimValue(scoresSizes[0], batchSize));
unsigned seqIndex = 1;
if (hasExplicitHeadDim) {
scoresDims.push_back(getDimValue(scoresSizes[1], numHeadsSize));
seqIndex = 2;
}
scoresDims.push_back(getDimValue(scoresSizes[seqIndex], seqLen));
scoresDims.push_back(getDimValue(scoresSizes.back(), keySeqLen));
Value scoresShapeList =
PrimListConstructOp::create(rewriter, loc, listIntType, scoresDims);
scores =
AtenViewOp::create(rewriter, loc, scoresType, scores, scoresShapeList);
Value scaledScores =
AtenMulScalarOp::create(rewriter, loc, scoresType, scores, scaleFloat);

Value softmax = getSoftmaxResult(op.getOperation(), scaledScores, negOne,
scoresType, scoresType, rewriter);
if (!softmax)
return rewriter.notifyMatchFailure(op,
"failed to compute softmax scores");

Value output =
AtenMatmulOp::create(rewriter, loc, op.getType(), softmax, value);

rewriter.replaceOp(op, output);
return success();
}
};
} // namespace

// Calculates the softmax function on the given `input` tensor. Softmax(x) =
// exp(x)/sum(exp(x)).
// To avoid overflow we use the following decomposition rule:
// x_max = max(input, dim, keepdim = True)
// unnorm = aten.exp(input - x_max)
// softmax = unnorm / sum(unnorm, dim, keepdim = True)
template <typename OpTy>
static Value getSoftmaxResult(OpTy op, Value self, Type resultType,
Type accumulatorType, PatternRewriter &rewriter) {
Location loc = op.getLoc();
Value dim = op.getDim();
static Value getSoftmaxResult(Operation *op, Value self, Value dim,
Type resultType, Type accumulatorType,
PatternRewriter &rewriter) {
Location loc = op->getLoc();
if (resultType != accumulatorType)
self = convertTensorToDtype(rewriter, loc, self, accumulatorType);
Value xMax =
Expand Down Expand Up @@ -2362,8 +2568,9 @@ class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {

Type accumulatorTensorType = getDefaultAccType(rewriter, resultTensorDtype);

Value result = getSoftmaxResult(op, self, resultTensorType,
accumulatorTensorType, rewriter);
Value result =
getSoftmaxResult(op.getOperation(), self, op.getDim(), resultTensorType,
accumulatorTensorType, rewriter);
if (!result)
return failure();
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
Expand Down Expand Up @@ -2411,8 +2618,9 @@ class DecomposeAten_SoftmaxOp : public OpRewritePattern<Aten_SoftmaxOp> {

Type accumulatorTensorType = getDefaultAccType(rewriter, resultTensorDtype);

Value result = getSoftmaxResult(op, self, resultTensorType,
accumulatorTensorType, rewriter);
Value result =
getSoftmaxResult(op.getOperation(), self, op.getDim(), resultTensorType,
accumulatorTensorType, rewriter);
if (!result)
return op.emitError("failed to get softmax result");
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, resultTensorType,
Expand Down Expand Up @@ -13084,6 +13292,8 @@ class DecomposeComplexOpsPass
legalOpsSet.clear();
legalOpsSet.insert(legalOps.begin(), legalOps.end());

patterns.add<DecomposeAtenScaledDotProductAttentionOp>(context);

addPatternIfTargetOpIsIllegal<DecomposeAten_WeightNormInterfaceOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftmaxIntOp>(patterns);
Expand Down
20 changes: 20 additions & 0 deletions lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"

using namespace mlir;
Expand Down Expand Up @@ -40,6 +41,25 @@ static void setupValueTensorToBuiltinTensorConversion(
return {};
return ToBuiltinTensorOp::create(builder, loc, type, inputs[0]);
});
typeConverter.addTargetMaterialization([](OpBuilder &builder, Type type,
ValueRange inputs,
Location loc) -> Value {
if (inputs.size() != 1)
return Value();
auto fromType = dyn_cast<RankedTensorType>(inputs[0].getType());
auto toType = dyn_cast<RankedTensorType>(type);
if (!fromType || !toType)
return Value();
if (fromType == toType)
return inputs[0];
if (fromType.getElementType() != toType.getElementType())
return Value();
if (!toType.hasStaticShape())
return Value();
if (!tensor::CastOp::areCastCompatible(inputs[0].getType(), toType))
return Value();
return tensor::CastOp::create(builder, loc, toType, inputs[0]);
});
auto sourceMaterialization = [](OpBuilder &builder,
Torch::ValueTensorType type,
ValueRange inputs, Location loc) -> Value {
Expand Down
10 changes: 0 additions & 10 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,8 @@
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScaledDotProductAttentionSameModule_basic",
}

LINALG_CRASHING_SET = {
Expand Down Expand Up @@ -953,11 +950,8 @@
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScaledDotProductAttentionSameModule_basic",
"SubIntModule_basic",
"TensorToIntZeroRank_basic",
"UpSampleNearest2dDynamicFactor_basic",
Expand Down Expand Up @@ -3978,11 +3972,8 @@
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScaledDotProductAttentionSameModule_basic",
"ScaledDotProductAttentionGQAModule_basic",
# error: 'tosa.scatter' op requires dimensions K >= W
"IndexPut1DFloatNonAccumulateModule_basic",
Expand Down Expand Up @@ -4887,7 +4878,6 @@
# REMOVE WHEN ENABLE_GQA IS ADDED
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScatterAddDynamicModule_basic",
"ScatterReduceFloatMaxModule",
"ScatterReduceFloatMaxModuleIncludeSelf",
Expand Down
Loading