Skip to content

Commit

Permalink
⬆️ TensorFlow 2.9 (#724)
Browse files Browse the repository at this point in the history
* Update TensorFlow to `r2.9` branch

* ⬆️ [email protected]

* Update ARM patch

* Update MLIR converter code to TF 2.9

* Fix Java build

Not sure if we should rather set `--incompatible_disallow_resource_jars=false` but this seems to work well.

For more info see bazelbuild/bazel#13221

* ⬆️ [email protected]

* ⬆️ [email protected]

* Add TF 2.9 to CI

* Switch to upstream `manylinux2014` toolchain

* Update lce_benchmark_model CMake target

Co-authored-by: Tom Bannink <[email protected]>
Co-authored-by: Cedric Nugteren <[email protected]>
  • Loading branch information
3 people authored May 19, 2022
1 parent d4b34d5 commit 94b62f6
Show file tree
Hide file tree
Showing 48 changed files with 296 additions and 2,188 deletions.
2 changes: 1 addition & 1 deletion .bazelversion
Original file line number Diff line number Diff line change
@@ -1 +1 @@
4.2.1
5.0.0
4 changes: 2 additions & 2 deletions .github/tools/release_linux.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ bazel build :build_pip_pkg \
--copt=-mavx \
--distinct_host_configuration=false \
--verbose_failures \
--crosstool_top=//third_party/toolchains/gcc7_manylinux2010-nvcc-cuda11:toolchain
--crosstool_top=@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain

# Package Whl
bazel-bin/build_pip_pkg artifacts

# Remove manylinux2010 config flags so that normal builds work as expected
# Remove manylinux2014 config flags so that normal builds work as expected
rm -f .lce_configure.bazelrc
8 changes: 4 additions & 4 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ jobs:
path: wheelhouse

manylinux-release-wheel:
name: Build release wheels for manylinux2010
name: Build release wheels for manylinux2014
runs-on: ubuntu-18.04
strategy:
matrix:
Expand All @@ -228,7 +228,7 @@ jobs:
continue-on-error: true
with:
credentials_json: ${{ secrets.gcs_bazel_cache }}
- name: Build manylinux2010 wheels
- name: Build manylinux2014 wheels
run: |
if [[ -n $GOOGLE_APPLICATION_CREDENTIALS ]]; then
echo -e 'build --remote_http_cache=https://storage.googleapis.com/plumerai-bazel-cache/lce-release-manylinux-python${{ matrix.python-version }}' >> .bazelrc.user
Expand All @@ -239,14 +239,14 @@ jobs:
-e GOOGLE_APPLICATION_CREDENTIALS=/tmp/gcloud-credentials.json \
-v $GOOGLE_APPLICATION_CREDENTIALS:/tmp/gcloud-credentials.json:ro \
-v ${PWD}:/compute-engine -w /compute-engine \
tensorflow/build:2.8-python${{ matrix.python-version }} \
tensorflow/build:2.9-python${{ matrix.python-version }} \
.github/tools/release_linux.sh
sudo apt-get -y -qq install patchelf --no-install-recommends
python -m pip install auditwheel --no-cache-dir
for f in artifacts/*.whl; do
auditwheel repair --plat manylinux2010_x86_64 $f
auditwheel repair --plat manylinux2014_x86_64 $f
done
ls -al wheelhouse/
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ jobs:
if: github.ref != 'refs/heads/main'
shell: bash
- name: Install pip dependencies
run: pip install tensorflow-cpu~=2.8.0 larq~=0.11 larq_zoo~=2.0 pytest tensorflow_datasets~=4.4 flatbuffers==1.12 tqdm --no-cache-dir
run: pip install tensorflow-cpu~=2.9.0 larq~=0.11 larq_zoo~=2.0 pytest tensorflow_datasets~=4.4 flatbuffers==1.12 tqdm --no-cache-dir
- name: Run Interpreter test
run: bazelisk test larq_compute_engine/tflite/tests:interpreter_test --test_output=all
- name: Run FileCheck tests
Expand All @@ -101,7 +101,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
tf-version: [1.14.0, 1.15.5, 2.0.4, 2.1.4, 2.2.3, 2.3.3, 2.4.4, 2.5.3, 2.6.3, 2.7.1, 2.8.0]
tf-version: [1.14.0, 1.15.5, 2.0.4, 2.1.4, 2.2.3, 2.3.3, 2.4.4, 2.5.3, 2.6.4, 2.7.2, 2.8.1, 2.9.0]
if: "!contains(github.event.head_commit.message, 'ci-skip')"
steps:
- uses: actions/checkout@v3
Expand Down
24 changes: 4 additions & 20 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,26 +67,10 @@ if(COMPILE_BENCHMARK)
${LCE_SOURCE_DIR}/tflite/benchmark/lce_benchmark_tflite_model.h
${TFLITE_SOURCE_DIR}/tools/benchmark/benchmark_model.h
)
set(TFLITE_BENCHMARK_SRCS # from ${TFLITE_SOURCE_DIR}/tools/benchmark/CMakeLists.txt
${TENSORFLOW_SOURCE_DIR}/tensorflow/core/util/stats_calculator.cc
${TFLITE_SOURCE_DIR}/kernels/internal/utils/sparsity_format_converter.cc
${TFLITE_SOURCE_DIR}/profiling/memory_info.cc
${TFLITE_SOURCE_DIR}/profiling/memory_usage_monitor.cc
${TFLITE_SOURCE_DIR}/profiling/profile_summarizer.cc
${TFLITE_SOURCE_DIR}/profiling/profile_summary_formatter.cc
${TFLITE_SOURCE_DIR}/profiling/time.cc
${TFLITE_SOURCE_DIR}/tools/command_line_flags.cc
${TFLITE_SOURCE_DIR}/tools/benchmark/benchmark_model.cc
${TFLITE_SOURCE_DIR}/tools/benchmark/benchmark_performance_options.cc
${TFLITE_SOURCE_DIR}/tools/benchmark/benchmark_tflite_model.cc
${TFLITE_SOURCE_DIR}/tools/benchmark/benchmark_utils.cc
${TFLITE_SOURCE_DIR}/tools/benchmark/profiling_listener.cc
${TFLITE_SOURCE_DIR}/tools/delegates/default_execution_provider.cc
${TFLITE_SOURCE_DIR}/tools/delegates/delegate_provider.cc
${TFLITE_SOURCE_DIR}/tools/delegates/xnnpack_delegate_provider.cc
${TFLITE_SOURCE_DIR}/tools/evaluation/utils.cc
${TFLITE_SOURCE_DIR}/tools/tool_params.cc
)

get_directory_property(TFLITE_BENCHMARK_SRCS DIRECTORY ${TFLITE_SOURCE_DIR}/tools/benchmark DEFINITION TFLITE_BENCHMARK_SRCS)
list(FILTER TFLITE_BENCHMARK_SRCS EXCLUDE REGEX benchmark_main.cc)

add_executable(lce_benchmark_model
${TFLITE_BENCHMARK_SRCS}
${LCE_CORE_SRCS} ${LCE_CORE_HDRS}
Expand Down
1 change: 0 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,5 @@ RUN pip install six numpy --no-cache-dir
WORKDIR /compute-engine
COPY . .
RUN ./third_party/install_android.sh
ENV MANYLINUX2010=1
RUN ./configure.py
RUN bazelisk --version
6 changes: 3 additions & 3 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ http_archive(
"//third_party/tensorflow_patches:disable_forced_mkl.patch",
"//third_party/tensorflow_patches:fix_armhf_xnnpack.patch",
],
sha256 = "66b953ae7fba61fd78969a2e24e350b26ec116cf2e6a7eb93d02c63939c6f9f7",
strip_prefix = "tensorflow-2.8.0",
sha256 = "8087cb0c529f04a4bfe480e49925cd64a904ad16d8ec66b98e2aacdfd53c80ff",
strip_prefix = "tensorflow-2.9.0",
urls = [
"https://github.com/tensorflow/tensorflow/archive/v2.8.0.tar.gz",
"https://github.com/tensorflow/tensorflow/archive/v2.9.0.tar.gz",
],
)

Expand Down
16 changes: 8 additions & 8 deletions larq_compute_engine/mlir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ td_library(
srcs = ["transforms/op_removal_patterns.td"],
includes = ["/external/org_tensorflow"],
deps = [
"@llvm-project//mlir:StdOpsTdFiles",
"@llvm-project//mlir:FuncTdFiles",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
],
)
Expand All @@ -43,7 +43,7 @@ td_library(
includes = ["/external/org_tensorflow"],
deps = [
":lce_ops_td_file",
"@llvm-project//mlir:StdOpsTdFiles",
"@llvm-project//mlir:FuncTdFiles",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files",
],
)
Expand All @@ -54,7 +54,7 @@ td_library(
includes = ["/external/org_tensorflow"],
deps = [
":lce_ops_td_file",
"@llvm-project//mlir:StdOpsTdFiles",
"@llvm-project//mlir:FuncTdFiles",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
],
Expand Down Expand Up @@ -182,7 +182,7 @@ gentbl_cc_library(
td_file = "transforms/bitpack_activations_patterns.td",
deps = [
":lce_ops_td_file",
"@llvm-project//mlir:StdOpsTdFiles",
"@llvm-project//mlir:FuncTdFiles",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files",
],
)
Expand All @@ -199,7 +199,7 @@ gentbl_cc_library(
td_file = "transforms/bitpack_weights_patterns.td",
deps = [
":lce_ops_td_file",
"@llvm-project//mlir:StdOpsTdFiles",
"@llvm-project//mlir:FuncTdFiles",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
],
Expand Down Expand Up @@ -288,7 +288,7 @@ cc_library(
"transforms/passes.h",
],
deps = [
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:FuncDialect",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow",
],
alwayslink = 1,
Expand All @@ -308,7 +308,7 @@ cc_library(
deps = [
":larq_compute_engine",
"//larq_compute_engine/core:types",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:FuncDialect",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf",
"@org_tensorflow//tensorflow/compiler/mlir/lite:validators",
Expand Down Expand Up @@ -429,7 +429,7 @@ cc_library(
"transforms/passes.h",
],
deps = [
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:FuncDialect",
],
alwayslink = 1,
)
Expand Down
2 changes: 1 addition & 1 deletion larq_compute_engine/mlir/ir/lce_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def LarqDialect : Dialect {
//===----------------------------------------------------------------------===//

// Base class for the operation in this dialect
class LQ_Op<string mnemonic, list<OpTrait> traits = []> :
class LQ_Op<string mnemonic, list<Trait> traits = []> :
Op<LarqDialect, mnemonic, traits> {

let extraClassDeclaration = [{
Expand Down
11 changes: 5 additions & 6 deletions larq_compute_engine/mlir/lce_mlir_opt.cc
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
#include "larq_compute_engine/mlir/ir/lce_ops.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Support/MlirOptMain.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "mlir/Transforms/Passes.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"

int main(int argc, char** argv) {
mlir::registerTransformsPasses();
mlir::DialectRegistry registry;
registry.insert<mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect,
registry.insert<mlir::arith::ArithmeticDialect, mlir::func::FuncDialect,
mlir::quant::QuantizationDialect, mlir::TF::TensorFlowDialect,
mlir::TFL::TensorFlowLiteDialect, mlir::lq::LarqDialect>();
return failed(mlir::MlirOptMain(argc, argv,
"Larq Compute Engine pass driver\n", registry,
/*preloadDialectsInContext=*/false));
return failed(mlir::MlirOptMain(
argc, argv, "Larq Compute Engine pass driver\n", registry));
}
11 changes: 6 additions & 5 deletions larq_compute_engine/mlir/python/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ LCETarget GetLCETarget(const std::string& target_str) {
}
}

