Skip to content

Commit 1ee8659

Browse files
author
Tanyo Kwok
authored
[MHLO] fix tensor mode aten.div op pattern (#1160)
* [MHLO] fix tensor mode aten.div op pattern See RFC #999 Co-authored-by: Bairen Yi <[email protected]> Co-authored-by: Jiawei Wu <[email protected]> Co-authored-by: Tianyou Guo <[email protected]> Co-authored-by: Xu Yan <[email protected]> Co-authored-by: Ziheng Jiang <[email protected]>
1 parent 5618890 commit 1ee8659

File tree

4 files changed

+72
-7
lines changed

4 files changed

+72
-7
lines changed

lib/Conversion/TorchToMhlo/Basic.cpp

+32-4
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
208208
"only floating-point or integer datatype legalization supported");
209209
}
210210

211-
Value lhsTensor = lhs;
212211
if (std::is_same<AtenOpT, AtenSquareOp>()) {
213212
rhs = lhs;
214213
} else if (!rhsType) {
@@ -217,8 +216,37 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
217216
DenseIntElementsAttr bcastDimensions;
218217
lhs = mhlo::promoteType(rewriter, lhs, outType);
219218
rhs = mhlo::promoteType(rewriter, rhs, outType);
220-
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
221-
bcastDimensions);
219+
auto loc = op.getLoc();
220+
Value result =
221+
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
222+
223+
if (!isa<AtenDivTensorModeOp>(op)) {
224+
rewriter.replaceOp(op, result);
225+
return success();
226+
}
227+
228+
AtenDivTensorModeOp divTensorModeOp =
229+
llvm::dyn_cast<AtenDivTensorModeOp>(op.getOperation());
230+
std::string roundingMode;
231+
if (!matchPattern(divTensorModeOp.rounding_mode(),
232+
m_TorchConstantStr(roundingMode)))
233+
return rewriter.notifyMatchFailure(
234+
op, "only support constant str rounding mode");
235+
236+
if (roundingMode == "trunc") {
237+
// "trunc" - rounds the results of the division towards zero. Equivalent
238+
// to C-style integer division.
239+
auto sign = rewriter.create<mhlo::SignOp>(loc, result);
240+
auto abs = rewriter.create<mhlo::AbsOp>(loc, result);
241+
auto floor = rewriter.create<mhlo::FloorOp>(loc, abs);
242+
result = rewriter.create<mhlo::MulOp>(loc, sign, floor).getResult();
243+
}
244+
if (roundingMode == "floor") {
245+
// "floor" - rounds the results of the division down. Equivalent to
246+
// floor division in Python (the // operator)
247+
result = rewriter.create<mhlo::FloorOp>(loc, result).getResult();
248+
}
249+
rewriter.replaceOp(op, result);
222250
return success();
223251
}
224252
};
@@ -554,7 +582,6 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
554582
RankedTensorType outputType = getTypeConverter()
555583
->convertType(op->getResult(0).getType())
556584
.cast<RankedTensorType>();
557-
auto outputShape = outputType.getShape();
558585
auto outputElemType = outputType.getElementType();
559586
Value mhloTensor =
560587
mhlo::scalarToMhloTensor(rewriter, op, adaptor.a(), outputElemType);
@@ -968,6 +995,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
968995
INSERT_BINARY_MULDIV_PATTERN(AtenMulTensorOp, chlo::BroadcastMulOp);
969996
INSERT_BINARY_MULDIV_PATTERN(AtenMulScalarOp, chlo::BroadcastMulOp);
970997
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp);
998+
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorModeOp, chlo::BroadcastDivOp);
971999
INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarOp, chlo::BroadcastDivOp);
9721000
#undef INSERT_BINARY_MULDIV_PATTERN
9731001

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -2167,8 +2167,11 @@ class DecomposeAtenFloorDivideOp : public OpRewritePattern<AtenFloorDivideOp> {
21672167
using OpRewritePattern::OpRewritePattern;
21682168
LogicalResult matchAndRewrite(AtenFloorDivideOp op,
21692169
PatternRewriter &rewriter) const override {
2170+
// https://pytorch.org/docs/stable/generated/torch.floor_divide.html
2171+
// PyTorch aten.floor_divide is a misnomer because it actually rounds
2172+
// the quotient towards zero instead of taking its floor.
21702173
Value cstStrFloor =
2171-
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "floor");
2174+
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "trunc");
21722175
rewriter.replaceOpWithNewOp<AtenDivTensorModeOp>(
21732176
op, op.getType(), op.self(), op.other(),
21742177
/*rounding_mode=*/cstStrFloor);

