diff --git a/externals/llvm-project b/externals/llvm-project index bb1f220d534b..3ca2a5fc0b84 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit bb1f220d534b0f6d80bea36662f5188ff11c2e54 +Subproject commit 3ca2a5fc0b84762f0e7d8a0e613fd69f7e344219 diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index 31bc717914fd..20fb70b2ddb3 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -507,7 +507,7 @@ def Torch_PrimCallMethodOp : Torch_Op<"prim.CallMethod", []> { } def Torch_PrimLoopOp : Torch_Op<"prim.Loop", [ - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods]> { let summary = "TorchScript prim::Loop op"; let description = [{ This op (together with prim.Loop.condition) define a looping construct @@ -559,7 +559,7 @@ def Torch_PrimLoopConditionOp : Torch_Op<"prim.Loop.condition", [ } def Torch_PrimIfOp : Torch_Op<"prim.If", [ - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods]> { let summary = "TorchScript prim::If op"; let description = [{ This op (together with prim.If.yield) define a conditional control flow @@ -1183,7 +1183,7 @@ def Torch_RuntimeAssertOp: Torch_Op<"runtime.assert", [ //===----------------------------------------------------------------------===// def Torch_ShapeCalculateOp : Torch_Op<"shape.calculate", [ - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods]> { let summary = "Shape calculation encapsulation op"; let description = [{ The `torch.shape.calculate` op captures a shape calculation @@ -1263,7 +1263,7 @@ def Torch_ShapeCalculateYieldShapesOp : Torch_Op<"shape.calculate.yield.shapes", //===----------------------------------------------------------------------===// def Torch_DtypeCalculateOp : Torch_Op<"dtype.calculate", [ - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods]> { let summary = "Dtype calculation encapsulation op"; let description = [{ The `torch.dtype.calculate` op captures a dtype calculation diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 815ab1bd8444..440313af17fa 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -423,12 +423,17 @@ void PrimLoopOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { Region ®ion = getRegion(); if (!point.getTerminatorPredecessorOrNull()) { - regions.emplace_back(®ion, region.getArguments().slice(1)); + regions.emplace_back(®ion); return; } assert(point.getTerminatorPredecessorOrNull()->getParentRegion() == ®ion); - regions.emplace_back(®ion, region.getArguments().slice(1)); - regions.emplace_back(getOperation(), getResults()); + regions.emplace_back(®ion); + regions.emplace_back(RegionSuccessor::parent()); +} + +ValueRange PrimLoopOp::getSuccessorInputs(RegionSuccessor successor) { + return successor.isParent() ? ValueRange(getResults()) + : ValueRange(getRegion().getArguments().slice(1)); } bool PrimLoopOp::isForLike() { @@ -494,7 +499,7 @@ void PrimIfOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { // The `then` and the `else` region branch back to the parent operation. if (point.getTerminatorPredecessorOrNull()) { - regions.push_back(RegionSuccessor(getOperation(), getResults())); + regions.push_back(RegionSuccessor::parent()); return; } @@ -512,6 +517,10 @@ void PrimIfOp::getSuccessorRegions(RegionBranchPoint point, return; } +ValueRange PrimIfOp::getSuccessorInputs(RegionSuccessor successor) { + return successor.isParent() ? ValueRange(getResults()) : ValueRange(); +} + /// Replaces the given op with the contents of the given single-block region, /// using the operands of the block terminator to replace operation results. static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, @@ -5376,7 +5385,7 @@ getSuccessorRegionsForCalculateOp(CalculateOp op, RegionBranchPoint point, Region *region = point.getTerminatorPredecessorOrNull()->getParentRegion(); if (region == &op.getBody()) { // Body returns control to the outer op, passing through results. - regions.emplace_back(op.getOperation(), op.getResults()); + regions.emplace_back(RegionSuccessor::parent()); return; } assert(region == &op.getCalculation()); @@ -5389,6 +5398,10 @@ void ShapeCalculateOp::getSuccessorRegions( getSuccessorRegionsForCalculateOp(*this, point, regions); } +ValueRange ShapeCalculateOp::getSuccessorInputs(RegionSuccessor successor) { + return successor.isParent() ? ValueRange(getResults()) : ValueRange(); +} + //===----------------------------------------------------------------------===// // DtypeCalculateOp //===----------------------------------------------------------------------===// @@ -5398,6 +5411,10 @@ void DtypeCalculateOp::getSuccessorRegions( getSuccessorRegionsForCalculateOp(*this, point, regions); } +ValueRange DtypeCalculateOp::getSuccessorInputs(RegionSuccessor successor) { + return successor.isParent() ? ValueRange(getResults()) : ValueRange(); +} + //===----------------------------------------------------------------------===// // ShapeCalculateYieldShapesOp //===----------------------------------------------------------------------===//