Status GetNumInputs(mlir::OwningModuleRef* module, int* num_inputs) {
Status GetNumInputs(mlir::OwningOpRef<mlir::ModuleOp>* module,
int* num_inputs) {
*num_inputs = 0;
mlir::FuncOp entry_function = nullptr;
for (auto func : module->get().getOps<mlir::FuncOp>()) {
mlir::func::FuncOp entry_function = nullptr;
for (auto func : module->get().getOps<mlir::func::FuncOp>()) {
if (auto tf_attrs =
func->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function")) {
// TODO(jaesung): There could be multiple entry functions. Let's handle
Expand Down Expand Up @@ -70,13 +71,13 @@ Status GetNumInputs(mlir::OwningModuleRef* module, int* num_inputs) {
}

pybind11::bytes ConvertMLIRModuleToTFLiteFlatBuffer(
mlir::OwningModuleRef* module, mlir::MLIRContext& context,
mlir::OwningOpRef<mlir::ModuleOp>* module, mlir::MLIRContext& context,
const LCETarget target, const pybind11::object& default_ranges,
const std::unordered_set<std::string>& saved_model_tags,
llvm::StringRef saved_model_dir,
llvm::Optional<tensorflow::Session*> session, const int num_inputs,
const bool should_quantize, const bool mark_as_post_training_quant) {
mlir::TFL::QuantizationSpecs quant_specs;
mlir::quant::QuantizationSpecs quant_specs;
if (should_quantize) {
// Normally we'd only set `inference_type` to QINT8 when there are
// fake_quant nodes in the graph. However this did not work reliably, and
Expand Down
4 changes: 2 additions & 2 deletions larq_compute_engine/mlir/python/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ namespace tensorflow {

LCETarget GetLCETarget(const std::string& target_str);

Status GetNumInputs(mlir::OwningModuleRef* module, int* num_inputs);
Status GetNumInputs(mlir::OwningOpRef<mlir::ModuleOp>* module, int* num_inputs);

pybind11::bytes ConvertMLIRModuleToTFLiteFlatBuffer(
mlir::OwningModuleRef* module, mlir::MLIRContext& context,
mlir::OwningOpRef<mlir::ModuleOp>* module, mlir::MLIRContext& context,
const LCETarget target, const pybind11::object& default_ranges,
const std::unordered_set<std::string>& saved_model_tags,
llvm::StringRef saved_model_dir,
Expand Down
2 changes: 1 addition & 1 deletion larq_compute_engine/mlir/tests/bitpack-weights.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: lce-tf-opt %s -tfl-lce-bitpack-weights -verify-diagnostics | FileCheck %s

// CHECK-LABEL: @bitpack_bconv2d_filters
func @bitpack_bconv2d_filters(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: none) -> tensor<256x30x30x16xf32> {
func.func @bitpack_bconv2d_filters(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: none) -> tensor<256x30x30x16xf32> {
%cst = arith.constant dense<1.0> : tensor<16x3x3x3xf32>
%0 = "lq.Bconv2d"(%arg0, %cst, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32>
return %0 : tensor<256x30x30x16xf32>
Expand Down
4 changes: 2 additions & 2 deletions larq_compute_engine/mlir/tests/const-fold.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: lce-tf-opt %s -canonicalize | FileCheck %s

// CHECK-LABEL: @quantize
func @quantize() -> (tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32>) {
func.func @quantize() -> (tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32>) {
%pos = arith.constant dense< 0.5> : tensor<1x1x2x32xf32>
%neg = arith.constant dense<-0.5> : tensor<1x1x2x32xf32>
%0 = "lq.Quantize"(%pos) {} : (tensor<1x1x2x32xf32>) -> tensor<1x1x2x1xi32>
Expand All @@ -14,7 +14,7 @@ func @quantize() -> (tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32>) {
}

// CHECK-LABEL: @dequantize
func @dequantize() -> (tensor<1x1x2x32xf32>, tensor<1x1x2x32xf32>) {
func.func @dequantize() -> (tensor<1x1x2x32xf32>, tensor<1x1x2x32xf32>) {
%pos = arith.constant dense<0> : tensor<1x1x2x1xi32>
%neg = arith.constant dense<-1> : tensor<1x1x2x1xi32>
%0 = "lq.Dequantize"(%pos) {} : (tensor<1x1x2x1xi32>) -> tensor<1x1x2x32xf32>
Expand Down
12 changes: 6 additions & 6 deletions larq_compute_engine/mlir/tests/fuse_padding.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: lce-tf-opt %s -tfl-fuse-padding -verify-diagnostics | FileCheck %s

// CHECK-LABEL: @fuse_pad_into_conv_valid
func @fuse_pad_into_conv_valid(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64x16xf32> {
func.func @fuse_pad_into_conv_valid(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64x16xf32> {
%cst0 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>
%cst1 = arith.constant dense<1.0> : tensor<16x3x3x8xf32>
%cst2 = arith.constant dense<1.0> : tensor<16xf32>
Expand All @@ -14,7 +14,7 @@ func @fuse_pad_into_conv_valid(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64x1
}

// CHECK-LABEL: @fuse_padv2_into_conv_valid
func @fuse_padv2_into_conv_valid(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64x16xf32> {
func.func @fuse_padv2_into_conv_valid(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64x16xf32> {
%cst0 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>
%cst1 = arith.constant dense<0.0> : tensor<f32>
%cst2 = arith.constant dense<1.0> : tensor<16x3x3x8xf32>
Expand All @@ -28,7 +28,7 @@ func @fuse_padv2_into_conv_valid(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64
}

// CHECK-LABEL: @fuse_pad_into_dwconv_valid
func @fuse_pad_into_dwconv_valid(%arg0: tensor<1x64x64x16xf32>) -> tensor<1x64x64x16xf32> {
func.func @fuse_pad_into_dwconv_valid(%arg0: tensor<1x64x64x16xf32>) -> tensor<1x64x64x16xf32> {
%cst0 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>
%cst1 = arith.constant dense<1.0> : tensor<1x3x3x16xf32>
%cst2 = arith.constant dense<1.0> : tensor<16xf32>
Expand All @@ -41,7 +41,7 @@ func @fuse_pad_into_dwconv_valid(%arg0: tensor<1x64x64x16xf32>) -> tensor<1x64x6
}

// CHECK-LABEL: @do_not_fuse_padv2_into_conv_wrong_pad_value
func @do_not_fuse_padv2_into_conv_wrong_pad_value(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64x16xf32> {
func.func @do_not_fuse_padv2_into_conv_wrong_pad_value(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64x16xf32> {
%cst0 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>
%cst1 = arith.constant dense<1.0> : tensor<f32>
%cst2 = arith.constant dense<1.0> : tensor<16x3x3x8xf32>
Expand All @@ -54,7 +54,7 @@ func @do_not_fuse_padv2_into_conv_wrong_pad_value(%arg0: tensor<1x64x64x8xf32>)
}

// CHECK-LABEL: @do_not_fuse_pad_into_conv_same
func @do_not_fuse_pad_into_conv_same(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x66x66x16xf32> {
func.func @do_not_fuse_pad_into_conv_same(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x66x66x16xf32> {
%cst0 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>
%cst1 = arith.constant dense<1.0> : tensor<f32>
%cst2 = arith.constant dense<1.0> : tensor<16x3x3x8xf32>
Expand All @@ -67,7 +67,7 @@ func @do_not_fuse_pad_into_conv_same(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x6
}

// CHECK-LABEL: @do_not_fuse_pad_into_dwconv_channelpad
func @do_not_fuse_pad_into_dwconv_channelpad(%arg0: tensor<1x64x64x12xf32>) -> tensor<1x64x64x16xf32> {
func.func @do_not_fuse_pad_into_dwconv_channelpad(%arg0: tensor<1x64x64x12xf32>) -> tensor<1x64x64x16xf32> {
%cst0 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [1, 3]]> : tensor<4x2xi32>
%cst1 = arith.constant dense<1.0> : tensor<1x3x3x16xf32>
%cst2 = arith.constant dense<1.0> : tensor<16xf32>
Expand Down
8 changes: 4 additions & 4 deletions larq_compute_engine/mlir/tests/legalize-lce.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// RUN: lce-tf-opt %s -tfl-legalize-lce -lce-translate-tfl -verify-diagnostics | FileCheck %s --check-prefix=TRANSLATE

// CHECK-LABEL: @legalize_bconv2d
func @legalize_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: none) -> tensor<256x30x30x16xf32> {
func.func @legalize_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: none) -> tensor<256x30x30x16xf32> {
%0 = "lq.Bconv2d"(%arg0, %arg1, %arg2, %arg3, %arg4) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32>
return %0 : tensor<256x30x30x16xf32>

Expand All @@ -14,7 +14,7 @@ func @legalize_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf3
}

// CHECK-LABEL: @legalize_bmax_pool2d
func @legalize_bmax_pool2d(%arg0: tensor<256x32x32x3xi32>) -> tensor<256x16x16x3xi32> {
func.func @legalize_bmax_pool2d(%arg0: tensor<256x32x32x3xi32>) -> tensor<256x16x16x3xi32> {
%0 = "lq.BMaxPool2d"(%arg0) {filter_height = 2 : i32, filter_width = 2 : i32, padding = "SAME", stride_height = 2 : i32, stride_width = 2 : i32} : (tensor<256x32x32x3xi32>) -> tensor<256x16x16x3xi32>
return %0 : tensor<256x16x16x3xi32>

Expand All @@ -26,7 +26,7 @@ func @legalize_bmax_pool2d(%arg0: tensor<256x32x32x3xi32>) -> tensor<256x16x16x3
}

// CHECK-LABEL: @legalize_quantize
func @legalize_quantize(%arg0: tensor<256x32x32x64xf32>) -> tensor<256x32x32x2xi32> {
func.func @legalize_quantize(%arg0: tensor<256x32x32x64xf32>) -> tensor<256x32x32x2xi32> {
%0 = "lq.Quantize"(%arg0) {} : (tensor<256x32x32x64xf32>) -> tensor<256x32x32x2xi32>
return %0 : tensor<256x32x32x2xi32>

Expand All @@ -38,7 +38,7 @@ func @legalize_quantize(%arg0: tensor<256x32x32x64xf32>) -> tensor<256x32x32x2xi
}

// CHECK-LABEL: @legalize_dequantize
func @legalize_dequantize(%arg0: tensor<256x32x32x2xi32>) -> tensor<256x32x32x64xf32> {
func.func @legalize_dequantize(%arg0: tensor<256x32x32x2xi32>) -> tensor<256x32x32x64xf32> {
%0 = "lq.Dequantize"(%arg0) {} : (tensor<256x32x32x2xi32>) -> tensor<256x32x32x64xf32>
return %0 : tensor<256x32x32x64xf32>

Expand Down
Loading

0 comments on commit 94b62f6

Please sign in to comment.