Skip to content

Commit

Permalink
Add pipeline for TorchBackendToTcpBackend and lit test for VerifyTcpB…
Browse files Browse the repository at this point in the history
…ackendContractPass (#7)

Includes lit tests that check the pipeline.
  • Loading branch information
sjain-stanford authored Oct 16, 2023
1 parent 53b16f9 commit 0d42fc5
Show file tree
Hide file tree
Showing 27 changed files with 224 additions and 44 deletions.
11 changes: 8 additions & 3 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -149,23 +149,26 @@ gentbl_cc_library(
)

cc_library(
name = "TcpPasses",
name = "TcpDialectPasses",
srcs = [
"lib/Dialect/Transforms/FuseTcpOpsPass.cpp",
"lib/Dialect/Transforms/IsolateGroupOpsPass.cpp",
"lib/Dialect/Transforms/PassDetail.h",
"lib/Dialect/Transforms/Passes.cpp",
"lib/Dialect/Transforms/VerifyTcpBackendContractPass.cpp",
],
hdrs = [
"include/Dialect/Transforms/FuseTcpOpsPass.h",
"include/Dialect/Transforms/IsolateGroupOpsPass.h",
"include/Dialect/Transforms/Passes.h",
"include/Dialect/Transforms/VerifyTcpBackendContractPass.h",
],
strip_include_prefix = "include",
deps = [
":TcpDialect",
":TcpTransformsPassesIncGen",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
)
Expand Down Expand Up @@ -313,7 +316,7 @@ cc_library(
strip_include_prefix = "include",
deps = [
":TcpConversionPasses",
":TcpPasses",
":TcpDialectPasses",
"@llvm-project//mlir:AllExtensions",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:DialectUtils",
Expand All @@ -333,8 +336,10 @@ cc_library(
strip_include_prefix = "include",
deps = [
":TcpConversionPasses",
":TcpDialectPasses",
"@llvm-project//mlir:ConversionPasses",
"@llvm-project//mlir:Pass",
"@torch-mlir//:TorchMLIRTorchConversionPasses",
],
)

Expand All @@ -346,8 +351,8 @@ cc_binary(
deps = [
":Pipeline",
":TcpDialect",
":TcpDialectPasses",
":TcpInitAll",
":TcpPasses",
"@llvm-project//mlir:AllExtensions",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:MlirOptLib",
Expand Down
6 changes: 0 additions & 6 deletions include/Dialect/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@

namespace mlir::tcp {

std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createTcpFuseElementwiseOpsPass();

std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createTcpIsolateGroupOpsPass();

/// Registers all Tcp related passes.
void registerTcpPasses();

Expand Down
6 changes: 6 additions & 0 deletions include/Dialect/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,10 @@ def TcpIsolateGroupOps : Pass<"tcp-isolate-group-ops", "ModuleOp"> {
let constructor = "mlir::tcp::createTcpIsolateGroupOpsPass()";
}

// \brief This pass verifies conformity to the TCP backend contract.
def VerifyTcpBackendContract : Pass<"torch-verify-tcp-backend-contract", "ModuleOp"> {
let summary = "Verifies conformity to the tcp backend contract";
let constructor = "mlir::tcp::createVerifyTcpBackendContractPass()";
}

#endif // TCP_PASSES
21 changes: 21 additions & 0 deletions include/Dialect/Transforms/VerifyTcpBackendContractPass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#pragma once

#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include <memory>

namespace mlir::tcp {

std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createVerifyTcpBackendContractPass();

} // namespace mlir::tcp
1 change: 1 addition & 0 deletions lib/Dialect/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "Dialect/Transforms/Passes.h"
#include "Dialect/Transforms/FuseTcpOpsPass.h"
#include "Dialect/Transforms/IsolateGroupOpsPass.h"
#include "Dialect/Transforms/VerifyTcpBackendContractPass.h"
#include "mlir/Pass/Pass.h"
#include <memory>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
//
//===----------------------------------------------------------------------===//

#include "../PassDetail.h"
#include "Dialect/Transforms/VerifyTcpBackendContractPass.h"
#include "./PassDetail.h"

#include "Dialect/IR/TcpDialect.h"
#include "Dialect/IR/TcpOps.h"
Expand All @@ -17,11 +18,10 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::TorchConversion;

namespace mlir::tcp {

namespace {
class VerifyTcpBackendContractPass
Expand Down Expand Up @@ -61,7 +61,8 @@ class VerifyTcpBackendContractPass
};
} // namespace

std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::TorchConversion::createVerifyTcpBackendContractPass() {
std::unique_ptr<OperationPass<ModuleOp>> createVerifyTcpBackendContractPass() {
return std::make_unique<VerifyTcpBackendContractPass>();
}

} // namespace mlir::tcp
37 changes: 34 additions & 3 deletions lib/Pipeline/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@

#include "Pipeline/Pipeline.h"

#include "Dialect/Transforms/VerifyTcpBackendContractPass.h"

#include "Conversion/TcpToArith/TcpToArith.h"
#include "Conversion/TcpToLinalg/TcpToLinalg.h"
#include "Conversion/TorchToTcp/TorchToTcp.h"

#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
Expand All @@ -27,9 +30,32 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"

#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"

using namespace mlir;

static void tcpToLlvmPipelineBuilder(OpPassManager &pm) {
static void createTorchBackendToTcpBackendPipeline(OpPassManager &pm) {
pm.addNestedPass<func::FuncOp>(tcp::createConvertTorchToTcpPass());

// Clean up any non-canonical code introduced above.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<func::FuncOp>(createCSEPass());

// Finish the type conversion from `torch` types to the types of the
// TCP backend contract.
pm.addPass(torch::TorchConversion::createFuncBackendTypeConversionPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(
torch::TorchConversion::createFinalizingBackendTypeConversionPass());

// Verify that we have lowered to the form that TCP backend expects.
// This fails compilation (signalPassFailure) if the IR is not in the
// correct form.
pm.addPass(tcp::createVerifyTcpBackendContractPass());
}

static void createTcpToLlvmPipeline(OpPassManager &pm) {
pm.addPass(tcp::createConvertTcpToLinalgPass());
pm.addPass(tcp::createConvertTcpToArithPass());
pm.addPass(func::createFuncBufferizePass());
Expand All @@ -53,6 +79,11 @@ static void tcpToLlvmPipelineBuilder(OpPassManager &pm) {
}

void tcp::registerTcpPipelines() {
PassPipelineRegistration<>("lower-tcp-to-llvm", "Lowers TCP to LLVM",
tcpToLlvmPipelineBuilder);
PassPipelineRegistration<>(
"torch-backend-to-tcp-backend-pipeline",
"Pipeline lowering torch backend contract to TCP backend contract.",
createTorchBackendToTcpBackendPipeline);

PassPipelineRegistration<>("tcp-to-llvm-pipeline", "Lowers TCP to LLVM",
createTcpToLlvmPipeline);
}
22 changes: 22 additions & 0 deletions test/AotCompile/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

load("@//tools/aot:aot_compile.bzl", "aot_compile")
load("@rules_cc//cc:defs.bzl", "cc_test")

aot_compile(
name = "basic_tcp_ops",
tcp_source = "basic_tcp_ops.mlir",
)

cc_test(
name = "test_aot_compiled_basic_tcp_ops",
srcs = ["test_aot_compiled_basic_tcp_ops.cpp"],
deps = [
":aot_compiled_basic_tcp_ops",
"//tools/aot:abi",
"@googletest//:gtest_main",
],
)
File renamed without changes.
File renamed without changes.
18 changes: 1 addition & 17 deletions test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
load("@llvm-project//llvm:lit_test.bzl", "lit_test")
load("@//tools/aot:aot_compile.bzl", "aot_compile")
load("@rules_cc//cc:defs.bzl", "cc_test")

expand_template(
name = "lit_site_cfg_py",
Expand Down Expand Up @@ -46,20 +44,6 @@ filegroup(
for src in glob([
"Conversion/**/*.mlir",
"Dialect/**/*.mlir",
"Pipeline/**/*.mlir",
])
]

aot_compile(
name = "basic_tcp_ops",
tcp_source = "basic_tcp_ops.mlir",
)

cc_test(
name = "test_aot_compiled_basic_tcp_ops",
srcs = ["test_aot_compiled_basic_tcp_ops.cpp"],
deps = [
":aot_compiled_basic_tcp_ops",
"//tools/aot:abi",
"@googletest//:gtest_main",
],
)
2 changes: 1 addition & 1 deletion test/Conversion/StablehloToTcp/stablehlo.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: tcp-opt <%s -convert-stablehlo-to-tcp -split-input-file | FileCheck %s
// RUN: tcp-opt %s -convert-stablehlo-to-tcp -split-input-file | FileCheck %s

// CHECK-LABEL: func.func @tanh(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/TcpToArith/basic.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: tcp-opt <%s -convert-tcp-to-arith -split-input-file | FileCheck %s
// RUN: tcp-opt %s -convert-tcp-to-arith -split-input-file | FileCheck %s

// CHECK-LABEL: func.func @test_constants() -> tensor<f32> {
// CHECK: %[[C0:.*]] = arith.constant dense<2.500000e+00> : tensor<f32>
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/TcpToLinalg/binary.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: tcp-opt <%s -convert-tcp-to-linalg -split-input-file | FileCheck %s
// RUN: tcp-opt %s -convert-tcp-to-linalg -split-input-file | FileCheck %s

// CHECK: #[[MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>

Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/TcpToLinalg/misc.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: tcp-opt <%s -convert-tcp-to-linalg -split-input-file | FileCheck %s
// RUN: tcp-opt %s -convert-tcp-to-linalg -split-input-file | FileCheck %s

// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1) -> (0, d1)>
// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/TcpToLinalg/unary.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: tcp-opt <%s -convert-tcp-to-linalg -split-input-file | FileCheck %s
// RUN: tcp-opt %s -convert-tcp-to-linalg -split-input-file | FileCheck %s

// CHECK: #[[MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>

Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/TorchToTcp/data_movement.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: tcp-opt <%s -convert-torch-to-tcp -split-input-file -verify-diagnostics | FileCheck %s
// RUN: tcp-opt %s -convert-torch-to-tcp -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: @torch.aten.cat
// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.+]]: !torch.vtensor<[?,?],f32>
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/TorchToTcp/elementwise.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: tcp-opt <%s -convert-torch-to-tcp -split-input-file -verify-diagnostics | FileCheck %s
// RUN: tcp-opt %s -convert-torch-to-tcp -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: func.func @torch.aten.sigmoid(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/TorchToTcp/misc.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: tcp-opt <%s -convert-torch-to-tcp -split-input-file -verify-diagnostics | FileCheck %s
// RUN: tcp-opt %s -convert-torch-to-tcp -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: func.func @torch.vtensor.literal() -> !torch.vtensor<[4],f32> {
// CHECK: %[[T1:.*]] = tcp.const {value = dense<[5.000000e-01, 4.000000e-01, 3.000000e-01, 6.000000e-01]> : tensor<4xf32>} : tensor<4xf32>
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 0d42fc5

Please sign in to comment.