Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 6783 files
8 changes: 4 additions & 4 deletions include/torch-mlir/Dialect/Torch/IR/TorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def Torch_PrimCallMethodOp : Torch_Op<"prim.CallMethod", []> {
}

def Torch_PrimLoopOp : Torch_Op<"prim.Loop", [
DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getEntrySuccessorOperands"]>]> {
DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getEntrySuccessorOperands", "getSuccessorInputs"]>]> {
let summary = "TorchScript prim::Loop op";
let description = [{
This op (together with prim.Loop.condition) define a looping construct
Expand Down Expand Up @@ -559,7 +559,7 @@ def Torch_PrimLoopConditionOp : Torch_Op<"prim.Loop.condition", [
}

def Torch_PrimIfOp : Torch_Op<"prim.If", [
DeclareOpInterfaceMethods<RegionBranchOpInterface>]> {
DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorInputs"]>]> {
let summary = "TorchScript prim::If op";
let description = [{
This op (together with prim.If.yield) define a conditional control flow
Expand Down Expand Up @@ -1183,7 +1183,7 @@ def Torch_RuntimeAssertOp: Torch_Op<"runtime.assert", [
//===----------------------------------------------------------------------===//

def Torch_ShapeCalculateOp : Torch_Op<"shape.calculate", [
DeclareOpInterfaceMethods<RegionBranchOpInterface>]> {
DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorInputs"]>]> {
let summary = "Shape calculation encapsulation op";
let description = [{
The `torch.shape.calculate` op captures a shape calculation
Expand Down Expand Up @@ -1263,7 +1263,7 @@ def Torch_ShapeCalculateYieldShapesOp : Torch_Op<"shape.calculate.yield.shapes",
//===----------------------------------------------------------------------===//

def Torch_DtypeCalculateOp : Torch_Op<"dtype.calculate", [
DeclareOpInterfaceMethods<RegionBranchOpInterface>]> {
DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorInputs"]>]> {
let summary = "Dtype calculation encapsulation op";
let description = [{
The `torch.dtype.calculate` op captures a dtype calculation
Expand Down
27 changes: 22 additions & 5 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,12 +423,17 @@ void PrimLoopOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
Region &region = getRegion();
if (!point.getTerminatorPredecessorOrNull()) {
regions.emplace_back(&region, region.getArguments().slice(1));
regions.emplace_back(&region);
return;
}
assert(point.getTerminatorPredecessorOrNull()->getParentRegion() == &region);
regions.emplace_back(&region, region.getArguments().slice(1));
regions.emplace_back(getOperation(), getResults());
regions.emplace_back(&region);
regions.emplace_back(RegionSuccessor::parent());
}

ValueRange PrimLoopOp::getSuccessorInputs(RegionSuccessor successor) {
return successor.isParent() ? ValueRange(getResults())
: ValueRange(getRegion().getArguments().slice(1));
}

bool PrimLoopOp::isForLike() {
Expand Down Expand Up @@ -494,7 +499,7 @@ void PrimIfOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
// The `then` and the `else` region branch back to the parent operation.
if (point.getTerminatorPredecessorOrNull()) {
regions.push_back(RegionSuccessor(getOperation(), getResults()));
regions.push_back(RegionSuccessor::parent());
return;
}

Expand All @@ -512,6 +517,10 @@ void PrimIfOp::getSuccessorRegions(RegionBranchPoint point,
return;
}

ValueRange PrimIfOp::getSuccessorInputs(RegionSuccessor successor) {
return successor.isParent() ? ValueRange(getResults()) : ValueRange();
}

/// Replaces the given op with the contents of the given single-block region,
/// using the operands of the block terminator to replace operation results.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
Expand Down Expand Up @@ -5376,7 +5385,7 @@ getSuccessorRegionsForCalculateOp(CalculateOp op, RegionBranchPoint point,
Region *region = point.getTerminatorPredecessorOrNull()->getParentRegion();
if (region == &op.getBody()) {
// Body returns control to the outer op, passing through results.
regions.emplace_back(op.getOperation(), op.getResults());
regions.emplace_back(RegionSuccessor::parent());
return;
}
assert(region == &op.getCalculation());
Expand All @@ -5389,6 +5398,10 @@ void ShapeCalculateOp::getSuccessorRegions(
getSuccessorRegionsForCalculateOp(*this, point, regions);
}

ValueRange ShapeCalculateOp::getSuccessorInputs(RegionSuccessor successor) {
return successor.isParent() ? ValueRange(getResults()) : ValueRange();
}

//===----------------------------------------------------------------------===//
// DtypeCalculateOp
//===----------------------------------------------------------------------===//
Expand All @@ -5398,6 +5411,10 @@ void DtypeCalculateOp::getSuccessorRegions(
getSuccessorRegionsForCalculateOp(*this, point, regions);
}

ValueRange DtypeCalculateOp::getSuccessorInputs(RegionSuccessor successor) {
return successor.isParent() ? ValueRange(getResults()) : ValueRange();
}

//===----------------------------------------------------------------------===//
// ShapeCalculateYieldShapesOp
//===----------------------------------------------------------------------===//
Expand Down
Loading