test/Conversion/TorchToMhlo/elementwise.mlir

+34
Original file line numberDiff line numberDiff line change
@@ -540,3 +540,37 @@ func.func @torch.aten.gt.scalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1
540540
return %0 : !torch.vtensor<[?,?],i1>
541541
}
542542

543+
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$trunc(
544+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
545+
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
546+
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
547+
// CHECK: %[[STR:.*]] = torch.constant.str "trunc"
548+
// CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
549+
// CHECK: %[[T3:.*]] = mhlo.sign %[[T2]] : tensor<?x?x?x?xf32>
550+
// CHECK: %[[T4:.*]] = mhlo.abs %[[T2]] : tensor<?x?x?x?xf32>
551+
// CHECK: %[[T5:.*]] = mhlo.floor %[[T4]] : tensor<?x?x?x?xf32>
552+
// CHECK: %[[T6:.*]] = mhlo.multiply %[[T3]], %[[T5]] : tensor<?x?x?x?xf32>
553+
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
554+
// CHECK: return %[[T7]] : !torch.vtensor<[?,?,?,?],f32>
555+
func.func @torch.aten.div.Tensor_mode$trunc(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
556+
%str = torch.constant.str "trunc"
557+
%0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.str -> !torch.vtensor<[?,?,?,?],f32>
558+
return %0 : !torch.vtensor<[?,?,?,?],f32>
559+
}
560+
561+
// -----
562+
563+
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$floor(
564+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
565+
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
566+
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
567+
// CHECK: %[[STR:.*]] = torch.constant.str "floor"
568+
// CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
569+
// CHECK: %[[T3:.*]] = mhlo.floor %[[T2]] : tensor<?x?x?x?xf32>
570+
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
571+
// CHECK: return %[[T4]] : !torch.vtensor<[?,?,?,?],f32>
572+
func.func @torch.aten.div.Tensor_mode$floor(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
573+
%str = torch.constant.str "floor"
574+
%0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.str -> !torch.vtensor<[?,?,?,?],f32>
575+
return %0 : !torch.vtensor<[?,?,?,?],f32>
576+
}

test/Dialect/Torch/decompose-complex-ops.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -1113,8 +1113,8 @@ func.func @torch.aten.baddbmm(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.
11131113
// CHECK-LABEL: func @torch.aten.floor_divide(
11141114
// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?],f32>,
11151115
// CHECK-SAME: %[[OTHER:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
1116-
// CHECK: %[[CSTFLOOR:.*]] = torch.constant.str "floor"
1117-
// CHECK: %[[OUT:.*]] = torch.aten.div.Tensor_mode %[[SELF]], %[[OTHER]], %[[CSTFLOOR]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.str -> !torch.vtensor<[?,?],f32>
1116+
// CHECK: %[[CSTTRUNC:.*]] = torch.constant.str "trunc"
1117+
// CHECK: %[[OUT:.*]] = torch.aten.div.Tensor_mode %[[SELF]], %[[OTHER]], %[[CSTTRUNC]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.str -> !torch.vtensor<[?,?],f32>
11181118
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?],f32>
11191119
func.func @torch.aten.floor_divide(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
11201120
%0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>

0 commit comments

Comments
 (0)