Skip to content
This repository has been archived by the owner on Jan 30, 2025. It is now read-only.

Commit

Permalink
[AOT Compile] E2E compile->execute->test framework using TorchDynamo …
Browse files Browse the repository at this point in the history
…export, TCP compilation, CPU codegen (#46)

### [Large PR] For ease of review, I've stacked the changes into 3
commits:
-
ada5753:
*[Please review]* AOT compile flow for TD export + TCP compilation and
CPU codegen
-
75031b6:
*[Please skim]* AOT README (also pasted as PR description below)
-
705aeba:
*[Please skip]* NFC / mechanical changes

Inlining `tools/aot/README.md` below for readability.

-------

AOT Compile (Developer Guide)
=============================

The
[`aot_compile`](https://github.com/cruise-automation/mlir-tcp/blob/main/tools/aot/aot_compile.bzl)
bazel macro implements an end-to-end framework to compile PyTorch (or
TCP) programs to a CPU library, execute it and test for functional
correctness of the generated code. It comprises starting with
TorchDynamo export of PyTorch programs, conversion and lowerings through
{Torch, TCP, Linalg, LLVM} MLIR dialects, translation to LLVM assembly,
compilation to assembly source for the host architecture (CPU), and
lastly generation of shared object that can be dynamically linked into
an executable at runtime. It leverages a series of genrules to stitch
the compilation pipeline together, and an unsophisticated
meta-programming trick for auto-generating C++ tests (specialized to the
input program's function signature) that execute the compiled code and
validate its numerics against reference PyTorch.

When authoring new TCP ops with dialect conversions from/to Torch and
Linalg, adding an `aot_compile` target is a fast, automated and
standardized way to test the e2e compilation and validate that the op
lowerings are implemented consistent with PyTorch semantics.

## Compile PyTorch programs

Onboarding to the `aot_compile` macro is quite easy (examples
[here](https://github.com/cruise-automation/mlir-tcp/blob/main/test/AotCompile/BUILD)).
Start by adding the following line to the `BUILD` to load the macro:
```starlark
load("//tools/aot:aot_compile.bzl", "aot_compile")
```

Then call the macro like this:
```starlark
aot_compile(
    name = "broadcast_add_mixed_ranks",
    torch_loader_lib = ":add_mul_loader_lib",
    torch_loader_path = "test.AotCompile.add_mul_loader_lib.broadcast_add_mixed_ranks_loader",
)
```

Here, `torch_loader_lib` expects a `py_library` target for the module
that defines the PyTorch program to be AOT compiled, and
`torch_loader_path` is the full python import path (dot separated) to
the loader function.
```starlark
py_library(
    name = "add_mul_loader_lib",
    srcs = ["add_mul_loader_lib.py"],
    visibility = ["//visibility:public"],
    deps = [
        requirement("torch"),
        "//tools/aot:torch_loader_utils",
    ],
)
```

The loader function can be called anything really, but it should define
the PyTorch program, sample inputs and dynamic dim constraints (if any),
and always return a `TorchLoaderOutput` object. The PyTorch program's
forward function must always consume and return tensors, like so:
```python
import torch
from torch.export import dynamic_dim

from tools.aot.torch_loader_utils import TorchLoaderOutput


def broadcast_add_mixed_ranks_loader() -> TorchLoaderOutput:
    class BroadcastAddMixedRanks(torch.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
            add = torch.add(x, y)
            return add

    # Sample inputs
    x = torch.tensor(10.0)
    y = torch.randn(2)

    # Dynamic dim constraints
    constraints = [dynamic_dim(y, 0)]

    return TorchLoaderOutput(
        model=BroadcastAddMixedRanks(),
        inputs=[x, y],
        constraints=constraints,
    )
```

An invocation of `aot_compile(name="foo", ...)` generates a bunch of
targets (see
[here](https://github.com/cruise-automation/mlir-tcp/blob/main/tools/aot/aot_compile.bzl#L43)
for the list) that can be helpful in debugging the intermediate steps in
the compilation process.

To get the full list of `aot_compile` macro generated targets for
`broadcast_add_mixed_ranks`, run the query:
```shell
$ bazel query 'attr(name, "broadcast_add_mixed_ranks", //test/AotCompile/...)'

//test/AotCompile:aot_compiled_broadcast_add_mixed_ranks
//test/AotCompile:broadcast_add_mixed_ranks_compile_execute_test
//test/AotCompile:broadcast_add_mixed_ranks_execute_test_generator
//test/AotCompile:broadcast_add_mixed_ranks_torch_exporter
//test/AotCompile:gen_broadcast_add_mixed_ranks_execute_test
//test/AotCompile:gen_broadcast_add_mixed_ranks_host_asm
//test/AotCompile:gen_broadcast_add_mixed_ranks_llvm_ir
//test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_llvm
//test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_tcp
//test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_torch
//test/AotCompile:gen_broadcast_add_mixed_ranks_reference_tensors
```

Lets walk through a series of steps involved in debugging an e2e
compilation pipeline. Note that these steps are not required to be
manually run one at a time (although they can be). Bazel automatically
identifies the DAG of dependencies and executes just what is needed to
build the specified target.

#### 1. Inspect the Torch dialect (`*_torch.mlir`) exported from the
PyTorch program:
```shell
$ bazel build //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_torch

INFO: Analyzed target //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_torch (61 packages loaded, 16582 targets configured).
INFO: Found 1 target...
Target //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_torch up-to-date:
  bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_torch.mlir
INFO: Elapsed time: 6.085s, Critical Path: 0.69s
INFO: 1 process: 1 internal.
INFO: Build completed successfully, 1 total action
```
```ll
$ cat bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_torch.mlir

module {
  func.func @func_main(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
    %int1 = torch.constant.int 1
    %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32>
    return %0 : !torch.vtensor<[?],f32>
  }
}
```

#### 2. Inspect the TCP dialect (`*_tcp.mlir`) lowered from the Torch
dialect:
```shell
$ bazel build //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_tcp

INFO: Analyzed target //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_tcp (0 packages loaded, 0 targets configured).
INFO: Found 1 target...
Target //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_tcp up-to-date:
  bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_tcp.mlir
INFO: Elapsed time: 0.572s, Critical Path: 0.03s
INFO: 1 process: 1 internal.
INFO: Build completed successfully, 1 total action
```
```ll
$ cat bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_tcp.mlir

module {
  func.func @func_main(%arg0: tensor<f32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
    %c0 = arith.constant 0 : index
    %expanded = tensor.expand_shape %arg0 [] : tensor<f32> into tensor<1xf32>
    %dim = tensor.dim %arg1, %c0 : tensor<?xf32>
    %0 = tcp.broadcast %expanded, %dim {axes = [0]} : tensor<1xf32>, index -> tensor<?xf32>
    %1 = tcp.add %0, %arg1 : tensor<?xf32>, tensor<?xf32> -> tensor<?xf32>
    return %1 : tensor<?xf32>
  }
}
```

#### 3. Inspect the LLVM dialect (`*_llvm.mlir`) lowered from the TCP
dialect:
```shell
$ bazel build //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_llvm

INFO: Analyzed target //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_llvm (0 packages loaded, 0 targets configured).
INFO: Found 1 target...
Target //test/AotCompile:gen_broadcast_add_mixed_ranks_mlir_llvm up-to-date:
  bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_llvm.mlir
INFO: Elapsed time: 0.305s, Critical Path: 0.00s
INFO: 1 process: 1 internal.
INFO: Build completed successfully, 1 total action
```
```ll
$ cat bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_llvm.mlir

module {
  llvm.func @malloc(i64) -> !llvm.ptr
  llvm.func @func_main(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: !llvm.ptr, %arg4: !llvm.ptr, %arg5: i64, %arg6: i64, %arg7: i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> {
    %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64)>
    %1 = llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr, ptr, i64)>
    %2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr, ptr, i64)>
    %3 = llvm.insertvalue %arg2, %2[2] : !llvm.struct<(ptr, ptr, i64)>
.
.
.
    %57 = llvm.load %56 : !llvm.ptr -> f32
    %58 = llvm.fadd %54, %57  : f32
    %59 = llvm.getelementptr %44[%51] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    llvm.store %58, %59 : f32, !llvm.ptr
    %60 = llvm.add %51, %14  : i64
    llvm.br ^bb4(%60 : i64)
  ^bb6:  // pred: ^bb4
    llvm.return %50 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
  }
}
```

#### 4. Inspect the LLVM assembly (`*.ll`) translated from the LLVM
dialect:
```shell
$ bazel build //test/AotCompile:gen_broadcast_add_mixed_ranks_llvm_ir

INFO: Analyzed target //test/AotCompile:gen_broadcast_add_mixed_ranks_llvm_ir (0 packages loaded, 0 targets configured).
INFO: Found 1 target...
Target //test/AotCompile:gen_broadcast_add_mixed_ranks_llvm_ir up-to-date:
  bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks.ll
INFO: Elapsed time: 0.312s, Critical Path: 0.00s
INFO: 1 process: 1 internal.
INFO: Build completed successfully, 1 total action
```
```ll
$ cat bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks.ll

; ModuleID = 'LLVMDialectModule'
source_filename = "LLVMDialectModule"

declare ptr @malloc(i64)

define { ptr, ptr, i64, [1 x i64], [1 x i64] } @func_main(ptr %0, ptr %1, i64 %2, ptr %3, ptr %4, i64 %5, i64 %6, i64 %7) {
  %9 = insertvalue { ptr, ptr, i64 } undef, ptr %0, 0
  %10 = insertvalue { ptr, ptr, i64 } %9, ptr %1, 1
  %11 = insertvalue { ptr, ptr, i64 } %10, i64 %2, 2
  %12 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } undef, ptr %3, 0
.
.
.
63:                                               ; preds = %51
  ret { ptr, ptr, i64, [1 x i64], [1 x i64] } %50
}

!llvm.module.flags = !{!0}

!0 = !{i32 2, !"Debug Info Version", i32 3}
```

#### 5. Inspect the assembly source (`*.S`) compiled for the host
architecture (CPU):
```shell
$ bazel build //test/AotCompile:gen_broadcast_add_mixed_ranks_host_asm

INFO: Analyzed target //test/AotCompile:gen_broadcast_add_mixed_ranks_host_asm (0 packages loaded, 0 targets configured).
INFO: Found 1 target...
Target //test/AotCompile:gen_broadcast_add_mixed_ranks_host_asm up-to-date:
  bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks.S
INFO: Elapsed time: 0.360s, Critical Path: 0.03s
INFO: 1 process: 1 internal.
INFO: Build completed successfully, 1 total action
```
```ll
$ cat bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks.S

        .text
        .file   "LLVMDialectModule"
        .globl  func_main                       # -- Begin function func_main
        .p2align        4, 0x90
        .type   func_main,@function
func_main:                              # @func_main
        .cfi_startproc
# %bb.0:
        pushq   %rbp
        .cfi_def_cfa_offset 16
.
.
.
        popq    %r14
        .cfi_def_cfa_offset 24
        popq    %r15
        .cfi_def_cfa_offset 16
        popq    %rbp
        .cfi_def_cfa_offset 8
        retq
.Lfunc_end0:
        .size   func_main, .Lfunc_end0-func_main
        .cfi_endproc
                                        # -- End function
        .section        ".note.GNU-stack","",@progbits
```

#### 6. Build the shared object (`*.so`) from the host assembly that can
be dynamically linked into an executable at runtime:
```shell
$ bazel build //test/AotCompile:aot_compiled_broadcast_add_mixed_ranks

INFO: Analyzed target //test/AotCompile:aot_compiled_broadcast_add_mixed_ranks (8 packages loaded, 8403 targets configured).
INFO: Found 1 target...
Target //test/AotCompile:aot_compiled_broadcast_add_mixed_ranks up-to-date:
  bazel-bin/test/AotCompile/libaot_compiled_broadcast_add_mixed_ranks.a
  bazel-bin/test/AotCompile/libaot_compiled_broadcast_add_mixed_ranks.so
INFO: Elapsed time: 2.264s, Critical Path: 0.12s
INFO: 1 process: 1 internal.
INFO: Build completed successfully, 1 total action
```

#### 7. Save the reference input and output tensors needed for
validation of the compiled code:
```shell
$ bazel build //test/AotCompile:gen_broadcast_add_mixed_ranks_reference_tensors

INFO: Analyzed target //test/AotCompile:gen_broadcast_add_mixed_ranks_reference_tensors (0 packages loaded, 5 targets configured).
INFO: Found 1 target...
Target //test/AotCompile:gen_broadcast_add_mixed_ranks_reference_tensors up-to-date:
  bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_reference_tensors.npz
INFO: Elapsed time: 0.743s, Critical Path: 0.15s
INFO: 1 process: 1 internal.
INFO: Build completed successfully, 1 total action
```

#### 8. Inspect the C++ test (`*_execute_test.cpp`) auto-generated from
the template:
```shell
$ bazel build //test/AotCompile:gen_broadcast_add_mixed_ranks_execute_test

INFO: Analyzed target //test/AotCompile:gen_broadcast_add_mixed_ranks_execute_test (22 packages loaded, 91 targets configured).
INFO: Found 1 target...Target //test/AotCompile:gen_broadcast_add_mixed_ranks_execute_test up-to-date:
  bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_execute_test.cpp
INFO: Elapsed time: 0.329s, Critical Path: 0.02s
INFO: 1 process: 1 internal.
INFO: Build completed successfully, 1 total action
```
```cpp
$ cat bazel-bin/test/AotCompile/_internal_broadcast_add_mixed_ranks_execute_test.cpp

//===------------------------------------------------------------*- C++ -*-===//
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "mlir/ExecutionEngine/CRunnerUtils.h"
#include "tools/aot/abi.h"

#include "cnpy.h"
#include "gtest/gtest.h"

using namespace mlir::tcp;

#pragma clang diagnostic ignored "-Wreturn-type-c-linkage"

template <typename DataType, int Rank>
static StridedMemRefType<DataType, Rank>
CreateMemRefFromNpyArray(cnpy::NpyArray &arr) {
  StridedMemRefType<DataType, Rank> Result;
  Result.basePtr = arr.data<DataType>();
  Result.data = arr.data<DataType>();
  Result.offset = 0;

  // Check if the Rank matches
  if (arr.shape.size() != Rank) {
    std::cerr << "Error: Rank mismatch." << std::endl;
    // Return an uninitialized memref
    return Result;
  }

  // Check if the DataType matches
  if (arr.word_size != sizeof(DataType)) {
    std::cerr << "Error: Data type mismatch." << std::endl;
    // Return an uninitialized memref
    return Result;
  }

  // Set sizes and strides based on the shape of the numpy array
  int stride = 1;
  for (int i = Rank - 1; i >= 0; --i) {
    Result.sizes[i] = arr.shape[i];
    Result.strides[i] = stride;
    stride *= arr.shape[i];
  }

  return Result;
}

// CreateMemRefFromNpyArray function specialized for rank 0
template <typename DataType>
static StridedMemRefType<DataType, 0>
CreateMemRefFromNpyArray(cnpy::NpyArray &arr) {
  StridedMemRefType<DataType, 0> Result;
  Result.basePtr = arr.data<DataType>();
  Result.data = arr.data<DataType>();
  Result.offset = 0;

  // Check if the Rank matches
  if (!arr.shape.empty()) {
    std::cerr << "Error: Rank mismatch. Expected rank-0 array." << std::endl;
    // Return an uninitialized memref
    return Result;
  }

  // Check if the DataType matches
  if (arr.word_size != sizeof(DataType)) {
    std::cerr << "Error: Data type mismatch." << std::endl;
    // Return an uninitialized memref
    return Result;
  }

  return Result;
}

// ### DO NOT MODIFY ### //
// This template file is pre-processed by `aot_compile` bazel macro
// to materialize the templated parameters based on the inputs
// passed by the callsite where the macro is instantiated.

struct OutputMemRefDescriptor {
  StridedMemRefType<float, 1> Output0;
};

extern "C" OutputMemRefDescriptor func_main(
    DECL_RANK_0_MEMREF_ABI(float),
    DECL_RANK_1_MEMREF_ABI(float)
);

TEST(AotCompiled, ExecuteTest) {
  cnpy::npz_t reference_tensors = cnpy::npz_load(
      "test/AotCompile/_internal_broadcast_add_mixed_ranks_reference_tensors.npz"
  );

  cnpy::NpyArray refInput0 = reference_tensors["Input0"];
  cnpy::NpyArray refInput1 = reference_tensors["Input1"];
  cnpy::NpyArray refOutput0 = reference_tensors["Output0"];

  StridedMemRefType<float, 0> Input0 =
      CreateMemRefFromNpyArray<float>(refInput0);
  StridedMemRefType<float, 1> Input1 =
      CreateMemRefFromNpyArray<float, 1>(refInput1);

  OutputMemRefDescriptor Result = func_main(
      PASS_RANK_0_MEMREF(Input0),
      PASS_RANK_1_MEMREF(Input1)
  );

  ASSERT_EQ(Result.Output0.sizes[0], refOutput0.shape[0]);

  for (int i = 0; i < refOutput0.num_vals; i++)
    EXPECT_EQ(Result.Output0.data[i], refOutput0.data<float>()[i]);

  free(Result.Output0.basePtr);
}
```

#### 9. Run the C++ test to execute the generated code and validate
functional correctness against reference PyTorch
```shell
$ bazel run //test/AotCompile:broadcast_add_mixed_ranks_compile_execute_test

INFO: Analyzed target //test/AotCompile:broadcast_add_mixed_ranks_compile_execute_test (0 packages loaded, 0 targets configured).
INFO: Found 1 target...
Target //test/AotCompile:broadcast_add_mixed_ranks_compile_execute_test up-to-date:
  bazel-bin/test/AotCompile/broadcast_add_mixed_ranks_compile_execute_test
INFO: Elapsed time: 0.215s, Critical Path: 0.00s
INFO: 1 process: 1 internal.
INFO: Build completed successfully, 1 total action
INFO: Running command line: external/bazel_tools/tools/test/test-setup.sh test/AotCompile/broadcast_add_mixed_ranks_compile_execute_test
exec ${PAGER:-/usr/bin/less} "$0" || exit 1
Executing tests from //test/AotCompile:broadcast_add_mixed_ranks_compile_execute_test
-----------------------------------------------------------------------------
Running main() from gmock_main.cc
[==========] Running 1 test from 1 test suite.
[----------] Global test environment set-up.
[----------] 1 test from AotCompiled
[ RUN      ] AotCompiled.ExecuteTest
[       OK ] AotCompiled.ExecuteTest (1 ms)
[----------] 1 test from AotCompiled (1 ms total)

[----------] Global test environment tear-down
[==========] 1 test from 1 test suite ran. (1 ms total)
[  PASSED  ] 1 test.
```

## Compile TCP programs

The `aot_compile` macro also accepts TCP dialect programs as inputs
(instead of PyTorch programs). This is useful to maintain framework
neutrality by allowing alternate ingress pathways (like Stablehlo, JAX,
TensorFlow, ONNX etc.) into the TCP dialect. When `tcp_source` is
specified, the generated `aot_compiled_foo` CPU library has one global
function for every function in the TCP program. Let's look at an
example.

```starlark
aot_compile(
    name = "basic_tcp_ops",
    tcp_source = "basic_tcp_ops.mlir",
)
```

Here, `tcp_source` expects a `.mlir` file containing TCP programs, like
so:
```ll
# basic_tcp_ops.mlir

func.func @func_1(%arg0: tensor<?x?xf32>,
                  %arg1: tensor<?x?xf32>,
                  %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
  %0 = tcp.add %arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
  %1 = tcp.mul %0, %arg2 : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
  return %1 : tensor<?x?xf32>
}

func.func @func_2(%arg0: tensor<?x?xf32>,
                  %arg1: tensor<?x?xf32>,
                  %arg2: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
  %0 = tcp.add %arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
  %1 = tcp.mul %0, %arg2 : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
  return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
}

func.func @func_3(%arg0: tensor<f32>,
                  %arg1: tensor<?xf32>) -> tensor<?xf32> {
  %c0 = arith.constant 0 : index
  %dim = tensor.dim %arg1, %c0 : tensor<?xf32>
  %arg0_ex = tensor.expand_shape %arg0 [] : tensor<f32> into tensor<1xf32>
  %arg0_bcast = tcp.broadcast %arg0_ex, %dim {axes = [0]} : tensor<1xf32>, index -> tensor<?xf32>
  %0 = tcp.add %arg0_bcast, %arg1 : tensor<?xf32>, tensor<?xf32> -> tensor<?xf32>
  return %0 : tensor<?xf32>
}
```

Now run the query to get all the relevant targets created.
```shell
$ bazel query 'attr(name, "basic_tcp_ops", //test/AotCompile/...)'

//test/AotCompile:aot_compiled_basic_tcp_ops
//test/AotCompile:gen_basic_tcp_ops_host_asm
//test/AotCompile:gen_basic_tcp_ops_llvm_ir
//test/AotCompile:gen_basic_tcp_ops_mlir_llvm
```

Note we're missing the
`//test/AotCompile:basic_tcp_ops_compile_execute_test` target. As there
is no access to PyTorch reference implementation, the `aot_compile`
macro does not auto-generate C++ execute tests but they can be manually
written (example
[here](https://github.com/cruise-automation/mlir-tcp/blob/main/test/AotCompile/test_aot_compiled_basic_tcp_ops.cpp)).
These tests should include `extern "C"` function declarations with the
same name and for every function in the input TCP source.

The rest of the steps to debug the e2e compilation pipeline are pretty
much the same.
  • Loading branch information
sjain-stanford authored Mar 6, 2024
1 parent 8d5d10e commit c3442f9
Show file tree
Hide file tree
Showing 26 changed files with 1,370 additions and 56 deletions.
5 changes: 4 additions & 1 deletion .bazelignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

third_party/
# ignore local_repos of llvm-project, torch-mlir, stablehlo
third_party/llvm-project
third_party/torch-mlir
third_party/stablehlo
2 changes: 1 addition & 1 deletion .github/workflows/bazelBuildAndTestTcp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
mlir-tcp:ci \
find . -type f -name "*.cpp" -o -name "*.h" | xargs clang-format -i
if [ -n "$(git status --porcelain)" ]; then
echo "Please run clang-format 'find . -type f -name "*.cpp" -o -name "*.h" | xargs clang-format -i' and commit changes."
echo "Please run 'find . -type f -name "*.cpp" -o -name "*.h" | xargs clang-format -i' and commit changes."
exit 1
fi
Expand Down
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ bazel-bin
bazel-out
bazel-mlir-tcp
bazel-testlogs
third_party/

# ignore local_repos of llvm, torch-mlir, stablehlo
third_party/llvm-project
third_party/torch-mlir
third_party/stablehlo

# clangd related
.cache
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ bazel build //:tcp-opt
bazel test //...
```

We welcome contributions to `mlir-tcp`. If you do contribute, please finalize your PR with clang-format and bazel buildifier to ensure the C++ sources and BUILD files are formatted consistently:
We welcome contributions to `mlir-tcp`. When authoring new TCP ops with dialect conversions from/to Torch and Linalg, please include lit tests for dialect and conversions, as well as [aot_compile](https://github.com/cruise-automation/mlir-tcp/blob/main/tools/aot/README.md) generated e2e integration tests. Finally, please finalize your PR with clang-format and bazel buildifier to ensure the C++ sources and BUILD files are formatted consistently:
```shell
# clang-format
find . -type f -name "*.cpp" -o -name "*.h" | xargs clang-format -i
Expand Down
10 changes: 9 additions & 1 deletion deps.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def third_party_deps():
TORCH_MLIR_SHA256 = "205ffab6683d5bcbe9bff6afca5aa547826990b0d9d7d58644f9777c37558fd1"
http_archive(
name = "torch-mlir-raw",
sha256 = TORCH_MLIR_SHA256,
build_file_content = "# empty",
sha256 = TORCH_MLIR_SHA256,
strip_prefix = "torch-mlir-" + TORCH_MLIR_COMMIT,
urls = ["https://github.com/llvm/torch-mlir/archive/{commit}.tar.gz".format(commit = TORCH_MLIR_COMMIT)],
)
Expand Down Expand Up @@ -151,3 +151,11 @@ def third_party_deps():
strip_prefix = "bazel-compile-commands-extractor-6d58fa6bf39f612304e55566fa628fd160b38177",
url = "https://github.com/hedronvision/bazel-compile-commands-extractor/archive/6d58fa6bf39f612304e55566fa628fd160b38177.tar.gz",
)

http_archive(
name = "cnpy",
build_file = "//third_party:cnpy.BUILD",
sha256 = "5120abc54a564efa92c642cc0199cc4fd3f345901157de9fbbdcedbb34d28d8a",
strip_prefix = "cnpy-4e8810b1a8637695171ed346ce68f6984e585ef4",
urls = ["https://github.com/rogersce/cnpy/archive/4e8810b1a8637695171ed346ce68f6984e585ef4.tar.gz"],
)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
-f https://github.com/llvm/torch-mlir-release/releases/expanded_assets/dev-wheels
torch
torch-mlir
numpy
16 changes: 9 additions & 7 deletions requirements_lock.txt
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ numpy==1.26.4 \
--hash=sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef \
--hash=sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3 \
--hash=sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f
# via torch-mlir
# via
# -r requirements.txt
# torch-mlir
packaging==23.2 \
--hash=sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5 \
--hash=sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7
Expand All @@ -142,11 +144,11 @@ torch==2.3.0.dev20240220+cpu \
# via
# -r requirements.txt
# torch-mlir
torch-mlir==20240223.16 \
--hash=sha256:429bce23c3830485b2c35ae48c1308ad66f63d3e86e4a9e19a17b11b58a1f06d \
--hash=sha256:8b463ad944951781a3c2c51f4240bfc15eb459a6565762499a279d332cc611cf
torch-mlir==20240306.29 \
--hash=sha256:c505cb254196f694ba1447af0ba3300d514fafe5520c2b3266d5d9c2e9ee4f93 \
--hash=sha256:da91a833acfba6e80def65295a805156af1399447a141c46f1becdf5f7ae13a3
# via -r requirements.txt
typing-extensions==4.9.0 \
--hash=sha256:23478f88c37f27d76ac8aee6c905017a143b0b1b886c3c9f66bc2fd94f9f5783 \
--hash=sha256:af72aea155e91adfc61c3ae9e0e342dbc0cba726d6cba4b6c72c1f34e47291cd
typing-extensions==4.10.0 \
--hash=sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475 \
--hash=sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb
# via torch
34 changes: 32 additions & 2 deletions test/AotCompile/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,45 @@

load("//tools/aot:aot_compile.bzl", "aot_compile")
load("@rules_cc//cc:defs.bzl", "cc_test")
load("@rules_python//python:defs.bzl", "py_library")
load("@pip_deps//:requirements.bzl", "requirement")

py_library(
name = "add_mul_loader_lib",
srcs = ["add_mul_loader_lib.py"],
visibility = ["//visibility:public"],
deps = [
requirement("torch"),
"//tools/aot:torch_loader_utils",
],
)

aot_compile(
name = "add_mul_single_output",
torch_loader_lib = ":add_mul_loader_lib",
torch_loader_path = "test.AotCompile.add_mul_loader_lib.add_mul_single_output_loader",
)

aot_compile(
name = "add_mul_multi_output",
torch_loader_lib = ":add_mul_loader_lib",
torch_loader_path = "test.AotCompile.add_mul_loader_lib.add_mul_multi_output_loader",
)

aot_compile(
name = "broadcast_add_mixed_ranks",
torch_loader_lib = ":add_mul_loader_lib",
torch_loader_path = "test.AotCompile.add_mul_loader_lib.broadcast_add_mixed_ranks_loader",
)

aot_compile(
name = "basic_tcp_ops",
tcp_source = "basic_tcp_ops.mlir",
)

cc_test(
name = "test_aot_compiled_basic_tcp_ops",
srcs = ["test_aot_compiled_basic_tcp_ops.cpp"],
name = "basic_tcp_ops_compile_execute_test",
srcs = ["basic_tcp_ops_compile_execute_test.cpp"],
tags = ["aot_tests"],
deps = [
":aot_compiled_basic_tcp_ops",
Expand Down
100 changes: 100 additions & 0 deletions test/AotCompile/add_mul_loader_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

import torch
from torch.export import dynamic_dim

from tools.aot.torch_loader_utils import TorchLoaderOutput


def add_mul_single_output_loader() -> TorchLoaderOutput:
class AddMulNetSingleOutput(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(
self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor
) -> torch.Tensor:
add = torch.add(x, y)
mul = torch.mul(add, z)
return mul

# Sample inputs
x = torch.randn(2, 3)
y = torch.randn(2, 3)
z = torch.randn(2, 3)

# Dynamic dim constraints
constraints = [
# Dim 1
dynamic_dim(x, 0) == dynamic_dim(y, 0),
dynamic_dim(y, 0) == dynamic_dim(z, 0),
# Dim 2
dynamic_dim(x, 1) == dynamic_dim(y, 1),
dynamic_dim(y, 1) == dynamic_dim(z, 1),
]

return TorchLoaderOutput(
model=AddMulNetSingleOutput(),
inputs=[x, y, z],
constraints=constraints,
)


def add_mul_multi_output_loader() -> TorchLoaderOutput:
class AddMulNetMultiOutput(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(
self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor
) -> tuple[torch.Tensor]:
add = torch.add(x, y)
mul = torch.mul(add, z)
return add, mul

# Sample inputs
x = torch.randn(2, 3)
y = torch.randn(2, 3)
z = torch.randn(2, 3)

# Dynamic dim constraints
constraints = [
# Dim 1
dynamic_dim(x, 0) == dynamic_dim(y, 0),
dynamic_dim(y, 0) == dynamic_dim(z, 0),
# Dim 2
dynamic_dim(x, 1) == dynamic_dim(y, 1),
dynamic_dim(y, 1) == dynamic_dim(z, 1),
]

return TorchLoaderOutput(
model=AddMulNetMultiOutput(),
inputs=[x, y, z],
constraints=constraints,
)


def broadcast_add_mixed_ranks_loader() -> TorchLoaderOutput:
class BroadcastAddMixedRanks(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
add = torch.add(x, y)
return add

# Sample inputs
x = torch.tensor(10.0)
y = torch.randn(2)

# Dynamic dim constraints
constraints = [dynamic_dim(y, 0)]

return TorchLoaderOutput(
model=BroadcastAddMixedRanks(),
inputs=[x, y],
constraints=constraints,
)
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ TEST(AotCompiled, SingleOutput) {
for (int i = 0; i < 2; i++)
for (int j = 0; j < 3; j++) {
float Expected = (5 + (2 + i)) * (3 + j);
EXPECT_EQ(Result.data[3 * i + j], Expected);
EXPECT_FLOAT_EQ(Result.data[3 * i + j], Expected);
}

free(Result.basePtr);
Expand Down Expand Up @@ -109,10 +109,10 @@ TEST(AotCompiled, MultiOutput) {
for (int i = 0; i < 2; i++)
for (int j = 0; j < 3; j++) {
float ExpectedA = 5 + (2 + i);
EXPECT_EQ(Result.A.data[3 * i + j], ExpectedA);
EXPECT_FLOAT_EQ(Result.A.data[3 * i + j], ExpectedA);

float ExpectedB = ExpectedA * (3 + j);
EXPECT_EQ(Result.B.data[3 * i + j], ExpectedB);
EXPECT_FLOAT_EQ(Result.B.data[3 * i + j], ExpectedB);
}

free(Result.A.basePtr);
Expand All @@ -129,10 +129,10 @@ TEST(AotCompiled, MixedRanks) {
StridedMemRefType<float, 1> Result =
func_3(&Arr0, &Arr0, 0, Arr1, Arr1, 0, 2, 1);

EXPECT_EQ(Result.sizes[0], 2);
EXPECT_EQ(Result.strides[0], 1);
EXPECT_EQ(Result.data[0], 11.0);
EXPECT_EQ(Result.data[1], 12.0);
ASSERT_EQ(Result.sizes[0], 2);
ASSERT_EQ(Result.strides[0], 1);
EXPECT_FLOAT_EQ(Result.data[0], 11.0);
EXPECT_FLOAT_EQ(Result.data[1], 12.0);

free(Result.basePtr);
}
2 changes: 1 addition & 1 deletion test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ filegroup(
tags = ["lit_tests"],
deps = [
requirement("torch"),
requirement("torch_mlir"),
requirement("torch-mlir"),
],
)
for src in glob([
Expand Down
20 changes: 10 additions & 10 deletions test/lit.cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,24 @@

# Populate Lit configuration with the minimal required metadata.
# Some metadata is populated in lit.site.cfg.py.in.
config.name = 'MLIR_TCP_TESTS_SUITE'
config.name = "MLIR_TCP_TESTS_SUITE"
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
config.suffixes = ['.mlir', '.py']
config.suffixes = [".mlir", ".py"]

tool_dirs = [
config.llvm_tools_dir,
config.tcp_tools_dir,
config.llvm_tools_dir,
config.tcp_tools_dir,
]

# Make LLVM, TCP and PYTHON tools available in RUN directives
tools = [
'tcp-opt',
'FileCheck',
'count',
'not',
ToolSubst('%PYTHON', config.python_executable, unresolved='ignore'),
"tcp-opt",
"FileCheck",
"count",
"not",
ToolSubst("%PYTHON", config.python_executable, unresolved="ignore"),
]

llvm_config.add_tool_substitutions(tools, tool_dirs)

llvm_config.with_environment('PYTHONPATH', config.python_path, append_path=True)
llvm_config.with_environment("PYTHONPATH", config.python_path, append_path=True)
2 changes: 1 addition & 1 deletion test/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ py_test(
tags = ["python_tests"],
deps = [
requirement("torch"),
requirement("torch_mlir"),
requirement("torch-mlir"),
],
)

Expand Down
3 changes: 2 additions & 1 deletion test/python/fx_import/basic_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
Expand All @@ -11,6 +11,7 @@

from torch_mlir import fx


def test_import_frozen_exported_program():
@torch._dynamo.assume_constant_result
def get_a():
Expand Down
2 changes: 1 addition & 1 deletion test/python_lit/fx_import/basic_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
Expand Down
4 changes: 4 additions & 0 deletions third_party/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
30 changes: 30 additions & 0 deletions third_party/cnpy.BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test")

licenses(["notice"]) # MIT

package(
default_visibility = [
"//visibility:public",
],
)

cc_library(
name = "cnpy",
srcs = ["cnpy.cpp"],
hdrs = ["cnpy.h"],
copts = [
"-Wno-unused-variable",
],
deps = ["@llvm_zlib//:zlib"],
)

cc_test(
name = "test_cnpy",
srcs = ["example1.cpp"],
deps = [":cnpy"],
)
Loading

0 comments on commit c3442f9

Please sign in to comment.