diff --git a/.github/workflows/ci_linux_arm64_clang.yml b/.github/workflows/ci_linux_arm64_clang.yml index 0972afe3a2c6..2229b4a9b656 100644 --- a/.github/workflows/ci_linux_arm64_clang.yml +++ b/.github/workflows/ci_linux_arm64_clang.yml @@ -63,7 +63,7 @@ jobs: run: ./build_tools/cmake/test_iree_dialects.sh "${BUILD_DIR}" - name: Post to Discord on Failure - uses: sarisia/actions-status-discord@b8381b25576cb341b2af39926ab42c5056cc44ed # v1.15.5 + uses: sarisia/actions-status-discord@eb045afee445dc055c18d3d90bd0f244fd062708 # v1.16.0 if: failure() && github.ref_name == 'main' && github.repository_owner == 'iree-org' with: webhook: ${{ secrets.DISCORD_WEBHOOK }} diff --git a/.github/workflows/ci_linux_x64_clang_byollvm.yml b/.github/workflows/ci_linux_x64_clang_byollvm.yml index a322e9c83891..e866c3262620 100644 --- a/.github/workflows/ci_linux_x64_clang_byollvm.yml +++ b/.github/workflows/ci_linux_x64_clang_byollvm.yml @@ -30,7 +30,7 @@ jobs: run: ./build_tools/cmake/build_and_test_byo_llvm.sh - name: Post to Discord on Failure - uses: sarisia/actions-status-discord@b8381b25576cb341b2af39926ab42c5056cc44ed # v1.15.5 + uses: sarisia/actions-status-discord@eb045afee445dc055c18d3d90bd0f244fd062708 # v1.16.0 if: failure() && github.ref_name == 'main' && github.repository_owner == 'iree-org' with: webhook: ${{ secrets.DISCORD_WEBHOOK }} diff --git a/.github/workflows/ci_linux_x64_clang_debug.yml b/.github/workflows/ci_linux_x64_clang_debug.yml index 33582148d477..493379d08c2d 100644 --- a/.github/workflows/ci_linux_x64_clang_debug.yml +++ b/.github/workflows/ci_linux_x64_clang_debug.yml @@ -47,7 +47,7 @@ jobs: # would add 10+ minutes to the job. - name: Post to Discord on Failure - uses: sarisia/actions-status-discord@b8381b25576cb341b2af39926ab42c5056cc44ed # v1.15.5 + uses: sarisia/actions-status-discord@eb045afee445dc055c18d3d90bd0f244fd062708 # v1.16.0 if: failure() && github.ref_name == 'main' && github.repository_owner == 'iree-org' with: webhook: ${{ secrets.DISCORD_WEBHOOK }} diff --git a/.github/workflows/ci_linux_x64_clang_tsan.yml b/.github/workflows/ci_linux_x64_clang_tsan.yml index a1dbe97509f5..2b1a2a6db568 100644 --- a/.github/workflows/ci_linux_x64_clang_tsan.yml +++ b/.github/workflows/ci_linux_x64_clang_tsan.yml @@ -50,7 +50,7 @@ jobs: sccache --show-stats - name: Post to Discord on Failure - uses: sarisia/actions-status-discord@b8381b25576cb341b2af39926ab42c5056cc44ed # v1.15.5 + uses: sarisia/actions-status-discord@eb045afee445dc055c18d3d90bd0f244fd062708 # v1.16.0 if: failure() && github.ref_name == 'main' && github.repository_owner == 'iree-org' with: webhook: ${{ secrets.DISCORD_WEBHOOK }} diff --git a/.github/workflows/ci_linux_x64_clang_ubsan.yml b/.github/workflows/ci_linux_x64_clang_ubsan.yml index 956ab5fe18b9..b245e1d9d9b0 100644 --- a/.github/workflows/ci_linux_x64_clang_ubsan.yml +++ b/.github/workflows/ci_linux_x64_clang_ubsan.yml @@ -4,7 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -name: CI - Linux x64 clang UBSan +name: CI - Linux x64 clang UBSan and Reverse Iteration on: workflow_call: @@ -29,6 +29,9 @@ jobs: # Use a modern clang explicitly. CC: clang-19 CXX: clang++-19 + # Enable reverse iteration of unordered LLVM containers. This helps + # catch non-determinism bugs. + IREE_REVERSE_ITERATE: "ON" SCCACHE_AZURE_CONNECTION_STRING: "${{ secrets.AZURE_CCACHE_CONNECTION_STRING }}" SCCACHE_AZURE_BLOB_CONTAINER: ccache-container SCCACHE_CACHE_ZSTD_LEVEL: 10 diff --git a/.github/workflows/ci_linux_x64_gcc.yml b/.github/workflows/ci_linux_x64_gcc.yml index d5600d5bc250..d453bbbd88d7 100644 --- a/.github/workflows/ci_linux_x64_gcc.yml +++ b/.github/workflows/ci_linux_x64_gcc.yml @@ -38,7 +38,7 @@ jobs: run: ./build_tools/cmake/build_all.sh "${BUILD_DIR}" - name: Post to Discord on Failure - uses: sarisia/actions-status-discord@b8381b25576cb341b2af39926ab42c5056cc44ed # v1.15.5 + uses: sarisia/actions-status-discord@eb045afee445dc055c18d3d90bd0f244fd062708 # v1.16.0 if: failure() && github.ref_name == 'main' && github.repository_owner == 'iree-org' with: webhook: ${{ secrets.DISCORD_WEBHOOK }} diff --git a/.github/workflows/ci_macos_arm64_clang.yml b/.github/workflows/ci_macos_arm64_clang.yml index e924f9fa5ca9..d18a8bb532bc 100644 --- a/.github/workflows/ci_macos_arm64_clang.yml +++ b/.github/workflows/ci_macos_arm64_clang.yml @@ -60,7 +60,7 @@ jobs: run: bash ./build_tools/cmake/build_all.sh "${BUILD_DIR}" - name: Post to Discord on Failure - uses: sarisia/actions-status-discord@b8381b25576cb341b2af39926ab42c5056cc44ed # v1.15.5 + uses: sarisia/actions-status-discord@eb045afee445dc055c18d3d90bd0f244fd062708 # v1.16.0 if: failure() && github.ref_name == 'main' && github.repository_owner == 'iree-org' with: webhook: ${{ secrets.DISCORD_WEBHOOK }} diff --git a/.github/workflows/ci_macos_x64_clang.yml b/.github/workflows/ci_macos_x64_clang.yml index d6aa3fa29e06..722de844ebb7 100644 --- a/.github/workflows/ci_macos_x64_clang.yml +++ b/.github/workflows/ci_macos_x64_clang.yml @@ -53,7 +53,7 @@ jobs: run: bash ./build_tools/cmake/ctest_all.sh "${BUILD_DIR}" - name: Post to Discord on Failure - uses: sarisia/actions-status-discord@b8381b25576cb341b2af39926ab42c5056cc44ed # v1.15.5 + uses: sarisia/actions-status-discord@eb045afee445dc055c18d3d90bd0f244fd062708 # v1.16.0 if: failure() && github.ref_name == 'main' && github.repository_owner == 'iree-org' with: webhook: ${{ secrets.DISCORD_WEBHOOK }} diff --git a/.github/workflows/pkgci.yml b/.github/workflows/pkgci.yml index 48013b8e6757..8ebe09c4b330 100644 --- a/.github/workflows/pkgci.yml +++ b/.github/workflows/pkgci.yml @@ -64,6 +64,12 @@ jobs: if: contains(fromJson(needs.setup.outputs.enabled-jobs), 'test_amd_w7900') uses: ./.github/workflows/pkgci_test_amd_w7900.yml + test_amd_r9700: + name: Test AMD R9700 + needs: [setup, build_packages] + if: contains(fromJson(needs.setup.outputs.enabled-jobs), 'test_amd_r9700') + uses: ./.github/workflows/pkgci_test_amd_r9700.yml + # TODO(#18238): migrate to new runner cluster # test_nvidia_t4: # name: Test NVIDIA T4 @@ -135,6 +141,7 @@ jobs: - test_amd_mi250 - test_amd_mi325 - test_amd_w7900 + - test_amd_r9700 # - test_nvidia_t4 - test_android - test_riscv64 diff --git a/.github/workflows/pkgci_test_amd_r9700.yml b/.github/workflows/pkgci_test_amd_r9700.yml new file mode 100644 index 000000000000..736769da82f7 --- /dev/null +++ b/.github/workflows/pkgci_test_amd_r9700.yml @@ -0,0 +1,67 @@ +# Copyright 2026 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +name: PkgCI Test AMD R9700 +on: + workflow_call: + inputs: + artifact_run_id: + type: string + default: "" + workflow_dispatch: + inputs: + artifact_run_id: + type: string + default: "" + +jobs: + test_r9700: + runs-on: [Linux, X64, iree-r9700] + env: + PACKAGE_DOWNLOAD_DIR: ${{ github.workspace }}/.packages + BUILD_DIR: build-tests + VENV_DIR: ${{ github.workspace }}/.venv + GH_TOKEN: ${{ github.token }} + IREE_CPU_DISABLE: 1 + IREE_VULKAN_DISABLE: 0 + IREE_CUDA_ENABLE: 0 + IREE_HIP_ENABLE: 1 + IREE_HIP_TEST_TARGET_CHIP: "gfx1201" + steps: + - name: Check out repository + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + submodules: false + - name: Check out runtime submodules + run: ./build_tools/scripts/git/update_runtime_submodules.sh + - uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 + with: + # Must match the subset of versions built in pkgci_build_packages. + python-version: "3.11" + - uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + if: ${{ inputs.artifact_run_id == '' }} + with: + name: linux_x86_64_release_packages + path: ${{ env.PACKAGE_DOWNLOAD_DIR }} + - name: Setup base venv + run: | + ./build_tools/pkgci/setup_venv.py ${VENV_DIR} \ + --artifact-path=${PACKAGE_DOWNLOAD_DIR} \ + --fetch-gh-workflow=${{ inputs.artifact_run_id }} + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Build tests + run: ./build_tools/pkgci/build_tests_using_package.sh ${VENV_DIR}/bin + - name: Run GPU tests + env: + CTEST_PARALLEL_LEVEL: 1 + IREE_CTEST_LABEL_REGEX: ^requires-gpu|^driver=vulkan$|^driver=hip$ + IREE_AMD_RDNA4_TESTS_DISABLE: 0 + IREE_NVIDIA_GPU_TESTS_DISABLE: 0 + IREE_NVIDIA_SM80_TESTS_DISABLE: 1 + IREE_MULTI_DEVICE_TESTS_DISABLE: 0 + run: ./build_tools/cmake/ctest_all.sh ${BUILD_DIR} diff --git a/.github/workflows/pkgci_test_onnx.yml b/.github/workflows/pkgci_test_onnx.yml index 08293275f575..1c2dc7fcd399 100644 --- a/.github/workflows/pkgci_test_onnx.yml +++ b/.github/workflows/pkgci_test_onnx.yml @@ -47,6 +47,11 @@ jobs: numprocesses: 1 config-file: onnx_ops_gpu_hip_rdna3_O3.json runs-on: [Linux, X64, gfx1100] + # TODO(#23160): Fix the onnx ops test suite for gfx1201. + # - name: amdgpu_hip_rdna4_O3 + # numprocesses: 1 + # config-file: onnx_ops_gpu_hip_rdna4_O3.json + # runs-on: [Linux, X64, gfx1201] - name: amdgpu_vulkan_O0 numprocesses: 1 config-file: onnx_ops_gpu_vulkan_O0.json @@ -103,7 +108,7 @@ jobs: uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: repository: iree-org/iree-test-suites - ref: 17a391dc3882f136e567bf2687806ef6af46ad64 + ref: 17ead09be6d84bf46d80e6192dc12e45ba776045 path: iree-test-suites - name: Install ONNX ops test suite requirements run: | @@ -154,6 +159,9 @@ jobs: - name: amdgpu_hip_rdna3 config-file: onnx_models_gpu_hip_rdna3.json runs-on: [Linux, X64, gfx1100, persistent-cache] + - name: amdgpu_hip_rdna4 + config-file: onnx_models_gpu_hip_rdna4.json + runs-on: [Linux, X64, gfx1201, persistent-cache] - name: amdgpu_vulkan config-file: onnx_models_gpu_vulkan.json # TODO(#22579): Remove `shark10-ci` label. There are vulkan driver issues on other runners. @@ -189,7 +197,7 @@ jobs: uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: repository: iree-org/iree-test-suites - ref: 17a391dc3882f136e567bf2687806ef6af46ad64 + ref: 17ead09be6d84bf46d80e6192dc12e45ba776045 path: iree-test-suites - name: Install ONNX models test suite requirements run: | diff --git a/.github/workflows/pkgci_test_sharktank.yml b/.github/workflows/pkgci_test_sharktank.yml index f905e894e6f6..031327dd2a4e 100644 --- a/.github/workflows/pkgci_test_sharktank.yml +++ b/.github/workflows/pkgci_test_sharktank.yml @@ -88,7 +88,7 @@ jobs: uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: repository: iree-org/iree-test-suites - ref: 17a391dc3882f136e567bf2687806ef6af46ad64 + ref: 17ead09be6d84bf46d80e6192dc12e45ba776045 path: iree-test-suites lfs: true @@ -197,7 +197,7 @@ jobs: uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: repository: iree-org/iree-test-suites - ref: 17a391dc3882f136e567bf2687806ef6af46ad64 + ref: 17ead09be6d84bf46d80e6192dc12e45ba776045 path: iree-test-suites lfs: true diff --git a/.github/workflows/pkgci_test_torch.yml b/.github/workflows/pkgci_test_torch.yml index a98d7b924c80..d58bf432be78 100644 --- a/.github/workflows/pkgci_test_torch.yml +++ b/.github/workflows/pkgci_test_torch.yml @@ -37,6 +37,9 @@ jobs: - name: amdgpu_hip_gfx1100_O3 config-file: torch_ops_gpu_hip_gfx1100_O3.json runs-on: [Linux, X64, gfx1100] + - name: amdgpu_hip_gfx1201_O3 + config-file: torch_ops_gpu_hip_gfx1201_O3.json + runs-on: [Linux, X64, gfx1201] - name: amdgpu_vulkan_O3 config-file: torch_ops_gpu_vulkan_O3.json # TODO(#22579): Remove `shark10-ci` label. There are vulkan driver issues on other runners. @@ -74,7 +77,7 @@ jobs: uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: repository: iree-org/iree-test-suites - ref: 132f91e49d629c35f98492a9f619017b83782aba + ref: 17ead09be6d84bf46d80e6192dc12e45ba776045 path: iree-test-suites - name: Install Torch ops test suite requirements run: | @@ -138,7 +141,7 @@ jobs: uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: repository: iree-org/iree-test-suites - ref: 17a391dc3882f136e567bf2687806ef6af46ad64 + ref: 17ead09be6d84bf46d80e6192dc12e45ba776045 path: iree-test-suites # Don't need lfs for torch models yet. lfs: false diff --git a/.github/workflows/workflow_summary.yml b/.github/workflows/workflow_summary.yml index 2a184b66a0ff..23ba06067479 100644 --- a/.github/workflows/workflow_summary.yml +++ b/.github/workflows/workflow_summary.yml @@ -55,7 +55,7 @@ jobs: exit 1 fi - name: Post to Discord on Failure - uses: sarisia/actions-status-discord@b8381b25576cb341b2af39926ab42c5056cc44ed # v1.15.5 + uses: sarisia/actions-status-discord@eb045afee445dc055c18d3d90bd0f244fd062708 # v1.16.0 if: failure() && github.ref_name == 'main' && github.repository_owner == 'iree-org' with: webhook: ${{ secrets.DISCORD_WEBHOOK }} diff --git a/.gitignore b/.gitignore index 98a5c89a6f3f..b7da4ef62955 100644 --- a/.gitignore +++ b/.gitignore @@ -57,6 +57,7 @@ imgui.ini # Source indexing files compile_commands.json +tablegen_compile_commands.yml .cache/clangd # Language server configuration files diff --git a/CMakeLists.txt b/CMakeLists.txt index 2ad29f2646ed..1d15a10b04bc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -523,11 +523,13 @@ option(IREE_ENABLE_ASAN "Enable address sanitizer" OFF) option(IREE_ENABLE_MSAN "Enable memory sanitizer" OFF) option(IREE_ENABLE_TSAN "Enable thread sanitizer" OFF) option(IREE_ENABLE_UBSAN "Enable undefined behavior sanitizer" OFF) +option(IREE_ENABLE_FUZZING "Enable libFuzzer-based fuzz targets" OFF) option(IREE_ENABLE_SPLIT_DWARF "Enable gsplit-dwarf for debug information if the platform supports it" OFF) option(IREE_ENABLE_THIN_ARCHIVES "Enables thin ar archives (elf systems only). Disable for released static archives" OFF) option(IREE_LINK_COMPILER_SHARED_LIBRARY "Links IREE tools using the compiler compiled into a shared library" ON) option(IREE_ENABLE_WERROR_FLAG "Enable `-Werror` flag, treat error as warning" ON) option(IREE_ENABLE_POSITION_INDEPENDENT_CODE "Enable position independent code" TRUE) +option(IREE_REVERSE_ITERATION "Reverse iteration over in unordered LLVM containers" OFF) if(IREE_LINK_COMPILER_SHARED_LIBRARY AND IREE_ENABLE_COMPILER_TRACING) message(SEND_ERROR @@ -569,6 +571,10 @@ if(IREE_ENABLE_RUNTIME_COVERAGE AND NOT _UPPERCASE_CMAKE_BUILD_TYPE STREQUAL "DE message(FATAL_ERROR "IREE_ENABLE_*_COVERAGE requires building in Debug") endif() +if(IREE_REVERSE_ITERATION) + set(LLVM_ENABLE_REVERSE_ITERATION ON CACHE BOOL "" FORCE) +endif() + #------------------------------------------------------------------------------- # IREE assertions # We don't love the way this is done, but we have to line it up with how LLVM @@ -629,6 +635,7 @@ include(iree_copts) include(iree_cc_binary) include(iree_cc_library) include(iree_cc_test) +include(iree_cc_fuzz) include(iree_import_binary) include(iree_install_support) include(iree_external_cmake_options) diff --git a/build_tools/bazel/build_defs.oss.bzl b/build_tools/bazel/build_defs.oss.bzl index 6cf23934c622..c8f3c9c2cc48 100644 --- a/build_tools/bazel/build_defs.oss.bzl +++ b/build_tools/bazel/build_defs.oss.bzl @@ -4,11 +4,23 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "gentbl_filegroup", "td_library") + +# All load statements must come first in Starlark. load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_test") +load( + "//build_tools/bazel:iree_cc_fuzz.bzl", + _iree_cc_fuzz = "iree_cc_fuzz", + _iree_compiler_cc_fuzz = "iree_compiler_cc_fuzz", + _iree_runtime_cc_fuzz = "iree_runtime_cc_fuzz", +) """Common Bazel definitions for IREE.""" -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "gentbl_filegroup", "td_library") +# Re-export fuzz rules for external use. +iree_cc_fuzz = _iree_cc_fuzz +iree_compiler_cc_fuzz = _iree_compiler_cc_fuzz +iree_runtime_cc_fuzz = _iree_runtime_cc_fuzz def defaulting_select(selector): """Pass through to select() with special semantics when converting to CMake. diff --git a/build_tools/bazel/iree.bazelrc b/build_tools/bazel/iree.bazelrc index 60bfb010cc7e..2369df478c8e 100644 --- a/build_tools/bazel/iree.bazelrc +++ b/build_tools/bazel/iree.bazelrc @@ -178,6 +178,17 @@ build:msvc_release --compilation_mode=opt # https://github.com/google/sanitizers/wiki/AddressSanitizer ############################################################################### +# Don't strip debug info +build:sanitizer --strip=never +# Ignore settings of `linkopts = ["-static"]` which can screw up the sanitizer. +# We don't use this in IREE (that's what linkstatic is for), but it could show +# up in dependencies. +build:sanitizer --force_ignore_dash_static +# sanitizer tests tend to take longer, so increase the timeouts +build:sanitizer --test_timeout=120,600,1800,-1 +# Get better stack traces +build:sanitizer --copt=-fno-omit-frame-pointer + # ASAN (address sanitizer) # https://clang.llvm.org/docs/AddressSanitizer.html build:asan --config=sanitizer @@ -216,16 +227,14 @@ build:ubsan --linkopt=-fsanitize=undefined build:ubsan --linkopt=-lubsan build:ubsan --cc_output_directory_tag=ubsan -# Don't strip debug info -build:sanitizer --strip=never -# Ignore settings of `linkopts = ["-static"]` which can screw up the sanitizer. -# We don't use this in IREE (that's what linkstatic is for), but it could show -# up in dependencies. -build:sanitizer --force_ignore_dash_static -# sanitizer tests tend to take longer, so increase the timeouts -build:sanitizer --test_timeout=120,600,1800,-1 -# Get better stack traces -build:sanitizer --copt=-fno-omit-frame-pointer +# Fuzzer (libFuzzer) configuration +# https://llvm.org/docs/LibFuzzer.html +# Includes ASAN by default - there's no reason to fuzz without memory sanitization. +build:fuzzer --config=asan +build:fuzzer --copt=-fsanitize=fuzzer-no-link +build:fuzzer --linkopt=-fsanitize=fuzzer-no-link +build:fuzzer --cc_output_directory_tag=asan-fuzzer +build:fuzzer --copt=-DFUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION ############################################################################### # Architecture specific options diff --git a/build_tools/bazel/iree_cc_fuzz.bzl b/build_tools/bazel/iree_cc_fuzz.bzl new file mode 100644 index 000000000000..d50e65033207 --- /dev/null +++ b/build_tools/bazel/iree_cc_fuzz.bzl @@ -0,0 +1,109 @@ +# Copyright 2026 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Macros for defining libFuzzer-based fuzz targets. + +Fuzz targets require --config=fuzzer to build properly. The config instruments +all code for coverage feedback and adds appropriate compile/link flags. + +Example usage: + load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_fuzz") + + iree_runtime_cc_fuzz( + name = "unicode_fuzz", + srcs = ["unicode_fuzz.cc"], + deps = [":unicode"], + ) + +Building and running: + bazel build --config=fuzzer //path/to:unicode_fuzz + ./bazel-bin/path/to/unicode_fuzz corpus/ -max_total_time=60 +""" + +def iree_cc_fuzz( + name, + srcs, + deps = None, + data = None, + copts = None, + defines = None, + linkopts = None, + tags = None, + **kwargs): + """Creates a libFuzzer-based fuzz target. + + Args: + name: Target name (e.g., "unicode_fuzz"). + srcs: Source files containing LLVMFuzzerTestOneInput(). + deps: Library dependencies. + data: Data file dependencies. + copts: Additional compile options. + defines: Preprocessor definitions. + linkopts: Additional link options. + tags: Target tags. "fuzz" tag is added automatically. + **kwargs: Additional cc_binary attributes. + """ + if deps == None: + deps = [] + if data == None: + data = [] + if copts == None: + copts = [] + if defines == None: + defines = [] + if linkopts == None: + linkopts = [] + if tags == None: + tags = [] + + # Add "fuzz" tag if not present. + if "fuzz" not in tags: + tags = tags + ["fuzz"] + + native.cc_binary( + name = name, + srcs = srcs, + deps = deps, + data = data, + copts = copts, + defines = defines + ["FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION"], + linkopts = linkopts + ["-fsanitize=fuzzer"], + tags = tags, + testonly = True, + **kwargs + ) + +def iree_runtime_cc_fuzz(deps = None, **kwargs): + """Fuzz target for runtime code using libFuzzer. + + Wraps iree_cc_fuzz and adds //runtime/src:runtime_defines dependency. + + Args: + deps: Library dependencies (runtime_defines added automatically). + **kwargs: Additional arguments passed to iree_cc_fuzz. + """ + if deps == None: + deps = [] + iree_cc_fuzz( + deps = deps + ["//runtime/src:runtime_defines"], + **kwargs + ) + +def iree_compiler_cc_fuzz(deps = None, **kwargs): + """Fuzz target for compiler code using libFuzzer. + + Wraps iree_cc_fuzz and adds //compiler/src:defs dependency. + + Args: + deps: Library dependencies (compiler defs added automatically). + **kwargs: Additional arguments passed to iree_cc_fuzz. + """ + if deps == None: + deps = [] + iree_cc_fuzz( + deps = deps + ["//compiler/src:defs"], + **kwargs + ) diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py index 39d46b384829..9ef6efdc223d 100644 --- a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py +++ b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py @@ -542,6 +542,48 @@ def cc_binary( f")\n\n" ) + def iree_cc_fuzz( + self, + name, + srcs=None, + data=None, + deps=None, + copts=None, + defines=None, + linkopts=None, + tags=None, + **kwargs, + ): + if self._should_skip_target(tags=tags, **kwargs): + return + name_block = self._convert_string_arg_block("NAME", name, quote=False) + srcs_block = self._convert_srcs_block(srcs) + data_block = self._convert_target_list_block("DATA", data) + deps_block = self._convert_target_list_block("DEPS", deps) + copts_block = self._convert_string_list_block("COPTS", copts, sort=False) + defines_block = self._convert_string_list_block("DEFINES", defines) + linkopts_block = self._convert_string_list_block("LINKOPTS", linkopts) + labels_block = self._convert_string_list_block("LABELS", tags) + + self._converter.body += ( + f"iree_cc_fuzz(\n" + f"{name_block}" + f"{srcs_block}" + f"{data_block}" + f"{deps_block}" + f"{copts_block}" + f"{defines_block}" + f"{linkopts_block}" + f"{labels_block}" + f")\n\n" + ) + + def iree_runtime_cc_fuzz(self, **kwargs): + self.iree_cc_fuzz(**kwargs) + + def iree_compiler_cc_fuzz(self, **kwargs): + self.iree_cc_fuzz(**kwargs) + def iree_c_embed_data( self, name, diff --git a/build_tools/cmake/build_and_test_ubsan.sh b/build_tools/cmake/build_and_test_ubsan.sh index 9cc551bba448..27bde66c8504 100755 --- a/build_tools/cmake/build_and_test_ubsan.sh +++ b/build_tools/cmake/build_and_test_ubsan.sh @@ -21,6 +21,7 @@ set -xeuo pipefail BUILD_DIR="${1:-${IREE_UBSAN_BUILD_DIR:-build-ubsan}}" IREE_ENABLE_ASSERTIONS="${IREE_ENABLE_ASSERTIONS:-ON}" +IREE_REVERSE_ITERATE="${IREE_REVERSE_ITERATE:-OFF}" # Enable CUDA and HIP/ROCM compiler and runtime by default if not on Darwin. OFF_IF_DARWIN="$(uname | awk '{print ($1 == "Darwin") ? "OFF" : "ON"}')" IREE_HAL_DRIVER_CUDA="${IREE_HAL_DRIVER_CUDA:-${OFF_IF_DARWIN}}" @@ -45,6 +46,7 @@ CMAKE_ARGS=( "-DIREE_BUILD_PYTHON_BINDINGS=OFF" "-DIREE_ENABLE_ASSERTIONS=${IREE_ENABLE_ASSERTIONS}" + "-DIREE_REVERSE_ITERATE=${IREE_REVERSE_ITERATE}" "-DIREE_ENABLE_LLD=ON" "-DIREE_ENABLE_SPLIT_DWARF=ON" "-DIREE_ENABLE_THIN_ARCHIVES=ON" diff --git a/build_tools/cmake/iree_c_module.cmake b/build_tools/cmake/iree_c_module.cmake index 3e4ad104e59d..11359824cc57 100644 --- a/build_tools/cmake/iree_c_module.cmake +++ b/build_tools/cmake/iree_c_module.cmake @@ -89,6 +89,14 @@ function(iree_c_module) DEPENDS ${_COMPILE_TOOL} ${_SRC_PATH} ) + # Generated EmitC code may have unused variables from optimization + # barriers and other cases where an SSA value is consumed by an op + # that produces a new value. Suppress this warning for Clang/GCC. + iree_select_compiler_opts(_EMITC_SUPPRESS_OPTS + CLANG_OR_GCC + "-Wno-unused-but-set-variable" + ) + iree_cc_library( NAME ${_RULE_NAME} HDRS "${_RULE_H_FILE_OUTPUT}" @@ -96,12 +104,21 @@ function(iree_c_module) INCLUDES "${CMAKE_CURRENT_BINARY_DIR}" COPTS "-DEMITC_IMPLEMENTATION=\"${_RULE_H_FILE_OUTPUT}\"" + ${_EMITC_SUPPRESS_OPTS} "${_TESTONLY_ARG}" DEPS # Include paths and options for the runtime sources. iree_defs ) + # Apply warning suppression to consumers (tests, etc.) that include the + # generated headers. + iree_package_name(_PACKAGE_NAME) + set(_TARGET "${_PACKAGE_NAME}_${_RULE_NAME}") + if(_EMITC_SUPPRESS_OPTS) + target_compile_options(${_TARGET} INTERFACE ${_EMITC_SUPPRESS_OPTS}) + endif() + if(_RULE_NO_RUNTIME) return() endif() diff --git a/build_tools/cmake/iree_cc_fuzz.cmake b/build_tools/cmake/iree_cc_fuzz.cmake new file mode 100644 index 000000000000..296f4f26edc4 --- /dev/null +++ b/build_tools/cmake/iree_cc_fuzz.cmake @@ -0,0 +1,115 @@ +# Copyright 2026 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# iree_cc_fuzz() +# +# CMake function to create a libFuzzer-based fuzz target. +# +# Parameters: +# NAME: name of target. This name is used for the generated executable. +# SRCS: List of source files for the fuzzer (must define LLVMFuzzerTestOneInput). +# DATA: List of other targets and files required for this binary. +# DEPS: List of other libraries to be linked in to the binary targets. +# COPTS: List of private compile options. +# DEFINES: List of public defines. +# LINKOPTS: List of link options. +# LABELS: Additional labels to apply to the target. +# +# Note: +# - Fuzz targets require IREE_BUILD_TESTS=ON AND IREE_ENABLE_FUZZING=ON. +# - Fuzz targets are NOT added to CTest (they run differently than tests). +# - Fuzz targets are excluded from the default 'all' target (build explicitly). +# - Binary name is ${NAME} in the bin directory. +# +# Usage: +# iree_cc_fuzz( +# NAME +# unicode_fuzz +# SRCS +# "unicode_fuzz.cc" +# DEPS +# iree::base::internal::unicode +# ) +function(iree_cc_fuzz) + # Fuzz targets require both tests enabled AND fuzzing enabled. + if(NOT IREE_BUILD_TESTS) + return() + endif() + if(NOT IREE_ENABLE_FUZZING) + return() + endif() + + cmake_parse_arguments( + _RULE + "" + "NAME" + "SRCS;COPTS;DEFINES;LINKOPTS;DATA;DEPS;LABELS" + ${ARGN} + ) + + # Prefix the library with the package name, so we get: iree_package_name + iree_package_name(_PACKAGE_NAME) + iree_package_ns(_PACKAGE_NS) + set(_NAME "${_PACKAGE_NAME}_${_RULE_NAME}") + + add_executable(${_NAME} "") + # Alias the iree_package_name fuzz binary to iree::package::name. + add_executable(${_PACKAGE_NS}::${_RULE_NAME} ALIAS ${_NAME}) + + set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "${_RULE_NAME}") + target_sources(${_NAME} + PRIVATE + ${_RULE_SRCS} + ) + target_include_directories(${_NAME} SYSTEM + PUBLIC + "$" + "$" + ) + target_compile_definitions(${_NAME} + PUBLIC + ${_RULE_DEFINES} + ) + target_compile_options(${_NAME} + PRIVATE + ${IREE_DEFAULT_COPTS} + ${_RULE_COPTS} + ) + + # Link with libFuzzer runtime. The -fsanitize=fuzzer flag provides the main() + # function and fuzzing driver. All other code is compiled with + # -fsanitize=fuzzer-no-link (set in iree_setup_toolchain.cmake) for coverage + # instrumentation without linking the fuzzer runtime. + target_link_options(${_NAME} + PRIVATE + ${IREE_DEFAULT_LINKOPTS} + ${_RULE_LINKOPTS} + "-fsanitize=fuzzer" + ) + + # Replace dependencies passed by ::name with iree::package::name + list(TRANSFORM _RULE_DEPS REPLACE "^::" "${_PACKAGE_NS}::") + + # Implicit deps. + if(IREE_IMPLICIT_DEFS_CC_DEPS) + list(APPEND _RULE_DEPS ${IREE_IMPLICIT_DEFS_CC_DEPS}) + endif() + + target_link_libraries(${_NAME} + PUBLIC + ${_RULE_DEPS} + ) + iree_add_data_dependencies(NAME ${_NAME} DATA ${_RULE_DATA}) + + # Add all IREE fuzz targets to a folder in the IDE for organization. + set_property(TARGET ${_NAME} PROPERTY FOLDER ${IREE_IDE_FOLDER}/fuzz) + + set_property(TARGET ${_NAME} PROPERTY CXX_STANDARD ${IREE_CXX_STANDARD}) + set_property(TARGET ${_NAME} PROPERTY CXX_STANDARD_REQUIRED ON) + + # Exclude from 'all' target - fuzz targets must be built explicitly. + set_property(TARGET ${_NAME} PROPERTY EXCLUDE_FROM_ALL ON) +endfunction() diff --git a/build_tools/cmake/iree_setup_toolchain.cmake b/build_tools/cmake/iree_setup_toolchain.cmake index 756d9b073409..5ff74afbf296 100644 --- a/build_tools/cmake/iree_setup_toolchain.cmake +++ b/build_tools/cmake/iree_setup_toolchain.cmake @@ -144,6 +144,12 @@ macro(iree_setup_toolchain) # defined with the same sanitizer flags, including e.g. standard library # symbols that might be used by both IREE and non-IREE (e.g. LLVM) code. + # Fuzzing requires ASan - enable it automatically if not already set. + if(IREE_ENABLE_FUZZING AND NOT IREE_ENABLE_ASAN) + message(STATUS "Fuzzing enabled: automatically enabling ASan") + set(IREE_ENABLE_ASAN ON) + endif() + if(IREE_ENABLE_ASAN) string(APPEND CMAKE_CXX_FLAGS " -fsanitize=address") string(APPEND CMAKE_C_FLAGS " -fsanitize=address") @@ -187,6 +193,13 @@ macro(iree_setup_toolchain) string(APPEND CMAKE_CXX_FLAGS " -fsanitize=undefined") string(APPEND CMAKE_C_FLAGS " -fsanitize=undefined") endif() + if(IREE_ENABLE_FUZZING) + # Instrument all code for libFuzzer coverage feedback without linking the + # fuzzer runtime. Fuzz targets link with -fsanitize=fuzzer separately to + # get the main() function and driver. + string(APPEND CMAKE_CXX_FLAGS " -fsanitize=fuzzer-no-link") + string(APPEND CMAKE_C_FLAGS " -fsanitize=fuzzer-no-link") + endif() #----------------------------------------------------------------------------- # Build performance optimizations diff --git a/build_tools/github_actions/configure_ci.py b/build_tools/github_actions/configure_ci.py index 86978ce4fcd0..7a62b1d633c5 100755 --- a/build_tools/github_actions/configure_ci.py +++ b/build_tools/github_actions/configure_ci.py @@ -191,6 +191,10 @@ def contains(cls, val): ".github/worklflows/ci_windows_x64_msvc.yml", ], ), + ( + "test_torch", + ["tests/external/iree-test-suites/torch*"], + ), ] PR_DESCRIPTION_TEMPLATE = string.Template("${title}\n\n${body}") diff --git a/compiler/.clang-format b/compiler/.clang-format index f50fe3d2d350..3e499f9a3782 100644 --- a/compiler/.clang-format +++ b/compiler/.clang-format @@ -10,6 +10,7 @@ # ordering. BasedOnStyle: LLVM AlwaysBreakTemplateDeclarations: Yes +InsertBraces: Yes IncludeCategories: - Regex: '^<.*\.h>' Priority: 1 diff --git a/compiler/bindings/c/iree/compiler/embedding_api.h b/compiler/bindings/c/iree/compiler/embedding_api.h index 6da0379e9afd..eaf873dcf35e 100644 --- a/compiler/bindings/c/iree/compiler/embedding_api.h +++ b/compiler/bindings/c/iree/compiler/embedding_api.h @@ -268,6 +268,9 @@ enum iree_compiler_pipeline_t { // This is experimental and this should be changed as we move to a more // cohesive approach for managing compilation phases. IREE_COMPILER_PIPELINE_PRECOMPILE = 2, + // VM transformation pipeline only. Converts from input dialects to the VM + // dialect without serialization. + IREE_COMPILER_PIPELINE_VM = 3, }; IREE_EMBED_EXPORTED bool ireeCompilerInvocationPipeline(iree_compiler_invocation_t *inv, diff --git a/compiler/bindings/c/iree/compiler/loader/loader.cpp b/compiler/bindings/c/iree/compiler/loader/loader.cpp index f3c04646bb1e..73b6647e0d14 100644 --- a/compiler/bindings/c/iree/compiler/loader/loader.cpp +++ b/compiler/bindings/c/iree/compiler/loader/loader.cpp @@ -19,8 +19,9 @@ namespace { using DlHandle = HMODULE; DlHandle loadLibrary(const char *libraryPath) { HMODULE lib = LoadLibraryExA(libraryPath, nullptr, 0); - if (lib) + if (lib) { return lib; + } DWORD errorMessageID = GetLastError(); LPSTR messageBuffer = nullptr; size_t size = FormatMessageA( @@ -48,8 +49,9 @@ DlHandle loadLibrary(const char *libraryPath) { DlHandle lib = dlopen(libraryPath, RTLD_NOW | RTLD_LOCAL); if (!lib) { const char *reason = dlerror(); - if (!reason) + if (!reason) { reason = ""; + } fprintf(stderr, "IREE COMPILER ERROR: Could not open compiler library %s : %s\n", libraryPath, reason); @@ -73,8 +75,9 @@ DlHandle libraryHandle = nullptr; #undef HANDLE_VERSIONED_SYMBOL void assertLoaded() { - if (libraryHandle) + if (libraryHandle) { return; + } fprintf(stderr, "FATAL ERROR: Attempt to call IREE compiler stub methods before " "library loaded\n"); diff --git a/compiler/bindings/python/CMakeLists.txt b/compiler/bindings/python/CMakeLists.txt index 5956c2d24cb8..5a148eddcdcf 100644 --- a/compiler/bindings/python/CMakeLists.txt +++ b/compiler/bindings/python/CMakeLists.txt @@ -416,29 +416,6 @@ add_custom_target(IREECompilerPythonDylibFiles add_dependencies(IREECompilerPythonModules IREECompilerPythonDylibFiles) -################################################################################ -# Windows DLL colocation fix -# On Windows, the nanobind-mlir.dll ends up in iree/build/_mlir_libs/ but we -# need to copy it to iree/compiler/_mlir_libs/ for the Python extensions to find -# it at runtime. -################################################################################ -if(WIN32) - set(_nanobind_src "${_PYTHON_BUILD_PREFIX}/iree/build/_mlir_libs/nanobind-mlir.dll") - set(_nanobind_dst "${_PYTHON_BUILD_PREFIX}/iree/compiler/_mlir_libs/nanobind-mlir.dll") - add_custom_command( - OUTPUT "${_nanobind_dst}" - DEPENDS "${_nanobind_src}" - COMMAND ${CMAKE_COMMAND} -E copy_if_different - "${_nanobind_src}" "${_nanobind_dst}" - COMMENT "Copying nanobind-mlir.dll to iree/compiler/_mlir_libs/ for Windows DLL loading" - ) - add_custom_target(IREECompilerPythonNanobindCopy - DEPENDS "${_nanobind_dst}" - ) - add_dependencies(IREECompilerPythonNanobindCopy IREECompilerBuildPythonModules) - add_dependencies(IREECompilerPythonModules IREECompilerPythonNanobindCopy) -endif() - ################################################################################ # Subdirectories ################################################################################ diff --git a/compiler/bindings/python/IREECompilerDialectsModule.cpp b/compiler/bindings/python/IREECompilerDialectsModule.cpp index 93ef30749db3..d20e58afafd4 100644 --- a/compiler/bindings/python/IREECompilerDialectsModule.cpp +++ b/compiler/bindings/python/IREECompilerDialectsModule.cpp @@ -51,8 +51,9 @@ ireeCodegenGetTunerRootOpsBinding(MlirModule module) { } static std::vector getIntArrayAttrValues(MlirAttribute attr) { - if (mlirAttributeIsNull(attr) || !mlirAttributeIsAArray(attr)) + if (mlirAttributeIsNull(attr) || !mlirAttributeIsAArray(attr)) { return {}; + } std::vector result; size_t n = mlirArrayAttrGetNumElements(attr); @@ -261,8 +262,9 @@ NB_MODULE(_ireeCompilerDialects, m) { "prefetch_num_stages", [](MlirAttribute self) -> std::optional { auto attr = ireeGPUPipelineOptionsAttrGetPrefetchNumStages(self); - if (!mlirAttributeIsNull(attr)) + if (!mlirAttributeIsNull(attr)) { return mlirIntegerAttrGetValueInt(attr); + } return std::nullopt; }) .def_property_readonly( @@ -271,16 +273,18 @@ NB_MODULE(_ireeCompilerDialects, m) { auto attr = ireeGPUPipelineOptionsAttrGetNoReduceSharedMemoryBankConflicts( self); - if (!mlirAttributeIsNull(attr)) + if (!mlirAttributeIsNull(attr)) { return mlirBoolAttrGetValue(attr); + } return std::nullopt; }) .def_property_readonly( "use_igemm_convolution", [](MlirAttribute self) -> std::optional { auto attr = ireeGPUPipelineOptionsAttrGetUseIgemmConvolution(self); - if (!mlirAttributeIsNull(attr)) + if (!mlirAttributeIsNull(attr)) { return mlirBoolAttrGetValue(attr); + } return std::nullopt; }) .def_property_readonly( @@ -288,8 +292,9 @@ NB_MODULE(_ireeCompilerDialects, m) { [](MlirAttribute self) -> std::optional { auto attr = ireeGPUPipelineOptionsAttrGetReorderWorkgroupsStrategy(self); - if (!mlirAttributeIsNull(attr)) + if (!mlirAttributeIsNull(attr)) { return attr; + } return std::nullopt; }); @@ -485,8 +490,9 @@ NB_MODULE(_ireeCompilerDialects, m) { .def_property_readonly( "mma_kind", [](MlirAttribute self) -> std::optional { auto attr = ireeGPULoweringConfigAttrGetMmaKind(self); - if (!mlirAttributeIsNull(attr)) + if (!mlirAttributeIsNull(attr)) { return attr; + } return std::nullopt; }); diff --git a/compiler/plugins/input/StableHLO/Conversion/ConvertCollectives.cpp b/compiler/plugins/input/StableHLO/Conversion/ConvertCollectives.cpp index e6175c2af2c5..aff8e78f0c2b 100644 --- a/compiler/plugins/input/StableHLO/Conversion/ConvertCollectives.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/ConvertCollectives.cpp @@ -97,8 +97,9 @@ static std::pair makeSplitColorAndKey(Location loc, OpBuilder &builder) { IndexSet indexSet(loc, builder); Value noColor = indexSet.get(-1); - if (!groups) + if (!groups) { return std::make_pair(noColor, noColor); + } auto groupsType = cast(groups.getType()); assert(groupsType.getRank() == 2); @@ -311,8 +312,9 @@ static Value createChannelWithGroupInfo( DenseIntElementsAttr replicaGroups, std::optional useGlobalDeviceIds, OpBuilder &builder) { // Set numPartitions to 1 if not set by the user. - if (numPartitions == -1) + if (numPartitions == -1) { numPartitions = 1; + } // Base channel that may be split by the group info. Value baseChannel = IREE::Flow::ChannelDefaultOp::create( @@ -854,8 +856,9 @@ struct CollectivePermuteOpConversion int64_t numParticipants = mode == CollectiveOpGroupMode::CrossReplica ? numReplicas : numPartitions; - if (numParticipants == -1) + if (numParticipants == -1) { numParticipants = 1; + } SmallVector replicaGroups; for (int64_t i = 0; i < numParticipants; ++i) { replicaGroups.push_back(rewriter.getI64IntegerAttr(i)); diff --git a/compiler/plugins/input/StableHLO/Conversion/LegalizeControlFlow.cpp b/compiler/plugins/input/StableHLO/Conversion/LegalizeControlFlow.cpp index dd1d9c9434be..f2506ecafef0 100644 --- a/compiler/plugins/input/StableHLO/Conversion/LegalizeControlFlow.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/LegalizeControlFlow.cpp @@ -65,12 +65,14 @@ struct ScfForBounds { std::optional extractForBounds(mlir::stablehlo::WhileOp op) { Block &cond = op.getCond().front(); Block &body = op.getBody().front(); - if (cond.getOperations().size() != 2) + if (cond.getOperations().size() != 2) { return std::nullopt; + } auto matchBbArg = [](Value v, Block &block) -> std::optional { - if (!isa(v) || v.getParentBlock() != &block) + if (!isa(v) || v.getParentBlock() != &block) { return std::nullopt; + } return cast(v).getArgNumber(); }; @@ -87,8 +89,9 @@ std::optional extractForBounds(mlir::stablehlo::WhileOp op) { } std::optional iterArg = matchBbArg(compare.getLhs(), cond); - if (!iterArg) + if (!iterArg) { return std::nullopt; + } auto add = dyn_cast_if_present( body.getTerminator()->getOperand(*iterArg).getDefiningOp()); diff --git a/compiler/plugins/input/StableHLO/Conversion/LegalizeShapeComputations.cpp b/compiler/plugins/input/StableHLO/Conversion/LegalizeShapeComputations.cpp index d80d4ed123bb..070352296c86 100644 --- a/compiler/plugins/input/StableHLO/Conversion/LegalizeShapeComputations.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/LegalizeShapeComputations.cpp @@ -47,8 +47,9 @@ struct HloElementwiseConverter : OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final { - if (!opIsShapeComputation(op)) + if (!opIsShapeComputation(op)) { return failure(); + } auto resultTy = cast(op.getType()); @@ -86,8 +87,9 @@ struct ConcatenateConverter final LogicalResult matchAndRewrite(mlir::stablehlo::ConcatenateOp op, PatternRewriter &rewriter) const override { - if (!opIsShapeComputation(op)) + if (!opIsShapeComputation(op)) { return failure(); + } Location loc = op.getLoc(); auto resultTy = cast(op.getType()); @@ -144,14 +146,16 @@ struct ReshapeConverter : OpRewritePattern { PatternRewriter &rewriter) const override { Value operand = op.getOperand(); auto shapedTy = cast(operand.getType()); - if (!shapedTy.hasRank() || shapedTy.getRank() > 1) + if (!shapedTy.hasRank() || shapedTy.getRank() > 1) { return failure(); + } auto resultTy = cast(op.getType()); auto fromElements = op.getOperand().getDefiningOp(); - if (!fromElements) + if (!fromElements) { return failure(); + } rewriter.replaceOpWithNewOp( op, resultTy, fromElements.getOperands()); diff --git a/compiler/plugins/input/StableHLO/Conversion/MapStableHLOToScalarOp.h b/compiler/plugins/input/StableHLO/Conversion/MapStableHLOToScalarOp.h index 06f30deb65a1..0c91d76954a1 100644 --- a/compiler/plugins/input/StableHLO/Conversion/MapStableHLOToScalarOp.h +++ b/compiler/plugins/input/StableHLO/Conversion/MapStableHLOToScalarOp.h @@ -456,8 +456,9 @@ inline Value mapStableHloOpToStdScalarOp( return ScalarFOp::create(*b, loc, predicate.value(), lhs, rhs); } - if (auto complexType = dyn_cast(elementType)) + if (auto complexType = dyn_cast(elementType)) { return cmpComplex(loc, lhs, rhs, comparisonDirection, b); + } return nullptr; } @@ -602,11 +603,12 @@ inline Value mapStableHloOpToStdScalarOp( Value lhs = operands.front(); Type complexTy = lhs.getType(); - if (!isa(complexTy)) + if (!isa(complexTy)) { return MapStableHloOpToScalarOpImpl< IsFloatType, arith::MaximumFOp, IsSignedIntegerType, arith::MaxSIOp, IsUnsignedIntegerType, arith::MaxUIOp>{}(loc, resultTypes, argTypes, adaptor.getOperands(), b); + } assert(resultTypes.size() == 1 && "MaxOp should return a single result"); assert(operands.size() == 2 && "MaxOp should take exactly two arguments"); @@ -626,11 +628,12 @@ inline Value mapStableHloOpToStdScalarOp( Value lhs = operands.front(); Type complexTy = lhs.getType(); - if (!isa(complexTy)) + if (!isa(complexTy)) { return MapStableHloOpToScalarOpImpl< IsFloatType, arith::MinimumFOp, IsSignedIntegerType, arith::MinSIOp, IsUnsignedIntegerType, arith::MinUIOp>{}(loc, resultTypes, argTypes, adaptor.getOperands(), b); + } assert(resultTypes.size() == 1 && "MinOp should return a single result"); assert(operands.size() == 2 && "MinOp should take exactly two arguments"); @@ -646,8 +649,9 @@ template <> inline Value mapStableHloOpToStdScalarOp( Location loc, ArrayRef resultTypes, ArrayRef argTypes, stablehlo::RealOp::Adaptor adaptor, OpBuilder *b) { - if (!isa(adaptor.getOperand().getType())) + if (!isa(adaptor.getOperand().getType())) { return adaptor.getOperand(); + } return MapStableHloOpToScalarOpImpl{}( loc, resultTypes, argTypes, adaptor.getOperands(), b); } @@ -656,9 +660,10 @@ template <> inline Value mapStableHloOpToStdScalarOp( Location loc, ArrayRef resultTypes, ArrayRef argTypes, stablehlo::ImagOp::Adaptor adaptor, OpBuilder *b) { - if (!isa(adaptor.getOperand().getType())) + if (!isa(adaptor.getOperand().getType())) { return arith::ConstantOp::create( *b, loc, b->getZeroAttr(adaptor.getOperand().getType())); + } return MapStableHloOpToScalarOpImpl{}( loc, resultTypes, argTypes, adaptor.getOperands(), b); } @@ -813,15 +818,18 @@ inline Value mapStableHloOpToStdScalarOp( Type resultType = getElementTypeOrSelf(resultTypes.front()); // Skip needless casts. - if (argType == resultType) + if (argType == resultType) { return adaptor.getOperand(); + } if (!isa(resultType) || - !isa(argType)) + !isa(argType)) { return nullptr; + } - if (resultType.getIntOrFloatBitWidth() != argType.getIntOrFloatBitWidth()) + if (resultType.getIntOrFloatBitWidth() != argType.getIntOrFloatBitWidth()) { return nullptr; + } return mlir::arith::BitcastOp::create(*b, loc, resultTypes, adaptor.getOperands()); diff --git a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/Canonicalization.cpp b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/Canonicalization.cpp index 8a996a19d6b5..6518c9653572 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/Canonicalization.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/Canonicalization.cpp @@ -54,8 +54,9 @@ static bool isIotaRange(ArrayRef dims) { static bool isIotaRange(ElementsAttr attr) { auto elems = attr.tryGetValues(); - if (!elems) + if (!elems) { return false; + } for (auto [idx, value] : llvm::enumerate(*elems)) { if (idx != value) { @@ -119,8 +120,9 @@ struct AddOpCanon final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::AddOp op, PatternRewriter &rewriter) const override { auto type = dyn_cast(op.getType()); - if (!type) + if (!type) { return failure(); + } Value lhs = op.getLhs(); Value rhs = op.getRhs(); @@ -166,8 +168,9 @@ struct SubtractOpCanon final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::SubtractOp op, PatternRewriter &rewriter) const override { auto type = dyn_cast(op.getType()); - if (!type) + if (!type) { return failure(); + } Value lhs = op.getLhs(); Value rhs = op.getRhs(); @@ -208,8 +211,9 @@ struct MulOpCanon final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op, PatternRewriter &rewriter) const override { auto type = dyn_cast(op.getType()); - if (!type) + if (!type) { return failure(); + } Value lhs = op.getLhs(); Value rhs = op.getRhs(); @@ -334,8 +338,9 @@ struct CompareOpCanon final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::CompareOp op, PatternRewriter &rewriter) const override { auto type = dyn_cast(op.getType()); - if (!type) + if (!type) { return failure(); + } // Bail out on non-integer comparison. // TODO: Support more comparison types. @@ -410,8 +415,9 @@ struct SelectOpCanon final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::SelectOp op, PatternRewriter &rewriter) const override { auto type = dyn_cast(op.getType()); - if (!type) + if (!type) { return failure(); + } Value trueVal = op.getOnTrue(); Value falseVal = op.getOnFalse(); @@ -437,16 +443,19 @@ struct SelectOpCanon final : OpRewritePattern { // Handle elementwise selection when both outcomes are also constants. This // will create a new, likely non-splat constant. - if (cond.getNumElements() > kFoldOpEltLimit) + if (cond.getNumElements() > kFoldOpEltLimit) { return failure(); + } ElementsAttr trueAttr; - if (!matchPattern(trueVal, m_Constant(&trueAttr))) + if (!matchPattern(trueVal, m_Constant(&trueAttr))) { return failure(); + } ElementsAttr falseAttr; - if (!matchPattern(falseVal, m_Constant(&falseAttr))) + if (!matchPattern(falseVal, m_Constant(&falseAttr))) { return failure(); + } SmallVector newValues; newValues.reserve(cond.getNumElements()); @@ -469,13 +478,15 @@ struct BroadcastInDimOpCanon final LogicalResult matchAndRewrite(mlir::stablehlo::BroadcastInDimOp op, PatternRewriter &rewriter) const override { auto type = dyn_cast(op.getType()); - if (!type) + if (!type) { return failure(); + } Value operand = op.getOperand(); auto operandTy = dyn_cast(operand.getType()); - if (!operandTy) + if (!operandTy) { return failure(); + } // Fold when broadcast is a noop. auto dims = op.getBroadcastDimensions(); @@ -534,12 +545,14 @@ struct ConcatenateOpCanon final LogicalResult matchAndRewrite(mlir::stablehlo::ConcatenateOp op, PatternRewriter &rewriter) const override { auto type = dyn_cast(op.getType()); - if (!type || !type.hasStaticShape()) + if (!type || !type.hasStaticShape()) { return failure(); + } size_t numElems = type.getNumElements(); - if (numElems > kFoldOpEltLimit) + if (numElems > kFoldOpEltLimit) { return failure(); + } // Fold concatenate when all inputs are constants. OperandRange inputs = op.getInputs(); @@ -578,8 +591,9 @@ struct ConvertOpCanon final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::ConvertOp op, PatternRewriter &rewriter) const override { // Check if this convert is a noop. - if (op.getOperand().getType() != op.getType()) + if (op.getOperand().getType() != op.getType()) { return failure(); + } rewriter.replaceOp(op, op.getOperand()); return success(); @@ -673,8 +687,9 @@ struct ChainedDynamicBroadcastInDimCanonicalization final auto precedingBcast = bcast.getOperand() .getDefiningOp(); - if (!precedingBcast) + if (!precedingBcast) { return failure(); + } // Compose broadcast dimensions. SmallVector composition; @@ -759,8 +774,9 @@ struct EmptyReduceOpCanon final : OpRewritePattern { "unranked input unsupported"); } - if (!llvm::is_contained(elemTy.getShape(), 0)) + if (!llvm::is_contained(elemTy.getShape(), 0)) { return failure(); + } Location loc = op.getLoc(); DenseI64ArrayAttr empty = rewriter.getDenseI64ArrayAttr({}); @@ -799,8 +815,9 @@ struct DynamicReshapeOpCanon final PatternRewriter &rewriter) const override { // This is a noop when the output type is already a static shape. auto type = dyn_cast(op.getType()); - if (!type || !type.hasStaticShape()) + if (!type || !type.hasStaticShape()) { return failure(); + } rewriter.replaceOpWithNewOp(op, type, op.getOperand()); @@ -816,8 +833,9 @@ struct GetTupleElementOpCanon final PatternRewriter &rewriter) const override { auto constructor = op.getOperand().getDefiningOp(); - if (!constructor) + if (!constructor) { return failure(); + } Value result = constructor.getOperand(op.getIndex()); rewriter.replaceOp(op, result); @@ -831,8 +849,9 @@ struct RealOpCanon final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::RealOp op, PatternRewriter &rewriter) const override { auto complex = op.getOperand().getDefiningOp(); - if (!complex) + if (!complex) { return failure(); + } rewriter.replaceOp(op, complex.getLhs()); return success(); @@ -845,8 +864,9 @@ struct ImagOpCanon final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::ImagOp op, PatternRewriter &rewriter) const override { auto complex = op.getOperand().getDefiningOp(); - if (!complex) + if (!complex) { return failure(); + } rewriter.replaceOp(op, complex.getRhs()); return success(); @@ -861,12 +881,14 @@ struct GetDimensionSizeOpCanon final PatternRewriter &rewriter) const override { // Fold get_dimension_size when the queried dim is statically known. auto tensorTy = dyn_cast(op.getOperand().getType()); - if (!tensorTy) + if (!tensorTy) { return failure(); + } int64_t dimSize = tensorTy.getDimSize(op.getDimension()); - if (dimSize < 0) + if (dimSize < 0) { return failure(); + } auto elemTy = cast(op.getType().getElementType()); IntegerAttr elemVal = rewriter.getIntegerAttr(elemTy, dimSize); @@ -903,8 +925,9 @@ struct GatherOpCanon final : OpRewritePattern { auto operandType = dyn_cast(gather->getOperand(0).getType()); - if (!operandType || !operandType.hasStaticShape()) + if (!operandType || !operandType.hasStaticShape()) { return failure(); + } auto sliceEnd = llvm::to_vector(gather.getSliceSizes()); SmallVector sliceStart(sliceEnd.size(), 0); @@ -1044,13 +1067,16 @@ struct TransposeIsReshape final nonZeroPerms.reserve(permValues.size()); for (auto idx : permValues) { auto sz = inputTy.getDimSize(idx); - if (sz != 1) + if (sz != 1) { nonZeroPerms.push_back(idx); + } } - for (int i = 1, s = nonZeroPerms.size(); i < s; ++i) - if (nonZeroPerms[i - 1] > nonZeroPerms[i]) + for (int i = 1, s = nonZeroPerms.size(); i < s; ++i) { + if (nonZeroPerms[i - 1] > nonZeroPerms[i]) { return rewriter.notifyMatchFailure(op, "memory layout change"); + } + } rewriter.replaceOpWithNewOp(op, op.getType(), input); diff --git a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/DotGeneralToDot.cpp b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/DotGeneralToDot.cpp index 744a8c523d8b..6f080dfd4e2b 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/DotGeneralToDot.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/DotGeneralToDot.cpp @@ -65,8 +65,9 @@ Value transposeReshape(Value arg, Location loc, auto transposeType = RankedTensorType::get(transposedShape, elementType); Value transposeResult = mlir::stablehlo::TransposeOp::create( rewriter, loc, transposeType, arg, transposePermutationAttr); - if (noReshape) + if (noReshape) { return transposeResult; + } // Return the final result. auto reshapedType = RankedTensorType::get({leftSize, rightSize}, elementType); @@ -176,12 +177,14 @@ struct GeneralDotRemoveBatch final // We no longer include the batch dimension of 1. llvm::SmallVector newLhsContractingDims; - for (auto dim : dimNumbers.getLhsContractingDimensions()) + for (auto dim : dimNumbers.getLhsContractingDimensions()) { newLhsContractingDims.push_back(dim - 1); + } llvm::SmallVector newRhsContractingDims; - for (auto dim : dimNumbers.getRhsContractingDimensions()) + for (auto dim : dimNumbers.getRhsContractingDimensions()) { newRhsContractingDims.push_back(dim - 1); + } auto lhs = mlir::stablehlo::ReshapeOp::create( rewriter, op.getLoc(), lhsTy.clone(lhsTy.getShape().drop_front()), @@ -231,8 +234,9 @@ struct GeneralDotConvert final ArrayAttr precisionConfig; auto opPrecisionConfig = op.getPrecisionConfig(); - if (opPrecisionConfig.has_value()) + if (opPrecisionConfig.has_value()) { precisionConfig = *opPrecisionConfig; + } auto resultTy = cast(op.getType()); @@ -246,8 +250,9 @@ struct GeneralDotConvert final RankedTensorType lhsTy = dyn_cast(lhs.getType()); RankedTensorType rhsTy = dyn_cast(rhs.getType()); - if (!lhsTy || !rhsTy) + if (!lhsTy || !rhsTy) { return failure(); + } // The StableHLO dot operator directly supports a vector dot product // (two vectors reduce into a scalar) as well as a matrix vector @@ -295,8 +300,9 @@ struct GeneralDotConvert final // For any sparse situation, don't use any of the following rules, since // transposing and reshaping is not without cost. Instead, rely on the // default linalg lowering that follows later in the pipeline. - if (sparse_tensor::hasAnySparseOperandOrResult(op)) + if (sparse_tensor::hasAnySparseOperandOrResult(op)) { return failure(); + } // Compute the, possibly, transposed-reshaped operands. lhs = cast>(processDotArg( @@ -307,8 +313,9 @@ struct GeneralDotConvert final // Accept only static shaped types. auto lhsShapeType = dyn_cast_if_present(lhs.getType()); auto rhsShapeType = dyn_cast_if_present(rhs.getType()); - if (!lhsShapeType || !rhsShapeType) + if (!lhsShapeType || !rhsShapeType) { return failure(); + } // Generate new dot operator on expanded types. ShapedType newTy = RankedTensorType::get( diff --git a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/FlattenTuplesInCFG.cpp b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/FlattenTuplesInCFG.cpp index 931a43fd8b42..6ce7b58454f3 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/FlattenTuplesInCFG.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/FlattenTuplesInCFG.cpp @@ -69,8 +69,9 @@ void copyOperationAttrs(Operation *oldOp, Operation *newOp) { // Don't copy segment attributes as these correspond to the number operands, // which may be different. if (oldAttr.getName() == "operandSegmentSizes" || - oldAttr.getName() == "resultSegmentSizes") + oldAttr.getName() == "resultSegmentSizes") { continue; + } newOp->setAttr(oldAttr.getName(), oldAttr.getValue()); } @@ -127,8 +128,9 @@ class DetupleReturnOp : public OpRewritePattern { LogicalResult matchAndRewrite(func::ReturnOp op, PatternRewriter &builder) const override { - if (!hasTuples(op.getOperands())) + if (!hasTuples(op.getOperands())) { return builder.notifyMatchFailure(op, "No detupling required"); + } llvm::SmallVector newOperands; if (failed(untupleAndLookupValues(op.getOperands(), newOperands, builder, @@ -147,8 +149,9 @@ class DetupleCallOp : public OpRewritePattern { LogicalResult matchAndRewrite(func::CallOp oldOp, PatternRewriter &builder) const override { - if (!hasTuples(oldOp.getOperands()) && !hasTuples(oldOp.getResults())) + if (!hasTuples(oldOp.getOperands()) && !hasTuples(oldOp.getResults())) { return builder.notifyMatchFailure(oldOp, "No detupling required"); + } llvm::SmallVector newArgs; if (failed(untupleAndLookupValues(oldOp.getOperands(), newArgs, builder, @@ -180,8 +183,9 @@ class DetupleIndirectCallOp : public OpRewritePattern { LogicalResult matchAndRewrite(func::CallIndirectOp oldOp, PatternRewriter &builder) const override { - if (!hasTuples(oldOp.getOperands()) && !hasTuples(oldOp.getResults())) + if (!hasTuples(oldOp.getOperands()) && !hasTuples(oldOp.getResults())) { return builder.notifyMatchFailure(oldOp, "No detupling required"); + } llvm::SmallVector newArgs; if (failed(untupleAndLookupValues(oldOp.getOperands(), newArgs, builder, @@ -202,8 +206,9 @@ class DetupleBranchOp : public OpRewritePattern { LogicalResult matchAndRewrite(cf::BranchOp oldOp, PatternRewriter &builder) const override { - if (!hasTuples(oldOp.getOperands())) + if (!hasTuples(oldOp.getOperands())) { return builder.notifyMatchFailure(oldOp, "No detupling required"); + } llvm::SmallVector newArgs; if (failed(untupleAndLookupValues(oldOp.getOperands(), newArgs, builder, @@ -225,8 +230,9 @@ class DetupleConditionOp : public OpRewritePattern { LogicalResult matchAndRewrite(cf::CondBranchOp oldOp, PatternRewriter &builder) const override { - if (!hasTuples(oldOp.getOperands())) + if (!hasTuples(oldOp.getOperands())) { return builder.notifyMatchFailure(oldOp, "No detupling required"); + } llvm::SmallVector trueArgs; if (failed(untupleAndLookupValues(oldOp.getTrueOperands(), trueArgs, @@ -279,8 +285,9 @@ LogicalResult convertFunction(func::FuncOp oldFunction, // existing ones along path that produces tuples are used further, so just // remove instead of flattening. if (hasTupleSig && (attr.getName() == oldFunction.getArgAttrsAttrName() || - attr.getName() == oldFunction.getResAttrsAttrName())) + attr.getName() == oldFunction.getResAttrsAttrName())) { continue; + } newFunction->setAttr(attr.getName(), attr.getValue()); } diff --git a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/FlattenTuplesInSCF.cpp b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/FlattenTuplesInSCF.cpp index 5be0ac2a8338..145a2270996e 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/FlattenTuplesInSCF.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/FlattenTuplesInSCF.cpp @@ -114,8 +114,9 @@ class DetupleYieldOp : public OpRewritePattern { recursiveUntuple(operand, b, mapping, operands); } - if (!hasTuples) + if (!hasTuples) { return rewriter.notifyMatchFailure(op, "no tupled arguments"); + } rewriter.replaceOpWithNewOp(op, operands); return success(); @@ -137,8 +138,9 @@ class DetupleConditionOp : public OpRewritePattern { recursiveUntuple(operand, b, mapping, operands); } - if (!hasTuples) + if (!hasTuples) { return rewriter.notifyMatchFailure(op, "no tupled arguments"); + } rewriter.replaceOpWithNewOp(op, op.getCondition(), operands); @@ -159,8 +161,9 @@ class DetupleIfOp : public OpRewritePattern { hasTuples |= isa(type); } - if (!hasTuples) + if (!hasTuples) { return rewriter.notifyMatchFailure(op, "no tupled arguments"); + } llvm::SmallVector types; untupleTypes(op.getResultTypes(), types); @@ -204,8 +207,9 @@ class DetupleWhileOp : public OpRewritePattern { recursiveUntuple(operand, b, mapping, operands); } - if (!hasTuples) + if (!hasTuples) { return rewriter.notifyMatchFailure(op, "no tupled arguments"); + } llvm::SmallVector types; untupleTypes(op.getResultTypes(), types); diff --git a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/LowerComplex.cpp b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/LowerComplex.cpp index ff767d26b95b..11ec46a9887b 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/LowerComplex.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/LowerComplex.cpp @@ -96,12 +96,14 @@ ElementsAttr getSplat(Builder *b, RankedTensorType ty, T constant) { if (auto complexTy = dyn_cast(elementTy)) { auto complexElementTy = complexTy.getElementType(); - if (complexElementTy.isF32()) + if (complexElementTy.isF32()) { return DenseElementsAttr::get(ty, static_cast>(constant)); - if (complexElementTy.isF64()) + } + if (complexElementTy.isF64()) { return DenseElementsAttr::get( ty, static_cast>(constant)); + } } llvm_unreachable("unhandled element type"); } diff --git a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp index ba77c12a9bf4..c67ceb8e5bae 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp @@ -34,8 +34,9 @@ namespace { bool isIota(ArrayRef array) { for (auto [idx, value] : llvm::enumerate(array)) { - if (static_cast(idx) != value) + if (static_cast(idx) != value) { return false; + } } return true; } @@ -122,8 +123,9 @@ struct ReorderConvOpKernelDimensions final PatternRewriter &rewriter) const override { auto kernel = op.getRhs(); auto kernelType = cast(kernel.getType()); - if (!kernelType.hasRank()) + if (!kernelType.hasRank()) { return failure(); + } auto kernelShape = kernelType.getShape(); auto dimensionNumbers = op.getDimensionNumbers(); @@ -142,8 +144,9 @@ struct ReorderConvOpKernelDimensions final permutation.push_back(outputFeatureDimension); // If the permutation is iota, then no transpose is required. - if (isIota(permutation)) + if (isIota(permutation)) { return failure(); + } llvm::SmallVector transposeShape; for (int64_t perm : permutation) { @@ -253,8 +256,9 @@ struct ReorderConvOpOutputDimensions final bool isConsecutive(ArrayRef array) { for (size_t i = 1, e = array.size(); i < e; ++i) { - if (array[i] - array[i - 1] != 1) + if (array[i] - array[i - 1] != 1) { return false; + } } return true; } @@ -274,8 +278,9 @@ struct TransposeReshapeGenericDotGeneral final Value TransposeIfNonConsecutive(OpBuilder &b, Location loc, Value src, ArrayRef targetOrder) const { - if (isConsecutive(targetOrder)) + if (isConsecutive(targetOrder)) { return src; + } auto type = cast(src.getType()); SmallVector transposeShape; @@ -292,8 +297,9 @@ struct TransposeReshapeGenericDotGeneral final auto type = cast(src.getType()); ArrayRef shape = type.getShape(); if (dimsBorder0 <= 1 && dimsBorder1 - dimsBorder0 <= 1 && - shape.size() - dimsBorder1 <= 1) + shape.size() - dimsBorder1 <= 1) { return src; + } int64_t resultShape[] = { llvm::product_of(shape.take_front(dimsBorder0)), @@ -308,15 +314,17 @@ struct TransposeReshapeGenericDotGeneral final auto lhsShapeType = dyn_cast(op.getLhs().getType()); auto rhsShapeType = dyn_cast(op.getRhs().getType()); auto resultType = dyn_cast(op.getResult().getType()); - if (!lhsShapeType || !rhsShapeType || !resultType) + if (!lhsShapeType || !rhsShapeType || !resultType) { return failure(); + } // TODO(jpienaar): This pattern is not safe for dynamic shapes and seems to // be (now) redundant with later pass that does handle them. To decouple // fixing and verifying redundant, this just limits to static shapes and // then will remove this in follow up. - if (!lhsShapeType.hasStaticShape() || !rhsShapeType.hasStaticShape()) + if (!lhsShapeType.hasStaticShape() || !rhsShapeType.hasStaticShape()) { return failure(); + } SmallVector lhsTargetOrder, rhsTargetOrder; mlir::stablehlo::DotDimensionNumbersAttr dimNumbers = @@ -394,8 +402,9 @@ struct TransposeReshapeGenericDotGeneral final rhs = ReshapeIfNonStandard(rewriter, op.getLoc(), rhs, rhsBatchingDims.size(), numRhsContractionDims); - if (lhs == op.getLhs() && rhs == op.getRhs()) + if (lhs == op.getLhs() && rhs == op.getRhs()) { return rewriter.notifyMatchFailure(op, "already in canonical form"); + } auto dimensionNumbers = mlir::stablehlo::DotDimensionNumbersAttr::get( rewriter.getContext(), /*lhsBatchingDimensions=*/0, @@ -409,11 +418,13 @@ struct TransposeReshapeGenericDotGeneral final // batching、lhs parallel、rhs parallel this order is a conversion SmallVector newShape = {lhsNewType.getShape()[0]}; - if (lhsNewType.getRank() > 2) + if (lhsNewType.getRank() > 2) { newShape.push_back(lhsNewType.getDimSize(1)); + } - if (rhsNewType.getRank() > 2) + if (rhsNewType.getRank() > 2) { newShape.push_back(rhsNewType.getDimSize(2)); + } TensorType newResultType = RankedTensorType::get(newShape, resultType.getElementType()); @@ -537,8 +548,9 @@ struct ScatterImplicitBatch final static Value addUnitBatchDim(Location loc, Value value, PatternRewriter &rewriter) { auto valueTy = cast(value.getType()); - if (!valueTy.hasRank()) + if (!valueTy.hasRank()) { return nullptr; + } // Materialize the implicit indices dim. SmallVector reassociationMap(valueTy.getRank()); @@ -565,8 +577,9 @@ struct ScatterImplicitBatch final auto indicesTy = dyn_cast(indices.getType()); // Check whether indices has no batch dimension. - if (!indicesTy) + if (!indicesTy) { return failure(); + } if (indicesTy.getRank() != 1 || indexVectorDim != 0) { return rewriter.notifyMatchFailure(op, "no implicit batch dimension to add."); @@ -620,8 +633,9 @@ struct ScatterCollapseBatch final static Value collapseBatchDims(Location loc, Value value, int64_t batchCount, PatternRewriter &rewriter) { auto valueTy = dyn_cast(value.getType()); - if (!valueTy) + if (!valueTy) { return nullptr; + } SmallVector reassociationMap(1); reassociationMap.reserve(valueTy.getRank() - batchCount + 1); @@ -733,12 +747,14 @@ struct ScatterBatchFirst final : OpRewritePattern { llvm::SmallVector perm; perm.reserve(indicesTy.getRank()); for (int i = 0, s = indicesTy.getRank(); i < s; ++i) { - if (i != indexVectorDim) + if (i != indexVectorDim) { perm.push_back(i); + } } - if (perm.size() < indicesTy.getRank()) + if (perm.size() < indicesTy.getRank()) { perm.push_back(indexVectorDim); + } llvm::SmallVector newShape; for (int i = 0, s = perm.size(); i < s; ++i) { @@ -761,21 +777,25 @@ struct ScatterBatchFirst final : OpRewritePattern { // Determine which dimensions are batch dimensions. llvm::SmallVector isBatch(updates0Ty.getRank(), true); - for (int i = 0, s = updatedWindowDims.size(); i < s; ++i) + for (int i = 0, s = updatedWindowDims.size(); i < s; ++i) { isBatch[updatedWindowDims[i]] = false; + } // Permute batch dimensions to the start of the update tensor. llvm::SmallVector updatePerm; updatePerm.reserve(updates0Ty.getRank()); - for (int i = 0, s = isBatch.size(); i < s; ++i) - if (isBatch[i]) + for (int i = 0, s = isBatch.size(); i < s; ++i) { + if (isBatch[i]) { updatePerm.push_back(i); + } + } updatePerm.append(updatedWindowDims.begin(), updatedWindowDims.end()); llvm::SmallVector newUpdatedWindowDims; int64_t batchCount = updates0Ty.getRank() - updatedWindowDims.size(); - for (int i = batchCount, s = updates0Ty.getRank(); i < s; i++) + for (int i = batchCount, s = updates0Ty.getRank(); i < s; i++) { newUpdatedWindowDims.push_back(i); + } bool indicesChanged = indices != op.getScatterIndices(); bool updatesChanged = @@ -787,17 +807,19 @@ struct ScatterBatchFirst final : OpRewritePattern { auto updateTy = cast(update.getType()); llvm::SmallVector newShape; newShape.reserve(updateTy.getRank()); - for (int i = 0, s = updatePerm.size(); i < s; i++) + for (int i = 0, s = updatePerm.size(); i < s; i++) { newShape.push_back(updateTy.getDimSize(updatePerm[i])); + } update = mlir::stablehlo::TransposeOp::create( builder, updateTy.clone(newShape), update, builder.getDenseI64ArrayAttr(updatePerm)); } } - if (!indicesChanged && !updatesChanged) + if (!indicesChanged && !updatesChanged) { return rewriter.notifyMatchFailure( op, "batch dimensions are already leading"); + } auto newDimNumbers = mlir::stablehlo::ScatterDimensionNumbersAttr::get( op.getContext(), newUpdatedWindowDims, @@ -882,8 +904,9 @@ struct ScatterMaterializeInsertedDim final int64_t firstNonIndex = 0; for (int64_t s = scatterDimsToOperandDims.size(); firstNonIndex < s; ++firstNonIndex) { - if (!isIndexDim[firstNonIndex]) + if (!isIndexDim[firstNonIndex]) { break; + } } llvm::SmallVector isInsertDims(operandTy.getRank(), false); @@ -898,9 +921,9 @@ struct ScatterMaterializeInsertedDim final } } - llvm::ArrayRef toInsertDims = + auto toInsertDims = llvm::ArrayRef(isInsertDims).drop_front(frontInsertedDims); - if (!llvm::any_of(toInsertDims, [](auto d) { return d; })) { + if (llvm::none_of(toInsertDims, [](bool d) { return d; })) { return rewriter.notifyMatchFailure(op, "no dimensions to insert"); } @@ -908,9 +931,10 @@ struct ScatterMaterializeInsertedDim final SmallVector reassociationMap; reassociationMap.push_back({rewriter.getAffineDimExpr(0)}); - for (auto it : llvm::enumerate(llvm::ArrayRef(toInsertDims))) { - if (!it.value()) + for (auto it : llvm::enumerate(toInsertDims)) { + if (!it.value()) { reassociationMap.push_back({}); + } reassociationMap.back().push_back( rewriter.getAffineDimExpr(it.index() + 1)); } @@ -962,8 +986,9 @@ struct ScatterMaterializeInsertedDim final bool isFromBool(Value val) { while (true) { Operation *op = val.getDefiningOp(); - if (!op) + if (!op) { return false; + } if (auto convertOp = dyn_cast(op)) { auto inTy = cast(convertOp.getOperand().getType()); @@ -993,17 +1018,20 @@ struct MulCastOfBool final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op, PatternRewriter &rewriter) const override { auto resultTy = cast(op.getType()); - if (!isa(resultTy.getElementType())) + if (!isa(resultTy.getElementType())) { return failure(); + } Value lhs = op.getLhs(); Value rhs = op.getRhs(); bool lhsIsBool = isFromBool(lhs); bool rhsIsBool = isFromBool(rhs); - if (lhsIsBool == rhsIsBool) + if (lhsIsBool == rhsIsBool) { return failure(); - if (rhsIsBool) + } + if (rhsIsBool) { std::swap(lhs, rhs); + } Type eType = resultTy.getElementType(); auto lhsTy = cast(lhs.getType()); @@ -1023,8 +1051,9 @@ struct MulCastOfBool final : OpRewritePattern { auto valueTy = cast(value.getType()); auto newTy = RankedTensorType::get(resultTy.getShape(), valueTy.getElementType()); - if (valueTy == newTy) + if (valueTy == newTy) { return value; + } auto dimensions = llvm::to_vector( llvm::seq(resultRank - valueTy.getRank(), resultRank)); return mlir::stablehlo::DynamicBroadcastInDimOp::create( @@ -1047,19 +1076,22 @@ struct ExpandRngNormal final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::RngOp op, PatternRewriter &rewriter) const override { - if (op.getRngDistribution() != mlir::stablehlo::RngDistribution::NORMAL) + if (op.getRngDistribution() != mlir::stablehlo::RngDistribution::NORMAL) { return failure(); + } auto resTy = dyn_cast(op.getType()); // We can support static shapes, but it's easier to implement Box-Muller // transform if we know the number of elements. - if (!resTy || !resTy.hasStaticShape()) + if (!resTy || !resTy.hasStaticShape()) { return failure(); + } // The algorithm requires even numbers and will generate pairs. auto numElems = resTy.getNumElements(); - if (numElems & 1) + if (numElems & 1) { numElems++; + } auto halfNumElems = numElems / 2; ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -1193,11 +1225,13 @@ struct ReorderBroadcastInDimOpAndElementwiseOp final // NOTE: bcastOps may contain duplicates. SetVector deadOps; for (auto bcastOp : bcastOps) { - if (bcastOp.getOperation()->use_empty()) + if (bcastOp.getOperation()->use_empty()) { deadOps.insert(bcastOp); + } } - for (auto *deadOp : deadOps) + for (auto *deadOp : deadOps) { rewriter.eraseOp(deadOp); + } return success(); } @@ -1238,8 +1272,9 @@ struct FuseWidenOperands final : OpRewritePattern { if (llvm::all_of( llvm::zip_equal(operands, op->getOperands()), - [](auto pair) { return std::get<0>(pair) == std::get<1>(pair); })) + [](auto pair) { return std::get<0>(pair) == std::get<1>(pair); })) { return failure(); + } rewriter.replaceOpWithNewOp(op, op->getResultTypes(), operands, op->getAttrs()); @@ -1266,8 +1301,9 @@ struct DotToMul final : OpRewritePattern { return rewriter.notifyMatchFailure(op, "lhs and rhs must be rank-2"); } - if (lhsTy.getDimSize(1) != 1) + if (lhsTy.getDimSize(1) != 1) { return failure(); + } // Dynamically compute the shape of the result of the DotOp by querying // the 0-th dimensions, of the left, and the 1st dimension of the right. @@ -1298,10 +1334,13 @@ struct DotToMul final : OpRewritePattern { outSize, rewriter.getDenseI64ArrayAttr({0, 1})); auto computeETy = lhsTy.getElementType(); - if (computeETy.getIntOrFloatBitWidth() < rhsTy.getElementTypeBitWidth()) + if (computeETy.getIntOrFloatBitWidth() < rhsTy.getElementTypeBitWidth()) { computeETy = rhsTy.getElementType(); - if (computeETy.getIntOrFloatBitWidth() < resultTy.getElementTypeBitWidth()) + } + if (computeETy.getIntOrFloatBitWidth() < + resultTy.getElementTypeBitWidth()) { computeETy = resultTy.getElementType(); + } auto computeTy = resultTy.clone(computeETy); @@ -1362,8 +1401,9 @@ struct ZeroConcat final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::ConcatenateOp op, PatternRewriter &rewriter) const override { auto type = dyn_cast(op.getType()); - if (!type || !type.hasStaticShape()) + if (!type || !type.hasStaticShape()) { return failure(); + } uint64_t axis = op.getDimension(); OperandRange origInputs = op.getInputs(); @@ -1371,15 +1411,18 @@ struct ZeroConcat final : OpRewritePattern { for (auto input : origInputs) { auto type = dyn_cast(input.getType()); ArrayRef shape = type.getShape(); - if (axis > shape.size()) + if (axis > shape.size()) { return failure(); + } - if (shape[axis] != 0) + if (shape[axis] != 0) { nonzeroInputs.push_back(input); + } } - if (nonzeroInputs.size() == origInputs.size()) + if (nonzeroInputs.size() == origInputs.size()) { return failure(); + } rewriter.replaceOpWithNewOp( op, nonzeroInputs, /*dimension=*/axis); @@ -1402,8 +1445,9 @@ struct DotGeneralIsMul final : OpRewritePattern { auto resultTy = dyn_cast(op.getType()); ImplicitLocOpBuilder builder(op.getLoc(), rewriter); - if (!lhsTy || !rhsTy || !resultTy) + if (!lhsTy || !rhsTy || !resultTy) { return failure(); + } auto dNums = op.getDotDimensionNumbers(); auto batchDimsL = dNums.getLhsBatchingDimensions(); @@ -1414,14 +1458,18 @@ struct DotGeneralIsMul final : OpRewritePattern { llvm::SmallVector isLhsParallelDim(lhsTy.getRank(), true); llvm::SmallVector isRhsParallelDim(rhsTy.getRank(), true); - for (auto dim : batchDimsL) + for (auto dim : batchDimsL) { isLhsParallelDim[dim] = false; - for (auto dim : batchDimsR) + } + for (auto dim : batchDimsR) { isRhsParallelDim[dim] = false; - for (auto dim : contractDimsL) + } + for (auto dim : contractDimsL) { isLhsParallelDim[dim] = false; - for (auto dim : contractDimsR) + } + for (auto dim : contractDimsR) { isRhsParallelDim[dim] = false; + } for (auto dim : contractDimsL) { if (lhsTy.getDimSize(dim) != 1) { @@ -1437,13 +1485,15 @@ struct DotGeneralIsMul final : OpRewritePattern { permRhs.append(batchDimsR.begin(), batchDimsR.end()); for (auto [idx, value] : llvm::enumerate(isLhsParallelDim)) { - if (value) + if (value) { permLhs.push_back(idx); + } } for (auto [idx, value] : llvm::enumerate(isRhsParallelDim)) { - if (value) + if (value) { permRhs.push_back(idx); + } } llvm::append_range(permLhs, contractDimsL); @@ -1452,10 +1502,12 @@ struct DotGeneralIsMul final : OpRewritePattern { // Determine the transpose shape based on the generate permutations. llvm::SmallVector lhsTransposeShape; llvm::SmallVector rhsTransposeShape; - for (auto dim : permLhs) + for (auto dim : permLhs) { lhsTransposeShape.push_back(lhsTy.getDimSize(dim)); - for (auto dim : permRhs) + } + for (auto dim : permRhs) { rhsTransposeShape.push_back(rhsTy.getDimSize(dim)); + } // Transpose the left hand side and the right hand side. lhs = mlir::stablehlo::TransposeOp::create( @@ -1733,9 +1785,10 @@ struct IotaSortSliceIsTopK final : OpRewritePattern { int64_t k; // Check that the output of the sort op gets fed into a slice. for (auto [idx, result] : llvm::enumerate(opResults)) { - if (result.getUsers().empty()) + if (result.getUsers().empty()) { return rewriter.notifyMatchFailure( op, "sort isn't calling into a slice op"); + } auto sliceOp = dyn_cast(*result.getUsers().begin()); if (!sliceOp) { @@ -1774,8 +1827,9 @@ struct ApproxTopK final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::CustomCallOp op, PatternRewriter &rewriter) const override { - if (op.getCallTargetName() != "ApproxTopK") + if (op.getCallTargetName() != "ApproxTopK") { return rewriter.notifyMatchFailure(op, "not ApproxTopK operation."); + } auto computationName = dyn_cast(op.getCalledComputationsAttr()[0]); @@ -1784,11 +1838,13 @@ struct ApproxTopK final : OpRewritePattern { parent = parent->getParentOp()) { funcOp = SymbolTable::lookupNearestSymbolFrom( parent, computationName); - if (funcOp) + if (funcOp) { break; + } } - if (!funcOp) + if (!funcOp) { return rewriter.notifyMatchFailure(op, "computation function not found."); + } int64_t k = cast(op.getType(0)).getShape().back(); auto input = op.getOperand(0); diff --git a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/test/BUILD.bazel b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/test/BUILD.bazel index 988c92ddc26d..1ddae17e0da1 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/test/BUILD.bazel +++ b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "canonicalization.mlir", "canonicalize_dot_general.mlir", diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOCustomCalls.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOCustomCalls.cpp index c26220cdf5b6..e7314c129b3a 100644 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOCustomCalls.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOCustomCalls.cpp @@ -193,8 +193,9 @@ struct HouseholderReflectorRewriter final Value householder = computeHouseholderSlice(matrix, tau, iv, b); std::vector batch(rank - 2); - for (int i = 0; i < rank - 2; ++i) + for (int i = 0; i < rank - 2; ++i) { batch[i] = i; + } std::vector lhsContract = {rank - 1}; std::vector rhsContract = {rank - 2}; diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp index 75ba4b3817b8..3b264d2603a9 100644 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp @@ -91,8 +91,9 @@ struct ConcatenateOpConversion final auto toOpFoldResult = [](Value v) -> OpFoldResult { auto op = v.getDefiningOp(); - if (!op) + if (!op) { return v; + } return op.getValue(); }; @@ -233,8 +234,9 @@ static bool isValidFuncAttr(DictionaryAttr attrs) { // TODO: switch to using a dialect-based exclusion list or some other way that // is not a big string table. for (auto attr : attrs) { - if (attr.getName() == "tf.aliasing_output") + if (attr.getName() == "tf.aliasing_output") { return false; + } } return true; } @@ -246,13 +248,15 @@ static void setFuncEncodings(func::FuncOp funcOp, FunctionType oldFuncType, auto encodingName = StringAttr::get(funcOp.getContext(), "iree.abi.encoding"); for (auto [i, oldType, newType] : llvm::enumerate(oldFuncType.getInputs(), newFuncType.getInputs())) { - if (oldType != newType) + if (oldType != newType) { funcOp.setArgAttr(i, encodingName, TypeAttr::get(oldType)); + } } for (auto [i, oldType, newType] : llvm::enumerate(oldFuncType.getResults(), newFuncType.getResults())) { - if (oldType != newType) + if (oldType != newType) { funcOp.setResultAttr(i, encodingName, TypeAttr::get(oldType)); + } } } @@ -347,11 +351,13 @@ struct TensorEmptyPattern final : OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto oldType = cast(op.getType()); auto newType = getTypeConverter()->convertType(oldType); - if (newType == oldType) + if (newType == oldType) { return failure(); + } - if (!newType) + if (!newType) { return rewriter.notifyMatchFailure(op, "result type conversion failed"); + } rewriter.replaceOpWithNewOp( op, oldType.getShape(), @@ -369,8 +375,9 @@ struct GlobalOpPattern final : OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Type oldType = globalOp.getType(); Type newType = getTypeConverter()->convertType(oldType); - if (newType == oldType) + if (newType == oldType) { return failure(); + } if (!newType) { return rewriter.notifyMatchFailure(globalOp, "result type conversion failed"); @@ -452,21 +459,24 @@ static void stripFrontendAttrs(mlir::ModuleOp moduleOp) { auto filterOpAttrs = [&](Operation *op) { SmallVector newAttrs; for (auto attr : op->getDialectAttrs()) { - if (!isAttrFiltered(attr)) + if (!isAttrFiltered(attr)) { newAttrs.push_back(attr); + } } op->setDialectAttrs(newAttrs); }; auto filterAttrDicts = [&](ArrayAttr allOldAttrs, SmallVectorImpl &newAttrs) { - if (!allOldAttrs) + if (!allOldAttrs) { return false; + } for (auto oldAttrs : allOldAttrs.getAsRange()) { SmallVector preservedAttrs; preservedAttrs.reserve(oldAttrs.size()); for (auto attr : oldAttrs) { - if (!isAttrFiltered(attr)) + if (!isAttrFiltered(attr)) { preservedAttrs.push_back(attr); + } } newAttrs.push_back( DictionaryAttr::get(allOldAttrs.getContext(), preservedAttrs)); @@ -554,12 +564,14 @@ struct ConvertStableHloToIreeInputDialects final auto isIllegalType = [&](Type t) { return !typeConverter->isLegal(t); }; auto isLegallyTypedOp = [&](Operation *op) -> bool { for (Type type : op->getResultTypes()) { - if (isIllegalType(type)) + if (isIllegalType(type)) { return false; + } } for (Type type : op->getOperandTypes()) { - if (isIllegalType(type)) + if (isIllegalType(type)) { return false; + } } return true; }; @@ -582,17 +594,20 @@ struct ConvertStableHloToIreeInputDialects final } } for (Type type : funcOp.getFunctionType().getInputs()) { - if (isIllegalType(type)) + if (isIllegalType(type)) { return false; + } } for (Type type : funcOp.getFunctionType().getResults()) { - if (isIllegalType(type)) + if (isIllegalType(type)) { return false; + } } for (Block &block : funcOp.getFunctionBody()) { for (Type type : block.getArgumentTypes()) { - if (isIllegalType(type)) + if (isIllegalType(type)) { return false; + } } } return true; diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp index 8c3fab7e1a39..9bcdaf7f4962 100644 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp @@ -51,8 +51,9 @@ Type convertIntegerToSignless(IntegerType intType) { } std::optional convertRank0TensorToScalar(RankedTensorType tensorType) { - if (tensorType.getRank() != 0) + if (tensorType.getRank() != 0) { return std::nullopt; + } Type elementType = tensorType.getElementType(); if (auto intType = dyn_cast(elementType)) { elementType = convertIntegerToSignless(intType); @@ -72,8 +73,9 @@ Value materializeCast(OpBuilder &builder, Type toType, ValueRange inputs, assert(inputs.size() == 1 && "too many inputs to type conversion"); Value fromValue = inputs[0]; auto fromType = dyn_cast(fromValue.getType()); - if (!fromType) + if (!fromType) { return Value(); + } if (auto intFromType = dyn_cast(fromType.getElementType())) { Type castType = getElementTypeOrSelf(toType); @@ -88,8 +90,9 @@ Value materializeCast(OpBuilder &builder, Type toType, ValueRange inputs, } } - if (fromType.getRank() != 0) + if (fromType.getRank() != 0) { return fromValue; + } Type extractType = getElementTypeOrSelf(toType); return builder.createOrFold(loc, extractType, fromValue); @@ -131,11 +134,13 @@ struct LinalgExtRegionHLOOpConversion final : OpConversionPattern { LogicalResult matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!isInBodyOfLinalgExtOps(op)) + if (!isInBodyOfLinalgExtOps(op)) { return failure(); + } TensorType origRetType = dyn_cast(op.getType()); - if (!origRetType) + if (!origRetType) { return failure(); + } SmallVector scalarArgs; Type newRetType = getElementTypeOrSelf( this->typeConverter->convertType(origRetType.getElementType())); @@ -152,8 +157,9 @@ struct LinalgExtRegionReturnOpConversion final LogicalResult matchAndRewrite(mlir::stablehlo::ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!isInBodyOfLinalgExtOps(op)) + if (!isInBodyOfLinalgExtOps(op)) { return failure(); + } rewriter.replaceOpWithNewOp( op, adaptor.getOperands()); return success(); @@ -222,23 +228,28 @@ struct ScatterOpConversion final auto indexDepth = indicesType.getShape().back(); auto scatterDimsToOperandDims = dimNumbers.getScatterDimsToOperandDims(); - if (indicesRank != 2) + if (indicesRank != 2) { return false; - if (indexVectorDim != indicesRank - 1) + } + if (indexVectorDim != indicesRank - 1) { return false; - if (scatterDimsToOperandDims.size() != indexDepth) + } + if (scatterDimsToOperandDims.size() != indexDepth) { return false; + } auto insertedWindowDims = dimNumbers.getInsertedWindowDims(); for (auto [idx, dim] : llvm::enumerate(insertedWindowDims)) { - if (idx != dim) + if (idx != dim) { return false; + } } // Check that there is only one batch dimension in the updates. for (auto [idx, dim] : llvm::enumerate(dimNumbers.getUpdateWindowDims())) { - if (idx + 1 != dim) + if (idx + 1 != dim) { return false; + } } return true; @@ -247,12 +258,15 @@ struct ScatterOpConversion final LogicalResult matchAndRewrite(mlir::stablehlo::ScatterOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!hasCanonicalDimensionNumbers(op)) + if (!hasCanonicalDimensionNumbers(op)) { return failure(); - if (llvm::size(op.getInputs()) != 1) + } + if (llvm::size(op.getInputs()) != 1) { return op.emitError("NYI variadic operands scatter"); - if (llvm::size(op.getUpdates()) != 1) + } + if (llvm::size(op.getUpdates()) != 1) { return op.emitError("NYI variadic updates scatter"); + } ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -335,8 +349,9 @@ struct ReverseOpConversion final matchAndRewrite(mlir::stablehlo::ReverseOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto ty = dyn_cast(adaptor.getOperands()[0].getType()); - if (!ty) + if (!ty) { return failure(); + } Value input = op.getOperand(); auto inputTy = cast(input.getType()); @@ -426,8 +441,9 @@ struct ScanOpConversion final auto window = llvm::to_vector(op.getWindowDimensions()); llvm::SmallVector reduceAxes; for (int i = 0, s = window.size(); i < s; ++i) { - if (window[i] == 1) + if (window[i] == 1) { continue; + } if (window[i] == input0Ty.getDimSize(i)) { reduceAxes.push_back(i); continue; @@ -454,8 +470,9 @@ struct ScanOpConversion final } for (int i = 0, s = padding.size(); i < s; i += 2) { - if (i == reduceAxis * 2) + if (i == reduceAxis * 2) { continue; + } if (padding[i] != 0 || padding[i + 1] != 0) { return rewriter.notifyMatchFailure(op, "padding along non-reduction axis"); @@ -484,8 +501,9 @@ struct ScanOpConversion final llvm::SmallVector initDims; llvm::SmallVector initDynDims; for (int i = 0; i < input0Ty.getRank(); ++i) { - if (i == reduceAxis) + if (i == reduceAxis) { continue; + } initDims.push_back(input0Ty.getDimSize(i)); if (ShapedType::isDynamic(initDims.back())) { initDynDims.push_back( diff --git a/compiler/plugins/input/StableHLO/Conversion/test/BUILD.bazel b/compiler/plugins/input/StableHLO/Conversion/test/BUILD.bazel index c8f2420d6ae4..ca0da7f359e9 100644 --- a/compiler/plugins/input/StableHLO/Conversion/test/BUILD.bazel +++ b/compiler/plugins/input/StableHLO/Conversion/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "auto_input_conversion.mlir", "convert_collectives.mlir", diff --git a/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp b/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp index 6acdea63a8d6..2899ce23ea08 100644 --- a/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp +++ b/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp @@ -139,21 +139,25 @@ void Converti48Toi64Pass::runOnOperation() { target.markUnknownOpDynamicallyLegal([](Operation *op) { if (auto funcOp = dyn_cast(op)) { for (Type type : funcOp.getArgumentTypes()) { - if (isIllegalType(type)) + if (isIllegalType(type)) { return false; + } } for (Type type : funcOp.getResultTypes()) { - if (isIllegalType(type)) + if (isIllegalType(type)) { return false; + } } } for (Type type : op->getResultTypes()) { - if (type && isIllegalType(type)) + if (type && isIllegalType(type)) { return false; + } } for (Type type : op->getOperandTypes()) { - if (type && isIllegalType(type)) + if (type && isIllegalType(type)) { return false; + } } for (auto attr : op->getAttrs()) { if (auto typedAttr = dyn_cast(attr.getValue())) { diff --git a/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp b/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp index dc544ff4921c..7763ff95ee79 100644 --- a/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp +++ b/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp @@ -78,8 +78,9 @@ class GenericTypeConvert : public ConversionPattern { }; static bool isIllegalType(Type type) { - if (IntegerType ity = dyn_cast(type)) + if (IntegerType ity = dyn_cast(type)) { return !ity.isSignless(); + } if (auto shapedType = dyn_cast(type)) { return isIllegalType(shapedType.getElementType()); } @@ -94,21 +95,25 @@ void StripSignednessPass::runOnOperation() { target.markUnknownOpDynamicallyLegal([](Operation *op) { if (auto funcOp = dyn_cast(op)) { for (Type type : funcOp.getArgumentTypes()) { - if (isIllegalType(type)) + if (isIllegalType(type)) { return false; + } } for (Type type : funcOp.getResultTypes()) { - if (isIllegalType(type)) + if (isIllegalType(type)) { return false; + } } } for (Type type : op->getResultTypes()) { - if (type && isIllegalType(type)) + if (type && isIllegalType(type)) { return false; + } } for (Type type : op->getOperandTypes()) { - if (type && isIllegalType(type)) + if (type && isIllegalType(type)) { return false; + } } return true; }); diff --git a/compiler/plugins/input/TOSA/InputConversion/TosaToLinalgExt.cpp b/compiler/plugins/input/TOSA/InputConversion/TosaToLinalgExt.cpp index 8e7561953a6f..500d309c5b31 100644 --- a/compiler/plugins/input/TOSA/InputConversion/TosaToLinalgExt.cpp +++ b/compiler/plugins/input/TOSA/InputConversion/TosaToLinalgExt.cpp @@ -42,9 +42,10 @@ class ScatterConversion : public OpRewritePattern { auto updatesTy = dyn_cast(updates.getType()); ImplicitLocOpBuilder builder(op.getLoc(), rewriter); - if (!valuesTy || !indicesTy || !updatesTy) + if (!valuesTy || !indicesTy || !updatesTy) { return rewriter.notifyMatchFailure(op, "tosa.gather has unknown input rank"); + } // TOSA's scatter does not include a index dimension, instead it implicitly // supports an index depth of one. We materialize that implicit index of @@ -68,9 +69,11 @@ class ScatterConversion : public OpRewritePattern { // Materialize the batch indice as LinalgExt scatter is not batched. { llvm::SmallVector dynDims; - for (int i = 0, s = indicesTy.getRank(); i < s; ++i) - if (indicesTy.isDynamicDim(i)) + for (int i = 0, s = indicesTy.getRank(); i < s; ++i) { + if (indicesTy.isDynamicDim(i)) { dynDims.push_back(tensor::DimOp::create(builder, indices, i)); + } + } Value empty = tensor::EmptyOp::create( builder, indicesTy.getShape(), indicesTy.getElementType(), dynDims); @@ -159,8 +162,9 @@ class TosaToLinalgExtPass final mlir::FunctionOpInterface funcOp = getOperation(); mlir::iree_compiler::populateTosaToLinalgExtPatterns(&patterns); - if (failed(applyFullConversion(funcOp, target, std::move(patterns)))) + if (failed(applyFullConversion(funcOp, target, std::move(patterns)))) { signalPassFailure(); + } } }; diff --git a/compiler/plugins/input/TOSA/InputConversion/test/BUILD.bazel b/compiler/plugins/input/TOSA/InputConversion/test/BUILD.bazel index 8d592371c72d..9c9e6a66d044 100644 --- a/compiler/plugins/input/TOSA/InputConversion/test/BUILD.bazel +++ b/compiler/plugins/input/TOSA/InputConversion/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "apply_pdl_patterns_tosa.mlir", "auto_input_conversion.mlir", diff --git a/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp b/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp index 5fbef105cd71..b46fce9904d5 100644 --- a/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp +++ b/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp @@ -86,13 +86,15 @@ class BindSymbolicShapesPass final auto operand = bindOp.getOperand(); // Torch programs are single block and use structured control flow, so // presume this is an entrypoint. - if (isa(operand)) + if (isa(operand)) { return true; + } // Mutable tensors can exist at the boundary and must be "copied" to a // vtensor prior to use. Therefore, we anchor on the point of copy. - if (operand.getDefiningOp()) + if (operand.getDefiningOp()) { return true; + } return false; } @@ -117,10 +119,12 @@ class BindSymbolicShapesPass final // Gets the canonical dim for this symbol, returning {} if there // is no canonical dim. Value getCanonicalDimValue(OpBuilder &builder) { - if (canonicalDimValue) + if (canonicalDimValue) { return canonicalDimValue; - if (equalityDimInfos.empty()) + } + if (equalityDimInfos.empty()) { return {}; + } canonicalDimValue = getEqualityDimValue(builder, 0); return canonicalDimValue; } @@ -213,8 +217,9 @@ class BindSymbolicShapesPass final std::optional> evaluateExprBounds(AffineExpr expr, llvm::DenseMap &symbolInfos) { - if (!expr.isSymbolicOrConstant()) + if (!expr.isSymbolicOrConstant()) { return {}; + } llvm::SmallVector> lowerBounds; llvm::SmallVector> upperBounds; lowerBounds.reserve(symbols.size()); @@ -233,14 +238,16 @@ class BindSymbolicShapesPass final auto upperBound = getBoundForAffineExpr( expr, /*numDims=*/0, /*numSymbols=*/symbols.size(), lowerBounds, upperBounds, /*isUpper=*/true); - if (!upperBound) + if (!upperBound) { return {}; + } auto lowerBound = getBoundForAffineExpr( expr, /*numDims=*/0, /*numSymbols=*/symbols.size(), lowerBounds, upperBounds, /*isUpper=*/false); - if (!lowerBound) + if (!lowerBound) { return {}; + } return std::make_pair(*lowerBound, *upperBound); } @@ -250,8 +257,9 @@ class BindSymbolicShapesPass final void associateEqualityDims(llvm::DenseMap &symbolInfos) { OpBuilder builder(anchorOp); for (auto [index, expr] : llvm::enumerate(shapeMap.getResults())) { - if (expr.getKind() != AffineExprKind::SymbolId) + if (expr.getKind() != AffineExprKind::SymbolId) { continue; + } auto symbolPos = cast(expr).getPosition(); Value symbol = symbols[symbolPos]; auto symbolInfoIt = symbolInfos.find(symbol); @@ -268,12 +276,14 @@ class BindSymbolicShapesPass final if (auto binaryExpr = dyn_cast(genericExpr)) { auto lhs = materializeDimExpr(loc, builder, binaryExpr.getLHS(), symbolInfos); - if (!lhs) + if (!lhs) { return {}; + } auto rhs = materializeDimExpr(loc, builder, binaryExpr.getRHS(), symbolInfos); - if (!rhs) + if (!rhs) { return {}; + } switch (binaryExpr.getKind()) { case AffineExprKind::Add: @@ -303,12 +313,14 @@ class BindSymbolicShapesPass final case AffineExprKind::SymbolId: { auto symExpr = cast(genericExpr); auto pos = symExpr.getPosition(); - if (pos >= symbols.size()) + if (pos >= symbols.size()) { break; + } Value symbolValue = symbols[pos]; auto foundIt = symbolInfos.find(symbolValue); - if (foundIt == symbolInfos.end()) + if (foundIt == symbolInfos.end()) { break; + } SymbolInfo &info = foundIt->second; return info.getCanonicalDimValue(builder); // May legally return {} } @@ -327,8 +339,9 @@ class BindSymbolicShapesPass final void materializeDims(llvm::DenseMap &symbolInfos) { OpBuilder builder(anchorOp); for (auto [index, expr] : llvm::enumerate(shapeMap.getResults())) { - if (!builtinTensorType.isDynamicDim(index)) + if (!builtinTensorType.isDynamicDim(index)) { continue; + } Value dimValue = materializeDimExpr(anchorOp->getLoc(), builder, expr, symbolInfos); @@ -412,8 +425,9 @@ class BindSymbolicShapesPass final SymbolInfo(symbolOp)); } else if (auto bindOp = dyn_cast(childOp)) { cleanupOpList.push_back(bindOp); - if (!isEligibleBinding(bindOp)) + if (!isEligibleBinding(bindOp)) { return; + } auto torchType = cast(bindOp.getOperand().getType()); auto builtinType = dyn_cast_if_present( diff --git a/compiler/plugins/input/Torch/InputConversion/BitCastTensor.cpp b/compiler/plugins/input/Torch/InputConversion/BitCastTensor.cpp index 1b3cfacb9dcc..a2a761adc592 100644 --- a/compiler/plugins/input/Torch/InputConversion/BitCastTensor.cpp +++ b/compiler/plugins/input/Torch/InputConversion/BitCastTensor.cpp @@ -142,30 +142,37 @@ class BitCastMatmul : public OpRewritePattern { return success(); }; int unpackedBitWidth; - if (failed(getConstantIntegerFromDefiningOp(bitWidth, unpackedBitWidth))) + if (failed(getConstantIntegerFromDefiningOp(bitWidth, unpackedBitWidth))) { return failure(); + } auto rhsType = dyn_cast(rhs.getType()); - if (!rhsType) + if (!rhsType) { return failure(); + } - if (!rhsType.hasDtype()) + if (!rhsType.hasDtype()) { return failure(); + } Type dType = rhsType.getDtype(); int dTypeWidth = dType.getIntOrFloatBitWidth(); // If the dtype width already matches the target width, nothing to do. - if (dTypeWidth == unpackedBitWidth) + if (dTypeWidth == unpackedBitWidth) { return failure(); + } - if (!rhsType.hasSizes()) + if (!rhsType.hasSizes()) { return failure(); + } SmallVector tensorShape(rhsType.getSizes()); // Constants should have constant shape. - if (llvm::any_of(tensorShape, - [](int64_t s) { return s == torch::Torch::kUnknownSize; })) + if (llvm::any_of(tensorShape, [](int64_t s) { + return s == torch::Torch::kUnknownSize; + })) { return failure(); + } int packRatio = dTypeWidth / unpackedBitWidth; tensorShape[tensorShape.size() - 1] *= packRatio; @@ -185,10 +192,11 @@ class BitCastMatmul : public OpRewritePattern { // Cast back to the (un)signed torch tensor type to inform later lowerings. Type unpackedElementType; - if (dType.isSignedInteger()) + if (dType.isSignedInteger()) { unpackedElementType = rewriter.getIntegerType(unpackedBitWidth, true); - else + } else { unpackedElementType = rewriter.getIntegerType(unpackedBitWidth, false); + } torch::Torch::ValueTensorType newRhsType = torch::Torch::ValueTensorType::get(rewriter.getContext(), tensorShape, unpackedElementType); @@ -215,8 +223,9 @@ class BitCastTensorPass final patterns.add(context); patterns.add, BitCastViewComplex>(context); - if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); + } } }; } // namespace diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp index e5ce5246f614..9aeee4d25549 100644 --- a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp +++ b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp @@ -51,8 +51,9 @@ struct ScatterOpConversion LogicalResult matchAndRewrite(mlir::torch::TMTensor::ScatterOp op, PatternRewriter &rewriter) const override { auto indicesTy = op.getIndicesType(); - if (!indicesTy.hasRank()) + if (!indicesTy.hasRank()) { return failure(); + } if (indicesTy.isDynamicDim(indicesTy.getRank() - 1)) { return rewriter.notifyMatchFailure(op, "number of indices is unknown"); @@ -60,8 +61,9 @@ struct ScatterOpConversion auto numIndices = indicesTy.getShape().back(); llvm::SmallVector dimMap(numIndices); - for (int i = 0; i < numIndices; i++) + for (int i = 0; i < numIndices; i++) { dimMap[i] = i; + } auto updatesTy = op.getUpdateType(); @@ -182,8 +184,9 @@ struct AttentionOpConversion int64_t numBatches = op.getQueryType().getRank() - 2; for (AffineMap &map : indexingMaps) { map = map.shiftDims(numBatches); - if (map.getNumResults() == 0) + if (map.getNumResults() == 0) { continue; + } for (int batch : llvm::seq(numBatches)) { map = map.insertResult(rewriter.getAffineDimExpr(batch), batch); } diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp index 7f18e169f157..5947639692b7 100644 --- a/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp +++ b/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp @@ -6,7 +6,15 @@ #include "compiler/plugins/input/Torch/InputConversion/Passes.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" +#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" +#include "llvm/ADT/APFloat.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" @@ -163,21 +171,371 @@ struct FftRfftOpConversion } }; +// Utility to add a score modification region to the attention op. +void createScoreModificationRegion( + PatternRewriter &rewriter, Location loc, + IREE::LinalgExt::AttentionOp attentionOp, + std::optional scoreModSymbol, FloatType floatType, + const int kAttentionRank) { + OpBuilder::InsertionGuard g(rewriter); + Block *block = rewriter.createBlock(&attentionOp.getRegion()); + + block->addArgument(floatType, loc); + rewriter.setInsertionPointToStart(block); + + Value score = block->getArgument(0); + Value modifiedScore = score; + + if (scoreModSymbol) { + Type i32Type = rewriter.getI32Type(); + Type si32Type = + IntegerType::get(rewriter.getContext(), 32, IntegerType::Signed); + RankedTensorType scalarTensorType = RankedTensorType::get({}, floatType); + torch::Torch::ValueTensorType torchScalarType = + rewriter.getType(ArrayRef{}, + floatType); + RankedTensorType i32ScalarTensorType = RankedTensorType::get({}, i32Type); + torch::Torch::ValueTensorType torchI32ScalarType = + rewriter.getType(ArrayRef{}, + si32Type); + + Value scoreTensor = tensor::FromElementsOp::create( + rewriter, loc, scalarTensorType, ValueRange{score}); + Value torchScore = torch::TorchConversion::FromBuiltinTensorOp::create( + rewriter, loc, torchScalarType, scoreTensor); + + SmallVector callArgs; + callArgs.push_back(torchScore); + + for (unsigned i = 0; i < kAttentionRank; ++i) { + Value idx = IREE::LinalgExt::IndexOp::create(rewriter, loc, i); + Value idxI32 = arith::IndexCastOp::create(rewriter, loc, i32Type, idx); + Value idxTensor = tensor::FromElementsOp::create( + rewriter, loc, i32ScalarTensorType, ValueRange{idxI32}); + Value torchIdx = torch::TorchConversion::FromBuiltinTensorOp::create( + rewriter, loc, torchI32ScalarType, idxTensor); + callArgs.push_back(torchIdx); + } + + auto callOp = + func::CallOp::create(rewriter, loc, TypeRange{torchScalarType}, + scoreModSymbol.value(), ValueRange(callArgs)); + Value torchResult = callOp.getResult(0); + + Value resultTensor = torch::TorchConversion::ToBuiltinTensorOp::create( + rewriter, loc, scalarTensorType, torchResult); + + modifiedScore = + tensor::ExtractOp::create(rewriter, loc, resultTensor, ValueRange{}); + } + + IREE::LinalgExt::YieldOp::create(rewriter, loc, modifiedScore); +} + +// Utility to compute dynamic sizes for attention tensors. +// This helper is used in two places: +// +// For the mask tensor. Shape = (B, Hq, L, S). Any of these may be dynamic, so +// we extract B/Hq/L from the query tensor and S from the key tensor. The +// resulting dynamic sizes are passed to tensor.empty when materialising the +// mask. +// +// For the output tensor. Shape = (B, Hq, L, Ev). Since Ev is statically known, +// only B/Hq/L may be dynamic. The helper again generates the needed tensor.dim +// ops from the query/value tensors so that tensor.splat/tensor.empty gets the +// correct dynamic extents. Assuming the standard 4D layout: +// Query: (B, Hq, L, E) +// Key: (B, Hkv, S, E) +// Value: (B, Hkv, S, Ev) +// When constructing new tensors (mask/output), we need dynamic sizes for +// dimensions that come from the input shapes. +// +// For dims (B, H, L), the runtime sizes always come from the query tensor. +// For dim 3, the required runtime size depends on what we are building: +// For the mask (shape = B×H×L×S), the 3rd axis is S, which lives at +// index 2 of the Key tensor. +// For the output (shape = B×H×L×Ev), Ev is statically known, so we never need a +// dynamic dimension for i = 3. + +void computeDynamicSizes(PatternRewriter &rewriter, Location loc, + const SmallVector &shape, + SmallVector &dynSizes, Value first, + Value second, const int kAttentionRank) { + for (int i = 0; i < kAttentionRank; ++i) { + if (shape[i] == torch::Torch::kUnknownSize) { + Value idx = arith::ConstantIndexOp::create(rewriter, loc, std::min(i, 2)); + Value dim = + tensor::DimOp::create(rewriter, loc, i < 3 ? first : second, idx); + dynSizes.push_back(dim); + } + } +} + +// Utility to create a modified mask tensor. +Value createModifiedMask(PatternRewriter &rewriter, Location loc, + MLIRContext *ctx, FlatSymbolRefAttr maskModRef, + int64_t batch, int64_t numHeads, int64_t seqLenQ, + int64_t seqLenKV, FloatType floatType, + Value builtinQuery, Value builtinKey, Value zero, + const int kAttentionRank) { + static const int kNumModificationIndices = 4; + // Create mask tensor [B, H, M, N] with values 0.0 (attend) or -inf + // (mask). + RankedTensorType boolScalarTensorType = + RankedTensorType::get({}, rewriter.getI1Type()); + torch::Torch::ValueTensorType torchBoolScalarType = + rewriter.getType(ArrayRef{}, + rewriter.getI1Type()); + Type i32Type = rewriter.getI32Type(); + RankedTensorType i32ScalarTensorType = RankedTensorType::get({}, i32Type); + Type si32Type = + IntegerType::get(rewriter.getContext(), 32, IntegerType::Signed); + torch::Torch::ValueTensorType torchI32ScalarType = + rewriter.getType(ArrayRef{}, + si32Type); + SmallVector maskShape = {batch, numHeads, seqLenQ, seqLenKV}; + SmallVector maskDynSizes; + + computeDynamicSizes(rewriter, loc, maskShape, maskDynSizes, builtinQuery, + builtinKey, kAttentionRank); + + Value maskTensor = tensor::EmptyOp::create(rewriter, loc, maskShape, + floatType, maskDynSizes); + // Create linalg.generic to materialize mask. + SmallVector maskMaps; + maskMaps.push_back(AffineMap::getMultiDimIdentityMap(kAttentionRank, ctx)); + + SmallVector iteratorTypes(kAttentionRank, + utils::IteratorType::parallel); + + Value negInf = arith::ConstantFloatOp::create( + rewriter, loc, floatType, + llvm::APFloat::getInf(floatType.getFloatSemantics(), + /*Negative=*/true)); + + auto maskGeneric = linalg::GenericOp::create( + rewriter, loc, TypeRange{maskTensor.getType()}, ValueRange{}, + ValueRange{maskTensor}, maskMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + // Get indices and convert to torch tensors. + SmallVector torchIndices; + for (unsigned i = 0; i < kNumModificationIndices; ++i) { + Value idx = linalg::IndexOp::create(b, loc, i); + Value idxI32 = + arith::IndexCastOp::create(b, loc, rewriter.getI32Type(), idx); + Value idxTensor = tensor::FromElementsOp::create( + b, loc, i32ScalarTensorType, ValueRange{idxI32}); + Value torchIdx = torch::TorchConversion::FromBuiltinTensorOp::create( + b, loc, torchI32ScalarType, idxTensor); + torchIndices.push_back(torchIdx); + } + + // Call mask_mod_fn(b, h, q_idx, kv_idx). + auto callOp = + func::CallOp::create(b, loc, TypeRange{torchBoolScalarType}, + maskModRef, ValueRange(torchIndices)); + Value torchMaskResult = callOp.getResult(0); + + Value maskResult = torch::TorchConversion::ToBuiltinTensorOp::create( + b, loc, boolScalarTensorType, torchMaskResult); + + Value maskBool = + tensor::ExtractOp::create(b, loc, maskResult, ValueRange{}); + + Value maskValue = + arith::SelectOp::create(b, loc, maskBool, zero, negInf); + + linalg::YieldOp::create(b, loc, maskValue); + }); + + return maskGeneric.getResult(0); +} + +Value convertToBuiltinTensor(PatternRewriter &rewriter, Location loc, + Value torchTensor) { + auto torchType = cast(torchTensor.getType()); + return torch::TorchConversion::ToBuiltinTensorOp::create( + rewriter, loc, torchType.toBuiltinTensor(), torchTensor); +} + +struct FlexAttentionOpConversion + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + // Attention tensors are 4D: [batch, head, query_seq, key_seq]. + static const int kAttentionRank = 4; + + LogicalResult matchAndRewrite(torch::Torch::HigherOrderFlexAttentionOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *ctx = getContext(); + Value query = op.getQuery(); + Value key = op.getKey(); + Value value = op.getValue(); + Value scaleValue = op.getScale(); + auto scoreModSymbol = op.getScoreModFn(); + auto maskModSymbol = op.getMaskModFn(); + + bool returnLseValue; + if (!matchPattern(op.getReturnLse(), + torch::Torch::m_TorchConstantBool(&returnLseValue))) { + return rewriter.notifyMatchFailure( + op, "expected return_lse to be a constant bool"); + } + + bool returnMaxScoresValue; + if (!matchPattern( + op.getReturnMaxScores(), + torch::Torch::m_TorchConstantBool(&returnMaxScoresValue))) { + return rewriter.notifyMatchFailure( + op, "expected return_max_scores to be a constant bool"); + } + + auto queryType = cast(query.getType()); + auto keyType = cast(key.getType()); + auto valueType = cast(value.getType()); + + ArrayRef queryShape = queryType.getSizes(); + ArrayRef valueShape = valueType.getSizes(); + + int64_t batch = queryShape[0]; + int64_t numHeads = queryShape[1]; + int64_t seqLenQ = queryShape[2]; + int64_t headDim = queryShape[3]; + int64_t seqLenKV = keyType.getSizes()[2]; + int64_t valueDim = valueShape[3]; + + // Dynamic head dim is not supported. + if (headDim == torch::Torch::kUnknownSize) { + return rewriter.notifyMatchFailure(op, "NYI: dynamic head dimension"); + } + + auto floatType = dyn_cast(queryType.getOptionalDtype()); + // Default scale: 1.0 / sqrt(head_dim). + double scaleVal; + if (!matchPattern(scaleValue, + torch::Torch::m_TorchConstantFloat(&scaleVal))) { + scaleVal = 1.0 / std::sqrt(static_cast(headDim)); + } + + Value scale = arith::ConstantOp::create( + rewriter, loc, floatType, rewriter.getFloatAttr(floatType, scaleVal)); + + Value builtinQuery = convertToBuiltinTensor(rewriter, loc, query); + Value builtinKey = convertToBuiltinTensor(rewriter, loc, key); + Value builtinValue = convertToBuiltinTensor(rewriter, loc, value); + + // Declare common types for mask and score modification regions. + Value zero = arith::ConstantFloatOp::create( + rewriter, loc, floatType, + llvm::APFloat::getZero(floatType.getFloatSemantics())); + Value mask; + if (maskModSymbol) { + FlatSymbolRefAttr maskModRef = + FlatSymbolRefAttr::get(ctx, *maskModSymbol); + mask = createModifiedMask(rewriter, loc, ctx, maskModRef, batch, numHeads, + seqLenQ, seqLenKV, floatType, builtinQuery, + builtinKey, zero, kAttentionRank); + } + + // Create output tensor for attention. + SmallVector outputDynSizes; + SmallVector outputShape = {batch, numHeads, seqLenQ, valueDim}; + computeDynamicSizes(rewriter, loc, outputShape, outputDynSizes, + builtinQuery, builtinValue, kAttentionRank); + + // Initialize output tensor with identity value (0.0 for addition). + Value outputInit = arith::getIdentityValue(arith::AtomicRMWKind::addf, + floatType, rewriter, loc, + /*useOnlyFiniteValue=*/true); + Value outputTensor = tensor::SplatOp::create(rewriter, loc, outputInit, + outputShape, outputDynSizes); + + // Build indexing maps for attention. + // Standard maps: Q, K, V, scale, [mask], output. + AffineExpr b, h, m, n, k1, k2; + bindDims(ctx, b, h, m, n, k1, k2); + + auto qMap = AffineMap::get(6, 0, {b, h, m, k1}, ctx); + auto kMap = AffineMap::get(6, 0, {b, h, n, k1}, ctx); + auto vMap = AffineMap::get(6, 0, {b, h, n, k2}, ctx); + auto sMap = AffineMap::get(6, 0, {}, ctx); + auto oMap = AffineMap::get(6, 0, {b, h, m, k2}, ctx); + + SmallVector indexingMaps = {qMap, kMap, vMap, sMap}; + if (mask) { + indexingMaps.push_back(AffineMap::get(6, 0, {b, h, m, n}, ctx)); + } + + indexingMaps.push_back(oMap); + + // Create attention op. + auto attentionOp = IREE::LinalgExt::AttentionOp::create( + rewriter, loc, outputTensor.getType(), builtinQuery, builtinKey, + builtinValue, scale, outputTensor, + rewriter.getAffineMapArrayAttr(indexingMaps), mask); + + createScoreModificationRegion(rewriter, loc, attentionOp, scoreModSymbol, + floatType, kAttentionRank); + + rewriter.setInsertionPointAfter(attentionOp); + + Value normalizedOutput = attentionOp.getResult(0); + + auto outputTorchType = + queryType.getWithSizesAndDtype(outputShape, floatType); + Value torchOutput = torch::TorchConversion::FromBuiltinTensorOp::create( + rewriter, loc, outputTorchType, normalizedOutput); + + // Handle logsumexp. + // Note: AttentionOp doesn't expose intermediate max/sum + // values needed for LSE calculation. Return a dummy tensor - logsumexp + // shape is output_shape[:-1] (remove last dim). + if (returnLseValue) { + op.emitWarning("FlexAttention: logsumexp output is a dummy (zeros), " + "actual values are not available from AttentionOp"); + } + // Same goes for max_scores computation from AttentionOp. + if (returnMaxScoresValue) { + op.emitWarning("FlexAttention: max_scores output is a dummy (zeros), " + "actual values are not available from AttentionOp"); + } + SmallVector lseShape = outputShape; + lseShape.pop_back(); + + SmallVector lseDynSizes = outputDynSizes; + if (ShapedType::isDynamic(outputShape.back())) { + lseDynSizes.pop_back(); + } + + Value lseTensor = + tensor::SplatOp::create(rewriter, loc, zero, lseShape, lseDynSizes); + + auto lseTorchType = queryType.getWithSizesAndDtype(lseShape, floatType); + Value torchLogsumexp = torch::TorchConversion::FromBuiltinTensorOp::create( + rewriter, loc, lseTorchType, lseTensor); + + rewriter.replaceOp( + op, {torchOutput, torchLogsumexp, /*max_scores=*/torchLogsumexp}); + return success(); + } +}; + class ConvertTorchUnstructuredToLinalgExtPass final : public impl::ConvertTorchUnstructuredToLinalgExtPassBase< ConvertTorchUnstructuredToLinalgExtPass> { public: void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert< + IREE::LinalgExt::IREELinalgExtDialect, torch::Torch::TorchDialect, + tensor::TensorDialect, linalg::LinalgDialect, arith::ArithDialect, + func::FuncDialect, torch::TorchConversion::TorchConversionDialect>(); } void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); - patterns.add(context); + patterns.add(context); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); diff --git a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp index 85f436ded555..5b6d944c136b 100644 --- a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp +++ b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp @@ -82,8 +82,9 @@ getEnclosingWaitSignalFences(Operation *op) { auto parentFuncOp = dyn_cast(op); if (!parentFuncOp) { parentFuncOp = parentFuncOp->getParentOfType(); - if (!parentFuncOp) + if (!parentFuncOp) { return {}; + } } Block *entryBlock = &parentFuncOp.front(); auto numArguments = entryBlock->getNumArguments(); @@ -99,8 +100,9 @@ getEnclosingWaitSignalFences(Value value) { Value convertToBuiltinTensor(OpBuilder &builder, Value possibleTorchTensor) { Type ty = possibleTorchTensor.getType(); - if (isa(ty)) + if (isa(ty)) { return possibleTorchTensor; + } if (auto defining = dyn_cast_if_present( possibleTorchTensor.getDefiningOp())) { @@ -177,8 +179,9 @@ struct ConvertedAsyncFunctionInfo { }; LogicalResult ConvertedAsyncFunctionInfo::postProcess() { - if (funcOp.isExternal()) + if (funcOp.isExternal()) { return success(); + } if (returnOps.size() != 1) { // Multi-exit/CFG could be supported but requires more complicated dominance @@ -197,14 +200,17 @@ LogicalResult ConvertedAsyncFunctionInfo::postProcess() { llvm::zip_equal(inputDispositions, entryArgs, torchInputTypes)) { switch (disp) { case TypeDisposition::IMMUTABLE_TENSOR: { - if (failed( - convertImmutableTensorArg(argValue, torchType, preambleBuilder))) + if (failed(convertImmutableTensorArg(argValue, torchType, + preambleBuilder))) { return failure(); + } break; } case TypeDisposition::MUTABLE_TENSOR: { - if (failed(convertMutableTensorArg(argValue, torchType, preambleBuilder))) + if (failed( + convertMutableTensorArg(argValue, torchType, preambleBuilder))) { return failure(); + } break; } case TypeDisposition::TORCH_PRIMITIVE: { @@ -374,12 +380,14 @@ LogicalResult ConvertedAsyncFunctionInfo::convertImmutableTensorArg( // it. bool hasNonTrivialUse = false; for (auto *userOp : argValue.getUsers()) { - if (isa(userOp)) + if (isa(userOp)) { continue; + } hasNonTrivialUse = true; } - if (!hasNonTrivialUse) + if (!hasNonTrivialUse) { return success(); + } // Remember original uses so we can redirect them. OriginalUses originalUses(argValue); @@ -481,8 +489,9 @@ void retainFunctionAttributes(Operation *srcOp, IREE::Util::FuncOp destOp) { for (auto retainAttrName : retainedAttributes) { StringRef attrName(retainAttrName); Attribute attr = srcOp->getAttr(attrName); - if (attr) + if (attr) { destOp->setAttr(attrName, attr); + } } } @@ -566,8 +575,9 @@ class FuncConversionPass final SmallVector eraseFuncOps; std::vector convertedFuncInfos; for (auto funcOp : moduleOp.getOps()) { - if (!shouldConvertFunc(funcOp)) + if (!shouldConvertFunc(funcOp)) { continue; + } ConvertedAsyncFunctionInfo &convertedFuncInfo = convertedFuncInfos.emplace_back(); if (failed(convertFuncOp(funcOp, convertedFuncInfo))) { @@ -594,12 +604,14 @@ class FuncConversionPass final // calling convention. In the future, we may support "torch externals" // which we convert to mate up with a torch module. We can remove/adapt // this when that is elaborated. - if (torchFunc.isExternal()) + if (torchFunc.isExternal()) { return false; + } // Something has already converted this and told us not to touch it. - if (torchFunc->hasAttr("iree.abi.stub")) + if (torchFunc->hasAttr("iree.abi.stub")) { return false; + } return true; } @@ -640,14 +652,16 @@ class FuncConversionPass final for (size_t i = 0; i < convertedFuncInfo.torchInputTypes.size(); ++i) { if (failed(convertType(loc, convertedFuncInfo.torchInputTypes[i], ireeInputTypes[i], - convertedFuncInfo.inputDispositions[i]))) + convertedFuncInfo.inputDispositions[i]))) { return failure(); + } } for (size_t i = 0; i < convertedFuncInfo.torchResultTypes.size(); ++i) { if (failed(convertType(loc, convertedFuncInfo.torchResultTypes[i], ireeResultTypes[i], - convertedFuncInfo.resultDispositions[i]))) + convertedFuncInfo.resultDispositions[i]))) { return failure(); + } } // Build tied operands index mapping results back to operands. diff --git a/compiler/plugins/input/Torch/InputConversion/Passes.cpp b/compiler/plugins/input/Torch/InputConversion/Passes.cpp index 2d04729c0065..4ed603cb4444 100644 --- a/compiler/plugins/input/Torch/InputConversion/Passes.cpp +++ b/compiler/plugins/input/Torch/InputConversion/Passes.cpp @@ -50,9 +50,10 @@ void createTorchToIREEPipeline( torch::Torch::createReduceOpVariantsPass(llvm::StringRef())); pm.addNestedPass( mlir::torch::TorchConversion::createConvertCustomQuantOpPass()); - if (options.decompose) + if (options.decompose) { pm.addNestedPass( torch::Torch::createDecomposeComplexOpsPass(BackendLegalOps::get())); + } pm.addNestedPass(torch::Torch::createFuseQuantizedOpsPass()); pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(torch::Torch::createScalarizeShapesPass()); diff --git a/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir b/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir index a568966d906b..0432200fa3d4 100644 --- a/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir +++ b/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(torch-iree-torch-unstructured-to-linalg-ext))" %s | FileCheck %s +// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(func.func(torch-iree-torch-unstructured-to-linalg-ext))" %s | FileCheck %s // CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d2)> // CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> @@ -99,3 +99,108 @@ func.func @fft_rfft.last(%arg0: !torch.vtensor<[3,8,16],f32>) -> !torch.vtensor< // CHECK: %[[VAR12:.*]] = torch.aten.cat %[[VAR11]], %[[INTM1]] : !torch.list>, !torch.int -> !torch.vtensor<[3,8,9,2],f32> // CHECK: %[[VAR13:.*]] = torch.aten.view_as_complex %[[VAR12]] : !torch.vtensor<[3,8,9,2],f32> -> !torch.vtensor<[3,8,9],complex> // CHECK: return %[[VAR13]] : !torch.vtensor<[3,8,9],complex> + +// ----- + +//===----------------------------------------------------------------------===// +// FlexAttention tests +//===----------------------------------------------------------------------===// + + +func.func @flex_attn_with_scoremod_and_maskmod(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>) attributes {torch.assume_strict_symbolic_shapes} { + %float1.000000e00 = torch.constant.float 1.000000e+00 + %false = torch.constant.bool false + %output, %logsumexp, %maxscores = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %false, %false {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> + return %output : !torch.vtensor<[4,8,1024,64],f32> +} +// CHECK-LABEL: func.func @flex_attn_with_scoremod_and_maskmod( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG1:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG2:.*]]: !torch.vtensor<[4,8,1024,64],f32>) -> !torch.vtensor<[4,8,1024,64],f32> +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<4x8x1024x64xf32> +// CHECK-DAG: %[[CST_0:.*]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: %[[CST_1:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[CST_2:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[QUERY:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> +// CHECK: %[[KEY:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> +// CHECK: %[[VALUE:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> +// CHECK: %[[MASK_EMPTY:.*]] = tensor.empty() : tensor<4x8x1024x1024xf32> +// CHECK: %[[MASK:.*]] = linalg.generic +// CHECK-SAME: outs(%[[MASK_EMPTY]] : tensor<4x8x1024x1024xf32>) +// CHECK: func.call @sdpa_mask0 +// CHECK: %[[ATTENTION:.*]] = iree_linalg_ext.attention +// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[VALUE]], %[[CST_2]], %[[MASK]] : +// CHECK-SAME: outs(%[[CST]] : tensor<4x8x1024x64xf32>) +// CHECK: func.call @sdpa_score0 +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[ATTENTION]] : tensor<4x8x1024x64xf32> -> !torch.vtensor<[4,8,1024,64],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[4,8,1024,64],f32> + +func.func private @sdpa_score0(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32> { + %0 = torch.aten.tanh %arg0 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} +// CHECK-LABEL: func.func private @sdpa_score0( +// CHECK: %{{.*}} = torch.aten.tanh %{{.*}} : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> + +func.func private @sdpa_mask0(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>) -> !torch.vtensor<[],i1> { + %0 = torch.aten.ge.Tensor %arg2, %arg3 : !torch.vtensor<[],si32>, !torch.vtensor<[],si32> -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> +} +// CHECK-LABEL: func.func private @sdpa_mask0( +// CHECK: %{{.*}} = torch.aten.ge.Tensor %{{.*}}, %{{.*}} : !torch.vtensor<[],si32>, !torch.vtensor<[],si32> -> !torch.vtensor<[],i1> + +func.func @flex_attn_with_scoremod_only(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>) attributes {torch.assume_strict_symbolic_shapes} { + %float1.000000e00 = torch.constant.float 1.000000e+00 + %false = torch.constant.bool false + %output, %logsumexp, %maxscores = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %false, %false {score_mod_fn = @sdpa_score0} : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> + return %output : !torch.vtensor<[4,8,1024,64],f32> +} +// CHECK-LABEL: func.func @flex_attn_with_scoremod_only +// CHECK-NOT: linalg.generic +// CHECK-NOT: func.call @sdpa_mask0 +// CHECK: iree_linalg_ext.attention +// CHECK: func.call @sdpa_score0 +// CHECK: return + +func.func @flex_attn_with_maskmod_only(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>) attributes {torch.assume_strict_symbolic_shapes} { + %float1.000000e00 = torch.constant.float 1.000000e+00 + %false = torch.constant.bool false + %output, %logsumexp, %maxscores = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %false, %false {mask_mod_fn = @sdpa_mask0} : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> + return %output : !torch.vtensor<[4,8,1024,64],f32> +} +// CHECK-LABEL: func.func @flex_attn_with_maskmod_only +// CHECK: linalg.generic +// CHECK: func.call @sdpa_mask0 +// CHECK: iree_linalg_ext.attention +// CHECK-NOT: func.call @sdpa_score0 +// CHECK: return + +// ----- + +func.func @flex_attn_without_mods(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>) attributes {torch.assume_strict_symbolic_shapes} { + %float1.000000e00 = torch.constant.float 1.000000e+00 + %false = torch.constant.bool false + %output, %logsumexp, %maxscores = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %false, %false : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> + return %output : !torch.vtensor<[4,8,1024,64],f32> +} +// CHECK-LABEL: func.func @flex_attn_without_mods +// CHECK-NOT: linalg.generic +// CHECK-NOT: func.call @sdpa_mask0 +// CHECK: iree_linalg_ext.attention +// CHECK-NOT: func.call @sdpa_score0 +// CHECK: return + +// ----- + +func.func @flex_attn_without_mods_return_maxscore_and_lse(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>) attributes {torch.assume_strict_symbolic_shapes} { + %float1.000000e00 = torch.constant.float 1.000000e+00 + %true = torch.constant.bool true + // expected-warning @+2 {{FlexAttention: logsumexp output is a dummy (zeros), actual values are not available from AttentionOp}} + // expected-warning @+1 {{FlexAttention: max_scores output is a dummy (zeros), actual values are not available from AttentionOp}} + %output, %logsumexp, %maxscores = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %true, %true : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> + return %output : !torch.vtensor<[4,8,1024,64],f32> +} +// CHECK-LABEL: func.func @flex_attn_without_mods_return_maxscore_and_lse +// CHECK-NOT: linalg.generic +// CHECK-NOT: func.call @sdpa_mask0 +// CHECK: iree_linalg_ext.attention +// CHECK-NOT: func.call @sdpa_score0 +// CHECK: return diff --git a/compiler/plugins/target/CUDA/CUDATarget.cpp b/compiler/plugins/target/CUDA/CUDATarget.cpp index 87422ee4b6ba..b4fe9c265bd9 100644 --- a/compiler/plugins/target/CUDA/CUDATarget.cpp +++ b/compiler/plugins/target/CUDA/CUDATarget.cpp @@ -122,14 +122,17 @@ static constexpr char kPtxasCompilerName[] = "ptxas"; static FailureOr findPtxasCompiler(const CUDAOptions &options, std::string *message) { std::string ptxasCompiler; - if (!options.clUsePtxasFrom.empty()) + if (!options.clUsePtxasFrom.empty()) { ptxasCompiler = options.clUsePtxasFrom; - if (llvm::sys::fs::exists(ptxasCompiler)) + } + if (llvm::sys::fs::exists(ptxasCompiler)) { return ptxasCompiler; + } ptxasCompiler = findTool(kPtxasCompilerName); - if (llvm::sys::fs::exists(ptxasCompiler)) + if (llvm::sys::fs::exists(ptxasCompiler)) { return ptxasCompiler; + } *message = std::string( "Could not find ptxas compiler. Try passing it explicitly with " @@ -181,8 +184,9 @@ static FailureOr compileWithPtxas(StringRef ptxasCompiler, llvm::StringSaver stringSaver(scratchAllocator); SmallVector rawArgs; Tokenize(ptxasParams, stringSaver, rawArgs, /*MarkEOLs=*/false); - for (auto rawArg : rawArgs) + for (auto rawArg : rawArgs) { ArgVector.push_back(StringRef(rawArg)); + } std::optional redirects[] = { stdinFile.str(), @@ -233,8 +237,9 @@ static FailureOr compileWithPtxas(StringRef ptxasCompiler, static std::string produceGpuImage(const CUDAOptions &options, StringRef targetArch, std::string &ptxImage) { - if (!options.clUsePtxas) + if (!options.clUsePtxas) { return ptxImage; + } std::string message; FailureOr ptxasCompiler = findPtxasCompiler(options, &message); @@ -243,8 +248,9 @@ static std::string produceGpuImage(const CUDAOptions &options, FailureOr maybeCubinImage = compileWithPtxas(ptxasCompiler.value(), targetArch, options.clUsePtxasParams, ptxImage, &message); - if (succeeded(maybeCubinImage)) + if (succeeded(maybeCubinImage)) { return maybeCubinImage.value(); + } } llvm::WithColor::warning() @@ -414,8 +420,9 @@ class CUDATargetBackend final : public TargetBackend { getExecutableTarget(MLIRContext *context) const { Builder b(context); SmallVector configItems; - if (failed(options.verify(b))) + if (failed(options.verify(b))) { return nullptr; + } if (auto target = GPU::getCUDATargetDetails( options.clTarget, options.clTargetFeatures, context)) { diff --git a/compiler/plugins/target/CUDA/test/BUILD.bazel b/compiler/plugins/target/CUDA/test/BUILD.bazel index 4bb0e5e82b7c..51edda9651ef 100644 --- a/compiler/plugins/target/CUDA/test/BUILD.bazel +++ b/compiler/plugins/target/CUDA/test/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "smoketest.mlir", ], diff --git a/compiler/plugins/target/LLVMCPU/Builtins/Device.cpp b/compiler/plugins/target/LLVMCPU/Builtins/Device.cpp index 9ae3527d078d..0689fbd07a9a 100644 --- a/compiler/plugins/target/LLVMCPU/Builtins/Device.cpp +++ b/compiler/plugins/target/LLVMCPU/Builtins/Device.cpp @@ -17,8 +17,9 @@ namespace mlir::iree_compiler::IREE::HAL { static const iree_file_toc_t *lookupDeviceFile(StringRef filename) { for (size_t i = 0; i < iree_builtins_libdevice_bitcode_size(); ++i) { const auto &file_toc = iree_builtins_libdevice_bitcode_create()[i]; - if (filename == file_toc.name) + if (filename == file_toc.name) { return &file_toc; + } } return nullptr; } @@ -67,8 +68,9 @@ loadDeviceBitcode(llvm::TargetMachine *targetMachine, llvm::MemoryBufferRef bitcodeBufferRef( llvm::StringRef(file->data, file->size), file->name); auto bitcodeModuleValue = llvm::parseBitcodeFile(bitcodeBufferRef, context); - if (!bitcodeModuleValue) + if (!bitcodeModuleValue) { return bitcodeModuleValue; + } auto bitcodeModule = std::move(bitcodeModuleValue.get()); // Clang adds its own per-function attributes that we need to strip so that @@ -86,8 +88,9 @@ static void overridePlatformGlobal(llvm::Module &module, StringRef globalName, uint32_t newValue) { // NOTE: the global will not be defined if it is not used in the module. auto *globalValue = module.getNamedGlobal(globalName); - if (!globalValue) + if (!globalValue) { return; + } globalValue->setLinkage(llvm::GlobalValue::PrivateLinkage); globalValue->setDSOLocal(true); globalValue->setConstant(true); diff --git a/compiler/plugins/target/LLVMCPU/Builtins/Musl.cpp b/compiler/plugins/target/LLVMCPU/Builtins/Musl.cpp index 337526155678..0189438d2bf0 100644 --- a/compiler/plugins/target/LLVMCPU/Builtins/Musl.cpp +++ b/compiler/plugins/target/LLVMCPU/Builtins/Musl.cpp @@ -16,8 +16,9 @@ namespace mlir::iree_compiler::IREE::HAL { static const iree_file_toc_t *lookupMuslFile(StringRef filename) { for (size_t i = 0; i < iree_builtins_libmusl_size(); ++i) { const auto &file_toc = iree_builtins_libmusl_create()[i]; - if (filename == file_toc.name) + if (filename == file_toc.name) { return &file_toc; + } } return nullptr; } diff --git a/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp b/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp index bdc5b16fc346..69e066ec94a6 100644 --- a/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp +++ b/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp @@ -272,8 +272,9 @@ class LLVMCPUTargetBackend final : public TargetBackend { // multi-threading issues. llvm::LLVMContext context; auto maybeTarget = getVariantTarget(variantOp); - if (!maybeTarget) + if (!maybeTarget) { return failure(); + } const LLVMTarget &target = *maybeTarget; LLVM_DEBUG(dbgs() << "LLVM-CPU SerializeExecutable:\n" << "-----------------------------\n"; @@ -384,8 +385,9 @@ class LLVMCPUTargetBackend final : public TargetBackend { for (auto exportOp : variantOp.getBlock().getOps()) { // Find the matching function in the LLVM module. auto *llvmFunc = llvmModule->getFunction(exportOp.getName()); - if (!llvmFunc) + if (!llvmFunc) { continue; + } llvmFunc->setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage); llvmFunc->setDSOLocal(true); @@ -595,8 +597,9 @@ class LLVMCPUTargetBackend final : public TargetBackend { // Strip any compiler identifiers that may have snuck in. We let the linker // tag the module. auto *llvmIdent = llvmModule->getNamedMetadata("llvm.ident"); - if (llvmIdent) + if (llvmIdent) { llvmIdent->clearOperands(); + } // Dump all linked bitcode prior to optimization. if (!options.dumpIntermediatesPath.empty()) { diff --git a/compiler/plugins/target/LLVMCPU/LLVMIRPasses.cpp b/compiler/plugins/target/LLVMCPU/LLVMIRPasses.cpp index c045c061ce2f..2b482ef98b54 100644 --- a/compiler/plugins/target/LLVMCPU/LLVMIRPasses.cpp +++ b/compiler/plugins/target/LLVMCPU/LLVMIRPasses.cpp @@ -89,8 +89,9 @@ LogicalResult runLLVMIRPasses(const LLVMTarget &target, modulePassManager.run(*module, moduleAnalysisManager); } - if (llvm::verifyModule(*module)) + if (llvm::verifyModule(*module)) { return failure(); + } return success(); } diff --git a/compiler/plugins/target/LLVMCPU/LLVMTargetOptions.cpp b/compiler/plugins/target/LLVMCPU/LLVMTargetOptions.cpp index 8ead1b21a39c..d4ba1d652e52 100644 --- a/compiler/plugins/target/LLVMCPU/LLVMTargetOptions.cpp +++ b/compiler/plugins/target/LLVMCPU/LLVMTargetOptions.cpp @@ -85,8 +85,9 @@ std::optional LLVMTarget::createForHost() { << getMessage(status, triple) << "\n"; return std::nullopt; } - if (target) + if (target) { target->populateDefaultsFromTargetMachine(); + } return target; } @@ -163,16 +164,21 @@ void LLVMTarget::storeToConfigAttrs(MLIRContext *context, if (!staticLibraryOutput.empty()) { addString("static_library_output", staticLibraryOutput); } - if (pipelineTuningOptions.LoopInterleaving != DEFAULT_LOOP_INTERLEAVING) + if (pipelineTuningOptions.LoopInterleaving != DEFAULT_LOOP_INTERLEAVING) { addBool("loop_interleaving", pipelineTuningOptions.LoopInterleaving); - if (pipelineTuningOptions.LoopVectorization != DEFAULT_LOOP_VECTORIZATION) + } + if (pipelineTuningOptions.LoopVectorization != DEFAULT_LOOP_VECTORIZATION) { addBool("loop_vectorization", pipelineTuningOptions.LoopVectorization); - if (pipelineTuningOptions.LoopUnrolling != DEFAULT_LOOP_UNROLLING) + } + if (pipelineTuningOptions.LoopUnrolling != DEFAULT_LOOP_UNROLLING) { addBool("loop_unrolling", pipelineTuningOptions.LoopUnrolling); - if (pipelineTuningOptions.SLPVectorization != DEFAULT_SLP_VECTORIZATION) + } + if (pipelineTuningOptions.SLPVectorization != DEFAULT_SLP_VECTORIZATION) { addBool("slp_vectorization", pipelineTuningOptions.SLPVectorization); - if (!llvmTargetOptions.MCOptions.ABIName.empty()) + } + if (!llvmTargetOptions.MCOptions.ABIName.empty()) { addString("target_abi", llvmTargetOptions.MCOptions.ABIName); + } if (llvmTargetOptions.FloatABIType != DEFAULT_FLOAT_ABI) { switch (llvmTargetOptions.FloatABIType) { case llvm::FloatABI::Default: @@ -186,10 +192,12 @@ void LLVMTarget::storeToConfigAttrs(MLIRContext *context, break; } } - if (ukernels.compare(DEFAULT_ENABLE_UKERNELS) != 0) + if (ukernels.compare(DEFAULT_ENABLE_UKERNELS) != 0) { addString("ukernels", ukernels); - if (linkUkernelBitcode != DEFAULT_LINK_UKERNEL_BITCODE) + } + if (linkUkernelBitcode != DEFAULT_LINK_UKERNEL_BITCODE) { addBool("link_ukernel_bitcode", linkUkernelBitcode); + } } std::optional @@ -274,13 +282,13 @@ LLVMTarget::loadFromConfigAttr(Location loc, DictionaryAttr config, target.linkStatic = getBool("link_static", DEFAULT_LINK_STATIC); auto sanitizer = getOptionalString("sanitizer"); if (sanitizer) { - if (sanitizer == "none") + if (sanitizer == "none") { target.sanitizerKind = SanitizerKind::kNone; - else if (sanitizer == "address") + } else if (sanitizer == "address") { target.sanitizerKind = SanitizerKind::kAddress; - else if (sanitizer == "thread") + } else if (sanitizer == "thread") { target.sanitizerKind = SanitizerKind::kThread; - else { + } else { emitError(loc) << "executable config unexpected value for 'sanitizer': " << *sanitizer; return {}; @@ -297,17 +305,18 @@ LLVMTarget::loadFromConfigAttr(Location loc, DictionaryAttr config, target.pipelineTuningOptions.SLPVectorization = getBool( "slp_vectorization", target.pipelineTuningOptions.SLPVectorization); auto targetAbi = getOptionalString("target_abi"); - if (targetAbi) + if (targetAbi) { target.llvmTargetOptions.MCOptions.ABIName = *targetAbi; + } auto floatAbi = getOptionalString("float_abi"); if (floatAbi) { - if (floatAbi == "default") + if (floatAbi == "default") { target.llvmTargetOptions.FloatABIType = llvm::FloatABI::Default; - else if (floatAbi == "soft") + } else if (floatAbi == "soft") { target.llvmTargetOptions.FloatABIType = llvm::FloatABI::Default; - else if (floatAbi == "hard") + } else if (floatAbi == "hard") { target.llvmTargetOptions.FloatABIType = llvm::FloatABI::Default; - else { + } else { emitError(loc) << "executable config unexpected value for 'float_abi'"; return {}; } @@ -389,8 +398,9 @@ createTargetMachine(const LLVMTarget &target) { std::string errorMessage; auto llvmTarget = llvm::TargetRegistry::lookupTarget( llvm::Triple(target.getTriple()), errorMessage); - if (!llvmTarget) + if (!llvmTarget) { return nullptr; + } llvm::Triple triple(target.getTriple()); std::unique_ptr machine(llvmTarget->createTargetMachine( triple, target.getCpu() /* cpu e.g k8 */, diff --git a/compiler/plugins/target/LLVMCPU/LibraryBuilder.cpp b/compiler/plugins/target/LLVMCPU/LibraryBuilder.cpp index d765b2adfafb..83c724ccaf1a 100644 --- a/compiler/plugins/target/LLVMCPU/LibraryBuilder.cpp +++ b/compiler/plugins/target/LLVMCPU/LibraryBuilder.cpp @@ -520,8 +520,9 @@ LibraryBuilder::buildLibraryV0ImportTable(std::string libraryName) { SmallVector symbolNameValues; for (auto &import : imports) { auto symbolName = import.symbol_name; - if (import.weak) + if (import.weak) { symbolName = "?" + symbolName; + } symbolNameValues.push_back(createStringConstant(symbolName, module)); } symbolNames = createArrayConstant(libraryName + "_import_names", @@ -552,8 +553,9 @@ LibraryBuilder::buildLibraryV0ExportTable(std::string libraryName) { // iree_hal_executable_export_table_v0_t::ptrs SmallVector exportPtrValues; - for (auto dispatch : exports) + for (auto dispatch : exports) { exportPtrValues.push_back(dispatch.func); + } llvm::Constant *exportPtrs = createArrayConstant( libraryName + "_funcs", ptrType, exportPtrValues, module); @@ -603,8 +605,9 @@ LibraryBuilder::buildLibraryV0ExportTable(std::string libraryName) { llvm::Constant *exportNames = llvm::Constant::getNullValue(ptrType); if (mode == Mode::INCLUDE_REFLECTION_ATTRS) { SmallVector exportNameValues; - for (auto dispatch : exports) + for (auto dispatch : exports) { exportNameValues.push_back(createStringConstant(dispatch.name, module)); + } exportNames = createArrayConstant(libraryName + "_names", ptrType, exportNameValues, module); } @@ -615,9 +618,10 @@ LibraryBuilder::buildLibraryV0ExportTable(std::string libraryName) { exports, [](auto &dispatch) { return !dispatch.tag.empty(); }); if (mode == Mode::INCLUDE_REFLECTION_ATTRS && hasAnyTags) { SmallVector exportTagValues; - for (auto dispatch : exports) + for (auto dispatch : exports) { exportTagValues.push_back( createStringConstantOrNull(dispatch.tag, module)); + } exportTags = createArrayConstant(libraryName + "_tags", ptrType, exportTagValues, module); } diff --git a/compiler/plugins/target/LLVMCPU/LinkerTool.cpp b/compiler/plugins/target/LLVMCPU/LinkerTool.cpp index b763830da3d6..ef13059db486 100644 --- a/compiler/plugins/target/LLVMCPU/LinkerTool.cpp +++ b/compiler/plugins/target/LLVMCPU/LinkerTool.cpp @@ -56,8 +56,9 @@ Artifact Artifact::createVariant(StringRef basePath, StringRef suffix) { } void Artifact::keep() const { - if (outputFile) + if (outputFile) { outputFile->keep(); + } } std::optional> Artifact::read() const { @@ -129,8 +130,9 @@ LogicalResult LinkerTool::runLinkCommand(std::string commandLine, commandLine = escapeCommandLineComponent(commandLine); } int exitCode = system(commandLine.c_str()); - if (exitCode == 0) + if (exitCode == 0) { return success(); + } llvm::errs() << "Linking failed; escaped command line returned exit code " << exitCode << ":\n\n" << commandLine << "\n\n"; diff --git a/compiler/plugins/target/LLVMCPU/ResolveCPUAndCPUFeatures.cpp b/compiler/plugins/target/LLVMCPU/ResolveCPUAndCPUFeatures.cpp index 59da38c40bfa..1705ba96ff7e 100644 --- a/compiler/plugins/target/LLVMCPU/ResolveCPUAndCPUFeatures.cpp +++ b/compiler/plugins/target/LLVMCPU/ResolveCPUAndCPUFeatures.cpp @@ -6,6 +6,7 @@ #include "compiler/plugins/target/LLVMCPU/ResolveCPUAndCPUFeatures.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/TargetParser/AArch64TargetParser.h" @@ -29,9 +30,18 @@ resolveHostCPUAndCPUFeatures(std::string &cpu, std::string &cpuFeatures) { return ResolveCPUAndCPUFeaturesStatus::InconsistentHost; } cpu = llvm::sys::getHostCPUName(); + // Sort features to ensure deterministic iteration order. The StringMap + // returned by getHostCPUFeatures() has non-deterministic iteration order. + llvm::StringMap hostFeatures = + llvm::sys::getHostCPUFeatures(); + auto sortedFeatures = + llvm::to_vector_of(hostFeatures.keys()); + llvm::sort(sortedFeatures); + + // Add all features in lexicographically sorted order. llvm::SubtargetFeatures features; - for (auto &feature : llvm::sys::getHostCPUFeatures()) { - features.AddFeature(feature.first(), feature.second); + for (llvm::StringRef feature : sortedFeatures) { + features.AddFeature(feature, hostFeatures.lookup(feature)); } cpuFeatures = features.getString(); return ResolveCPUAndCPUFeaturesStatus::OK; diff --git a/compiler/plugins/target/LLVMCPU/internal/AndroidLinkerTool.cpp b/compiler/plugins/target/LLVMCPU/internal/AndroidLinkerTool.cpp index 6d4c208cd48a..1ceb47dd5ac2 100644 --- a/compiler/plugins/target/LLVMCPU/internal/AndroidLinkerTool.cpp +++ b/compiler/plugins/target/LLVMCPU/internal/AndroidLinkerTool.cpp @@ -104,8 +104,9 @@ class AndroidLinkerTool : public LinkerTool { std::string getSystemToolPath() const override { auto toolPath = LinkerTool::getSystemToolPath(); - if (!toolPath.empty()) + if (!toolPath.empty()) { return toolPath; + } // ANDROID_NDK must be set for us to infer the tool path. char *androidNDKPath = std::getenv("ANDROID_NDK"); @@ -216,8 +217,9 @@ class AndroidLinkerTool : public LinkerTool { flagsToPrefixForLinker.clear(); auto commandLine = llvm::join(flags, " "); - if (failed(runLinkCommand(commandLine))) + if (failed(runLinkCommand(commandLine))) { return std::nullopt; + } return artifacts; } }; diff --git a/compiler/plugins/target/LLVMCPU/internal/EmbeddedLinkerTool.cpp b/compiler/plugins/target/LLVMCPU/internal/EmbeddedLinkerTool.cpp index 61f9a82ed4f5..1dd6eeef5316 100644 --- a/compiler/plugins/target/LLVMCPU/internal/EmbeddedLinkerTool.cpp +++ b/compiler/plugins/target/LLVMCPU/internal/EmbeddedLinkerTool.cpp @@ -50,15 +50,17 @@ class EmbeddedLinkerTool : public LinkerTool { // Fall back to check for setting the linker explicitly via environment // variables. char *envVarPath = std::getenv("IREE_LLVM_EMBEDDED_LINKER_PATH"); - if (envVarPath && envVarPath[0] != '\0') + if (envVarPath && envVarPath[0] != '\0') { return std::string(envVarPath); + } // No explicit linker specified, search the install/build dir or env. const SmallVector &toolNames{"iree-lld", "lld", "ld.lld", "lld-link"}; std::string toolPath = findTool(toolNames); - if (!toolPath.empty()) + if (!toolPath.empty()) { return toolPath; + } llvm::errs() << "error: required embedded linker tool (typically `lld`) not found " @@ -119,8 +121,9 @@ class EmbeddedLinkerTool : public LinkerTool { artifacts.libraryFile.close(); std::string embeddedToolPath = getEmbeddedToolPath(); - if (embeddedToolPath.empty()) + if (embeddedToolPath.empty()) { return std::nullopt; + } SmallVector flags = { embeddedToolPath, diff --git a/compiler/plugins/target/LLVMCPU/internal/UnixLinkerTool.cpp b/compiler/plugins/target/LLVMCPU/internal/UnixLinkerTool.cpp index aa71989094a8..529a036182dc 100644 --- a/compiler/plugins/target/LLVMCPU/internal/UnixLinkerTool.cpp +++ b/compiler/plugins/target/LLVMCPU/internal/UnixLinkerTool.cpp @@ -24,8 +24,9 @@ class UnixLinkerTool : public LinkerTool { std::string getSystemToolPath() const override { // First check for setting the linker explicitly. auto toolPath = LinkerTool::getSystemToolPath(); - if (!toolPath.empty()) + if (!toolPath.empty()) { return toolPath; + } // No explicit linker specified, search the environment for common tools. // We want LLD: @@ -53,8 +54,9 @@ class UnixLinkerTool : public LinkerTool { // of these, at least given current behavior. toolPath = findToolInEnvironment({"ld.lld", "ld"}); } - if (!toolPath.empty()) + if (!toolPath.empty()) { return toolPath; + } llvm::errs() << "No Unix linker tool found in environment.\n"; return ""; @@ -129,8 +131,9 @@ class UnixLinkerTool : public LinkerTool { } auto commandLine = llvm::join(flags, " "); - if (failed(runLinkCommand(commandLine))) + if (failed(runLinkCommand(commandLine))) { return std::nullopt; + } return artifacts; } diff --git a/compiler/plugins/target/LLVMCPU/internal/WasmLinkerTool.cpp b/compiler/plugins/target/LLVMCPU/internal/WasmLinkerTool.cpp index 9e2e4fe6c6fc..5ae6f60fad79 100644 --- a/compiler/plugins/target/LLVMCPU/internal/WasmLinkerTool.cpp +++ b/compiler/plugins/target/LLVMCPU/internal/WasmLinkerTool.cpp @@ -54,8 +54,9 @@ class WasmLinkerTool : public LinkerTool { // or install directories) for common tools. std::string toolPath = findToolFromExecutableDir( {"wasm-ld", "iree-lld", "lld", "ld.lld", "lld-link"}); - if (!toolPath.empty()) + if (!toolPath.empty()) { return toolPath; + } llvm::errs() << "No Wasm linker tool specified or discovered\n"; return ""; @@ -131,8 +132,9 @@ class WasmLinkerTool : public LinkerTool { } auto commandLine = llvm::join(flags, " "); - if (failed(runLinkCommand(commandLine))) + if (failed(runLinkCommand(commandLine))) { return std::nullopt; + } return artifacts; } }; diff --git a/compiler/plugins/target/LLVMCPU/internal/WindowsLinkerTool.cpp b/compiler/plugins/target/LLVMCPU/internal/WindowsLinkerTool.cpp index 8c3502771144..824af9a90a3c 100644 --- a/compiler/plugins/target/LLVMCPU/internal/WindowsLinkerTool.cpp +++ b/compiler/plugins/target/LLVMCPU/internal/WindowsLinkerTool.cpp @@ -23,14 +23,16 @@ class WindowsLinkerTool : public LinkerTool { std::string getSystemToolPath() const override { // First check for setting the linker explicitly. auto toolPath = LinkerTool::getSystemToolPath(); - if (!toolPath.empty()) + if (!toolPath.empty()) { return toolPath; + } // No explicit linker specified, search the executable directory (i.e. our // own build or install directories) for common tools. toolPath = findToolFromExecutableDir({"lld-link"}); - if (!toolPath.empty()) + if (!toolPath.empty()) { return toolPath; + } llvm::errs() << "No Windows linker tool specified or discovered\n"; return ""; @@ -273,8 +275,9 @@ class WindowsLinkerTool : public LinkerTool { } auto commandLine = llvm::join(flags, " "); - if (failed(runLinkCommand(commandLine))) + if (failed(runLinkCommand(commandLine))) { return std::nullopt; + } // PDB file gets generated wtih the same path + .pdb. artifacts.debugFile = diff --git a/compiler/plugins/target/LLVMCPU/test/BUILD.bazel b/compiler/plugins/target/LLVMCPU/test/BUILD.bazel index a8ce13583dd3..496cb7698608 100644 --- a/compiler/plugins/target/LLVMCPU/test/BUILD.bazel +++ b/compiler/plugins/target/LLVMCPU/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "hal_target_device_attributes.mlir", "materialize_homogeneous_encodings.mlir", diff --git a/compiler/plugins/target/MetalSPIRV/MSLToMetalLib.cpp b/compiler/plugins/target/MetalSPIRV/MSLToMetalLib.cpp index 2937fefc354e..c50969dab99d 100644 --- a/compiler/plugins/target/MetalSPIRV/MSLToMetalLib.cpp +++ b/compiler/plugins/target/MetalSPIRV/MSLToMetalLib.cpp @@ -54,8 +54,9 @@ static std::string getMetalCompileCommand(MetalTargetPlatform platform, static LogicalResult runSystemCommand(llvm::StringRef command) { LLVM_DEBUG(llvm::dbgs() << "Running system command: '" << command << "'\n"); int exitCode = system(command.data()); - if (exitCode == 0) + if (exitCode == 0) { return success(); + } llvm::errs() << "Failed to run system command '" << command << "' with error code: " << exitCode << "\n"; return failure(); @@ -78,8 +79,9 @@ compileMSLToMetalLib(MetalTargetPlatform targetPlatform, std::string command = getMetalCompileCommand(targetPlatform, mslFile, libFile); - if (failed(runSystemCommand(command))) + if (failed(runSystemCommand(command))) { return nullptr; + } auto fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(libFile, /*isText=*/false); diff --git a/compiler/plugins/target/MetalSPIRV/SPIRVToMSL.cpp b/compiler/plugins/target/MetalSPIRV/SPIRVToMSL.cpp index a86849ab565a..413056dfe2e2 100644 --- a/compiler/plugins/target/MetalSPIRV/SPIRVToMSL.cpp +++ b/compiler/plugins/target/MetalSPIRV/SPIRVToMSL.cpp @@ -33,8 +33,9 @@ class SPIRVToMSLCompiler : public SPIRV_CROSS_NAMESPACE::CompilerMSL { entryName.str(), spv::ExecutionModel::ExecutionModelGLCompute); const auto &workgroupSize = entryPoint.workgroup_size; // TODO(antiagainst): support specialization constant. - if (workgroupSize.constant != 0) + if (workgroupSize.constant != 0) { return {0, 0, 0}; + } return {workgroupSize.x, workgroupSize.y, workgroupSize.z}; } @@ -127,8 +128,9 @@ crossCompileSPIRVToMSL(IREE::HAL::MetalTargetPlatform targetPlatform, SmallVector descriptors; bool hasPushConstant = false; - if (!spvCrossCompiler.getResources(&descriptors, &hasPushConstant)) + if (!spvCrossCompiler.getResources(&descriptors, &hasPushConstant)) { return std::nullopt; + } // Explicitly set the argument buffer [[id(N)]] location for each SPIR-V // resource variable. diff --git a/compiler/plugins/target/MetalSPIRV/test/BUILD.bazel b/compiler/plugins/target/MetalSPIRV/test/BUILD.bazel index 0bf6c5a2c1da..9eeac27b6974 100644 --- a/compiler/plugins/target/MetalSPIRV/test/BUILD.bazel +++ b/compiler/plugins/target/MetalSPIRV/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted ["smoketest.mlir"], include = ["*.mlir"], ), diff --git a/compiler/plugins/target/ROCM/Dialect/ROCM/IR/BUILD.bazel b/compiler/plugins/target/ROCM/Dialect/ROCM/IR/BUILD.bazel index 144d37168e2c..3c37fbacaba3 100644 --- a/compiler/plugins/target/ROCM/Dialect/ROCM/IR/BUILD.bazel +++ b/compiler/plugins/target/ROCM/Dialect/ROCM/IR/BUILD.bazel @@ -21,6 +21,7 @@ exports_files([ iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "ROCMAttrs.td", "ROCMDialect.td", diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp index 516dead102d5..f994f722b356 100644 --- a/compiler/plugins/target/ROCM/ROCMTarget.cpp +++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp @@ -608,8 +608,9 @@ class ROCMTargetBackend final : public TargetBackend { for (auto func : innerModuleOp.getOps()) { llvm::Function *llvmFunc = llvmModule->getFunction(func.getName()); - if (llvmFunc->isDeclaration()) + if (llvmFunc->isDeclaration()) { continue; + } // Override flags as given by target func attrs. if (auto funcAttrs = @@ -702,8 +703,9 @@ class ROCMTargetBackend final : public TargetBackend { llvmModule->addModuleFlag(llvm::Module::Error, "amdhsa_code_object_version", abiVersion); - for (llvm::Function &f : llvmModule->functions()) + for (llvm::Function &f : llvmModule->functions()) { f.addFnAttr(llvm::Attribute::AlwaysInline); + } // Link user-provided modules. llvm::Linker linker(*llvmModule); @@ -814,8 +816,9 @@ class ROCMTargetBackend final : public TargetBackend { // final FlatBuffer. std::string targetObj = translateModuleToObj(*llvmModule, *targetMachine); targetHSACO = createHsaco(variantOp.getLoc(), targetObj, libraryName); - if (targetHSACO.empty()) + if (targetHSACO.empty()) { return failure(); + } if (options.enableRegSpillWarning) { checkRegisterSpilling(variantOp, targetObj); diff --git a/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp b/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp index ce4120d2d96b..e6bc477fbdc1 100644 --- a/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp +++ b/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp @@ -50,8 +50,9 @@ loadIRModule(Location loc, const std::string &filename, static LogicalResult linkWithBitcodeFiles(Location loc, llvm::Module *module, ArrayRef bitcodePaths) { - if (bitcodePaths.empty()) + if (bitcodePaths.empty()) { return success(); + } llvm::Linker linker(*module); for (auto &bitcodePath : bitcodePaths) { if (!llvm::sys::fs::exists(bitcodePath)) { @@ -62,8 +63,9 @@ static LogicalResult linkWithBitcodeFiles(Location loc, llvm::Module *module, } std::unique_ptr bitcodeModule = loadIRModule(loc, bitcodePath, &module->getContext()); - if (!bitcodeModule) + if (!bitcodeModule) { return failure(); + } // Ignore the data layout of the module we're importing. This avoids a // warning from the linker. bitcodeModule->setDataLayout(module->getDataLayout()); @@ -107,8 +109,9 @@ static void overridePlatformGlobal(llvm::Module *module, StringRef globalName, uint32_t newValue, llvm::Type *globalTy) { // NOTE: the global will not be defined if it is not used in the module. auto *globalValue = module->getNamedGlobal(globalName); - if (!globalValue) + if (!globalValue) { return; + } globalValue->setDSOLocal(true); globalValue->setConstant(true); globalValue->setInitializer(llvm::ConstantInt::get( @@ -160,10 +163,11 @@ LogicalResult linkHIPBitcodeIfNeeded(Location loc, llvm::Module *module, for (const llvm::Function &function : module->functions()) { if (!function.isIntrinsic() && function.isDeclaration()) { auto functionName = function.getName(); - if (functionName.starts_with("__ocml_")) + if (functionName.starts_with("__ocml_")) { usesOCML = true; - else if (functionName.starts_with("__ockl_")) + } else if (functionName.starts_with("__ockl_")) { usesOCKL = true; + } } } diff --git a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/BUILD.bazel b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/BUILD.bazel index 8b3a36ede083..ee2a7e155773 100644 --- a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/BUILD.bazel +++ b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/BUILD.bazel @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception load("//build_tools/bazel:build_defs.oss.bzl", "iree_cmake_extra_content") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite") load("//build_tools/embed_data:build_defs.bzl", "iree_c_embed_data") @@ -23,7 +24,14 @@ endif() inline = True, ) -ukernel_patterns_mlir_files = glob(["ukernel_patterns_*.mlir"]) +ukernel_patterns_mlir_files = enforce_glob( + # keep sorted + [ + "ukernel_patterns_gfx942.mlir", + "ukernel_patterns_gfx950.mlir", + ], + include = ["ukernel_patterns_*.mlir"], +) iree_c_embed_data( name = "iree_mlir_ukernel_patterns_amdgpu", diff --git a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/CMakeLists.txt b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/CMakeLists.txt index f392a56c96c3..9af5d0c64320 100644 --- a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/CMakeLists.txt +++ b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/CMakeLists.txt @@ -14,12 +14,12 @@ if(NOT IREE_TARGET_BACKEND_ROCM) return() endif() -file(GLOB _GLOB_UKERNEL_PATTERNS_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS ukernel_patterns_*.mlir) iree_c_embed_data( NAME iree_mlir_ukernel_patterns_amdgpu SRCS - "${_GLOB_UKERNEL_PATTERNS_X_MLIR}" + "ukernel_patterns_gfx942.mlir" + "ukernel_patterns_gfx950.mlir" C_FILE_OUTPUT "iree_mlir_ukernel_patterns_amdgpu.c" H_FILE_OUTPUT @@ -32,7 +32,8 @@ iree_lit_test_suite( NAME verify_mlir_ukernel_patterns_amdgpu SRCS - "${_GLOB_UKERNEL_PATTERNS_X_MLIR}" + "ukernel_patterns_gfx942.mlir" + "ukernel_patterns_gfx950.mlir" TOOLS iree-opt ) diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_bf16i32.c b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_bf16i32.c index aeebb08a26b6..cb542ed9d0a5 100644 --- a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_bf16i32.c +++ b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_bf16i32.c @@ -31,8 +31,9 @@ float newIn = idx >= reductionSize ? -FLT_MAX : (float)(inputBuffer[input_offset + idx]); - if (newIn == laneMax) + if (newIn == laneMax) { continue; + } laneMax = __builtin_fmaxf(newIn, laneMax); laneResult = newIn == laneMax ? idx : laneResult; } diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_bf16i64.c b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_bf16i64.c index 50388dac062b..d68f27693e2d 100644 --- a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_bf16i64.c +++ b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_bf16i64.c @@ -31,8 +31,9 @@ float newIn = idx >= reductionSize ? -FLT_MAX : (float)(inputBuffer[input_offset + idx]); - if (newIn == laneMax) + if (newIn == laneMax) { continue; + } laneMax = __builtin_fmaxf(newIn, laneMax); laneResult = newIn == laneMax ? idx : laneResult; } diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i32.c b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i32.c index 2d8f51add345..d2ee3cff419a 100644 --- a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i32.c +++ b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i32.c @@ -25,8 +25,9 @@ int32_t idx = warpSize * i + laneID; _Float16 newIn = idx >= reductionSize ? NEG_F16_MAX : inputBuffer[input_offset + idx]; - if (newIn == laneMax) + if (newIn == laneMax) { continue; + } laneMax = __builtin_fmaxf16(newIn, laneMax); laneResult = newIn == laneMax ? idx : laneResult; } diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i64.c b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i64.c index 2232d5f3887a..fd00bc8c1c33 100644 --- a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i64.c +++ b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i64.c @@ -25,8 +25,9 @@ int32_t idx = warpSize * i + laneID; _Float16 newIn = idx >= reductionSize ? NEG_F16_MAX : inputBuffer[input_offset + idx]; - if (newIn == laneMax) + if (newIn == laneMax) { continue; + } laneMax = __builtin_fmaxf16(newIn, laneMax); laneResult = newIn == laneMax ? idx : laneResult; } diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f32i32.c b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f32i32.c index ad5d5088e054..7819020f9f0f 100644 --- a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f32i32.c +++ b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f32i32.c @@ -24,8 +24,9 @@ int32_t idx = warpSize * i + laneID; float newIn = idx >= reductionSize ? -FLT_MAX : inputBuffer[input_offset + idx]; - if (newIn == laneMax) + if (newIn == laneMax) { continue; + } laneMax = __builtin_fmaxf(newIn, laneMax); laneResult = newIn == laneMax ? idx : laneResult; } diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f32i64.c b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f32i64.c index 5438c79cc182..d608d2368e69 100644 --- a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f32i64.c +++ b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f32i64.c @@ -24,8 +24,9 @@ int32_t idx = warpSize * i + laneID; float newIn = idx >= reductionSize ? -FLT_MAX : inputBuffer[input_offset + idx]; - if (newIn == laneMax) + if (newIn == laneMax) { continue; + } laneMax = __builtin_fmaxf(newIn, laneMax); laneResult = newIn == laneMax ? idx : laneResult; } diff --git a/compiler/plugins/target/VMVX/test/BUILD.bazel b/compiler/plugins/target/VMVX/test/BUILD.bazel index f53b5b481690..1fdf96d9a36f 100644 --- a/compiler/plugins/target/VMVX/test/BUILD.bazel +++ b/compiler/plugins/target/VMVX/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "smoketest.mlir", ], diff --git a/compiler/plugins/target/VMVX/test/smoketest.mlir b/compiler/plugins/target/VMVX/test/smoketest.mlir index c4217e983a05..0b2834da1cda 100644 --- a/compiler/plugins/target/VMVX/test/smoketest.mlir +++ b/compiler/plugins/target/VMVX/test/smoketest.mlir @@ -52,9 +52,11 @@ stream.executable public @add_dispatch_0 { // CHECK-DAG: %[[C1_I32:.+]] = vm.const.i32 1 // CHECK-DAG: %[[C1_I64:.+]] = vm.const.i64 1 // CHECK-DAG: %[[C2_I32:.+]] = vm.const.i32 2 +// CHECK: vm.discard.refs %[[SCRATCHPAD]], %[[CONSTANTS]] // CHECK-NEXT: %[[LHS_BUF:.+]] = vm.list.get.ref %[[BINDINGS]], %[[C0_I32]] : (!vm.list, i32) -> !vm.buffer // CHECK-NEXT: %[[RHS_BUF:.+]] = vm.list.get.ref %[[BINDINGS]], %[[C1_I32]] : (!vm.list, i32) -> !vm.buffer // CHECK-NEXT: %[[RET_BUF:.+]] = vm.list.get.ref %[[BINDINGS]], %[[C2_I32]] : (!vm.list, i32) -> !vm.buffer +// CHECK-NEXT: vm.discard.refs %[[BINDINGS]] // CHECK: vm.br ^bb1(%[[C0_I64]] : i64) // CHECK-NEXT: ^bb1(%[[IDX:.+]]: i64): // CHECK-NEXT: %slt = vm.cmp.lt.i64.s %[[IDX]], %{{.+}} : i64 @@ -68,6 +70,7 @@ stream.executable public @add_dispatch_0 { // CHECK-NEXT: %[[NEXT_IDX:.+]] = vm.add.i64 %[[IDX]], %[[C1_I64]] : i64 // CHECK-NEXT: vm.br ^bb1(%[[NEXT_IDX]] : i64) // CHECK-NEXT: ^bb3: +// CHECK: vm.discard.refs %[[LHS_BUF]], %[[RHS_BUF]], %[[RET_BUF]] // CHECK-NEXT: vm.return // CHECK-NEXT: } // CHECK-NEXT: vm.export @add_dispatch_0 diff --git a/compiler/plugins/target/VulkanSPIRV/test/BUILD.bazel b/compiler/plugins/target/VulkanSPIRV/test/BUILD.bazel index b8394437156e..ca0a92f251c8 100644 --- a/compiler/plugins/target/VulkanSPIRV/test/BUILD.bazel +++ b/compiler/plugins/target/VulkanSPIRV/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "materialize_homogeneous_encodings.mlir", "smoketest.mlir", diff --git a/compiler/src/iree/compiler/API/Internal/BUILD.bazel b/compiler/src/iree/compiler/API/Internal/BUILD.bazel index 11db7604420b..4d801083ab61 100644 --- a/compiler/src/iree/compiler/API/Internal/BUILD.bazel +++ b/compiler/src/iree/compiler/API/Internal/BUILD.bazel @@ -27,6 +27,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/VM/Target:init_targets", "//compiler/src/iree/compiler/Dialect/VM/Target/Bytecode", "//compiler/src/iree/compiler/Dialect/VM/Target/C", + "//compiler/src/iree/compiler/Dialect/VM/Transforms", "//compiler/src/iree/compiler/Pipelines", "//compiler/src/iree/compiler/PluginAPI", "//compiler/src/iree/compiler/PluginAPI:PluginManager", diff --git a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt index b68b35a226a6..18272b115a91 100644 --- a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt +++ b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt @@ -34,6 +34,7 @@ iree_cc_library( iree::compiler::Dialect::VM::Target::Bytecode iree::compiler::Dialect::VM::Target::C iree::compiler::Dialect::VM::Target::init_targets + iree::compiler::Dialect::VM::Transforms iree::compiler::Pipelines iree::compiler::PluginAPI iree::compiler::PluginAPI::PluginManager diff --git a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp index ddade384115e..1a22b8a05779 100644 --- a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp +++ b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp @@ -43,6 +43,7 @@ #include "iree/compiler/API/Internal/Diagnostics.h" #include "iree/compiler/ConstEval/Passes.h" #include "iree/compiler/Dialect/VM/Target/init_targets.h" +#include "iree/compiler/Dialect/VM/Transforms/Passes.h" #include "iree/compiler/Pipelines/Pipelines.h" #include "iree/compiler/PluginAPI/PluginManager.h" #include "iree/compiler/Tools/init_dialects.h" @@ -246,6 +247,7 @@ struct GlobalInit { InputDialectOptions *clInputOptions = nullptr; PreprocessingOptions *clPreprocessingOptions = nullptr; GlobalOptimizationOptions *clGlobalOptimizationOptions = nullptr; + ParameterOptions *clParameterOptions = nullptr; DispatchCreationOptions *clDispatchCreationOptions = nullptr; SchedulingOptions *clSchedulingOptions = nullptr; IREE::HAL::TargetOptions *clHalTargetOptions = nullptr; @@ -292,6 +294,7 @@ void GlobalInit::registerCommandLineOptions() { clInputOptions = &InputDialectOptions::FromFlags::get(); clPreprocessingOptions = &PreprocessingOptions::FromFlags::get(); clGlobalOptimizationOptions = &GlobalOptimizationOptions::FromFlags::get(); + clParameterOptions = &ParameterOptions::FromFlags::get(); clDispatchCreationOptions = &DispatchCreationOptions::FromFlags::get(); clSchedulingOptions = &SchedulingOptions::FromFlags::get(); clHalTargetOptions = &IREE::HAL::TargetOptions::FromFlags::get(); @@ -402,6 +405,7 @@ struct Session { BindingOptions bindingOptions; InputDialectOptions inputOptions; PreprocessingOptions preprocessingOptions; + ParameterOptions parameterOptions; GlobalOptimizationOptions highLevelOptimizationOptions; DispatchCreationOptions dispatchCreationOptions; SchedulingOptions schedulingOptions; @@ -431,6 +435,7 @@ Session::Session(GlobalInit &globalInit) inputOptions = *globalInit.clInputOptions; preprocessingOptions = *globalInit.clPreprocessingOptions; highLevelOptimizationOptions = *globalInit.clGlobalOptimizationOptions; + parameterOptions = *globalInit.clParameterOptions; dispatchCreationOptions = *globalInit.clDispatchCreationOptions; schedulingOptions = *globalInit.clSchedulingOptions; halTargetOptions = *globalInit.clHalTargetOptions; @@ -452,6 +457,7 @@ Session::Session(GlobalInit &globalInit) preprocessingOptions.bindOptions(binder); inputOptions.bindOptions(binder); highLevelOptimizationOptions.bindOptions(binder); + parameterOptions.bindOptions(binder); dispatchCreationOptions.bindOptions(binder); schedulingOptions.bindOptions(binder); halTargetOptions.bindOptions(binder); @@ -527,8 +533,9 @@ Error *Source::split(void (*callback)(iree_compiler_source_t *source, SmallVector rawSubBuffers; // Split dropping the last checkLen chars to enable flagging near misses. origMemBuffer->getBuffer().split(rawSubBuffers, splitMarker); - if (rawSubBuffers.empty()) + if (rawSubBuffers.empty()) { return nullptr; + } for (StringRef subBuffer : rawSubBuffers) { auto splitLoc = SMLoc::getFromPointer(subBuffer.data()); @@ -690,8 +697,9 @@ Error *Output::openMembuffer() { } void Output::keep() { - if (outputFile) + if (outputFile) { outputFile->keep(); + } } // Invocation corresponds to iree_compiler_invocation_t @@ -909,8 +917,9 @@ bool Invocation::importModule(Operation *inputModule, bool steal) { } Operation *Invocation::exportModule() { - if (!parsedModuleIsOwned) + if (!parsedModuleIsOwned) { return nullptr; + } parsedModuleIsOwned = false; return parsedModule; } @@ -954,14 +963,16 @@ bool Invocation::getCompilationPhase(IREEVMPipelinePhase &compileFrom, void Invocation::dumpCompilationPhase(IREEVMPipelinePhase phase, OpPassManager &passManager) { - if (!parsedModule || dumpCompilationPhasesTo.empty()) + if (!parsedModule || dumpCompilationPhasesTo.empty()) { return; + } std::string phaseName; enumerateIREEVMPipelinePhases( [&](IREEVMPipelinePhase enumeratedPhase, StringRef name, StringRef desc) { - if (enumeratedPhase == phase) + if (enumeratedPhase == phase) { phaseName = name; + } }); std::string fileName = @@ -1014,10 +1025,10 @@ bool Invocation::runPipeline(enum iree_compiler_pipeline_t pipeline) { buildIREEVMTransformPassPipeline( session.targetRegistry, session.pipelineOptions, session.bindingOptions, session.inputOptions, session.preprocessingOptions, - session.highLevelOptimizationOptions, session.dispatchCreationOptions, - session.schedulingOptions, session.halTargetOptions, - session.vmTargetOptions, pipelineHooks, *passManager, compileFrom, - compileTo); + session.parameterOptions, session.highLevelOptimizationOptions, + session.dispatchCreationOptions, session.schedulingOptions, + session.halTargetOptions, session.vmTargetOptions, pipelineHooks, + *passManager, compileFrom, compileTo); break; } case IREE_COMPILER_PIPELINE_HAL_EXECUTABLE: { @@ -1048,9 +1059,15 @@ bool Invocation::runPipeline(enum iree_compiler_pipeline_t pipeline) { buildIREEPrecompileTransformPassPipeline( session.targetRegistry, session.pipelineOptions, session.bindingOptions, session.inputOptions, session.preprocessingOptions, - session.highLevelOptimizationOptions, session.dispatchCreationOptions, - session.schedulingOptions, session.halTargetOptions, pipelineHooks, - *passManager, compileFrom, compileTo); + session.parameterOptions, session.highLevelOptimizationOptions, + session.dispatchCreationOptions, session.schedulingOptions, + session.halTargetOptions, pipelineHooks, *passManager, compileFrom, + compileTo); + break; + } + case IREE_COMPILER_PIPELINE_VM: { + IREE::VM::buildVMTransformPassPipeline(*passManager, + session.vmTargetOptions); break; } default: @@ -1069,8 +1086,9 @@ bool Invocation::runPipeline(enum iree_compiler_pipeline_t pipeline) { bool Invocation::runTextualPassPipeline(const char *textPassPipeline) { auto passManager = createPassManager(); if (failed(mlir::parsePassPipeline(textPassPipeline, *passManager, - llvm::errs()))) + llvm::errs()))) { return false; + } if (failed(passManager->run(parsedModule))) { return false; } @@ -1084,8 +1102,9 @@ Error *Invocation::outputIR(Output &output) { Error *Invocation::outputIRBytecode(Output &output, int bytecodeVersion) { mlir::BytecodeWriterConfig config; - if (bytecodeVersion >= 0) + if (bytecodeVersion >= 0) { config.setDesiredBytecodeVersion(bytecodeVersion); + } if (failed(mlir::writeBytecodeToFile(parsedModule, *output.outputStream, config))) { return new Error("illegal bytecode version requested"); @@ -1190,8 +1209,9 @@ void llvmVersionPrinter(llvm::raw_ostream &os) { #endif #if LLVM_VERSION_PRINTER_SHOW_HOST_TARGET_INFO std::string CPU = std::string(llvm::sys::getHostCPUName()); - if (CPU == "generic") + if (CPU == "generic") { CPU = "(unknown)"; + } os << ".\n" << " Default target: " << llvm::sys::getDefaultTargetTriple() << '\n' << " Host CPU: " << CPU; diff --git a/compiler/src/iree/compiler/API/Internal/Diagnostics.cpp b/compiler/src/iree/compiler/API/Internal/Diagnostics.cpp index cd279d9fe7eb..2a4aa1d7ad35 100644 --- a/compiler/src/iree/compiler/API/Internal/Diagnostics.cpp +++ b/compiler/src/iree/compiler/API/Internal/Diagnostics.cpp @@ -22,10 +22,12 @@ namespace mlir::iree_compiler::embed { namespace { /// Return a processable CallSiteLoc from the given location. std::optional getCallSiteLoc(Location loc) { - if (auto callLoc = dyn_cast(loc)) + if (auto callLoc = dyn_cast(loc)) { return callLoc; - if (auto nameLoc = dyn_cast(loc)) + } + if (auto nameLoc = dyn_cast(loc)) { return getCallSiteLoc(cast(loc).getChildLoc()); + } if (auto fusedLoc = dyn_cast(loc)) { for (auto subLoc : cast(loc).getLocations()) { if (auto callLoc = getCallSiteLoc(subLoc)) { @@ -49,9 +51,11 @@ std::optional findLocToShow(Location loc) { .Case([&](FusedLoc fusedLoc) -> std::optional { // Fused location is unique in that we try to find a sub-location to // show, rather than the top-level location itself. - for (Location childLoc : fusedLoc.getLocations()) - if (std::optional showableLoc = findLocToShow(childLoc)) + for (Location childLoc : fusedLoc.getLocations()) { + if (std::optional showableLoc = findLocToShow(childLoc)) { return showableLoc; + } + } return std::nullopt; }) .Case([&](NameLoc nameLoc) -> std::optional { @@ -105,8 +109,9 @@ LogicalResult FormattingDiagnosticHandler::emit(Diagnostic &diag) { // Assemble location fragments. SmallVector> locationStack; auto addLocToStack = [&](Location loc, StringRef locContext) { - if (std::optional showableLoc = findLocToShow(loc)) + if (std::optional showableLoc = findLocToShow(loc)) { locationStack.emplace_back(*showableLoc, locContext); + } }; // Add locations to display for this diagnostic. @@ -121,10 +126,11 @@ LogicalResult FormattingDiagnosticHandler::emit(Diagnostic &diag) { const unsigned callStackLimit = 50; for (unsigned curDepth = 0; curDepth < callStackLimit; ++curDepth) { addLocToStack(loc, "called from"); - if ((callLoc = getCallSiteLoc(loc))) + if ((callLoc = getCallSiteLoc(loc))) { loc = callLoc->getCaller(); - else + } else { break; + } } } @@ -134,8 +140,9 @@ LogicalResult FormattingDiagnosticHandler::emit(Diagnostic &diag) { appendDiag(diag.getLocation(), diag.str(), diag.getSeverity()); } else { appendDiag(locationStack.front().first, diag.str(), diag.getSeverity()); - for (auto &it : llvm::drop_begin(locationStack)) + for (auto &it : llvm::drop_begin(locationStack)) { appendDiag(it.first, it.second, DiagnosticSeverity::Note); + } } // Append each of the notes. diff --git a/compiler/src/iree/compiler/API/Internal/IREEOptToolEntryPoint.cpp b/compiler/src/iree/compiler/API/Internal/IREEOptToolEntryPoint.cpp index 8c981a2b0e29..001680d9d52d 100644 --- a/compiler/src/iree/compiler/API/Internal/IREEOptToolEntryPoint.cpp +++ b/compiler/src/iree/compiler/API/Internal/IREEOptToolEntryPoint.cpp @@ -93,8 +93,9 @@ static LogicalResult ireeOptMainFromCL(int argc, char **argv, auto localBinder = mlir::iree_compiler::OptionsBinder::local(); mlir::iree_compiler::PluginManagerSession pluginSession( pluginManager, localBinder, pluginManagerOptions); - if (failed(pluginSession.initializePlugins())) + if (failed(pluginSession.initializePlugins())) { return failure(); + } pluginSession.registerDialects(registry); // In the normal compiler flow, activated plugins maintain a scoped registry @@ -127,9 +128,10 @@ static LogicalResult ireeOptMainFromCL(int argc, char **argv, // and the process "appears to be stuck". Print a message to let the user know // about it! if (inputFilename == "-" && - sys::Process::FileDescriptorIsDisplayed(fileno(stdin))) + sys::Process::FileDescriptorIsDisplayed(fileno(stdin))) { llvm::errs() << "(processing input from stdin now, hit ctrl-c/ctrl-d to " "interrupt)\n"; + } // Set up the input file. std::string errorMessage; @@ -144,8 +146,9 @@ static LogicalResult ireeOptMainFromCL(int argc, char **argv, llvm::errs() << errorMessage << "\n"; return failure(); } - if (failed(MlirOptMain(output->os(), std::move(file), registry, config))) + if (failed(MlirOptMain(output->os(), std::move(file), registry, config))) { return failure(); + } // Keep the output file if the invocation of MlirOptMain was successful. output->keep(); diff --git a/compiler/src/iree/compiler/API/Internal/IREEReduceToolEntryPoint.cpp b/compiler/src/iree/compiler/API/Internal/IREEReduceToolEntryPoint.cpp index 4cb61ec3c421..3b0e4f017f04 100644 --- a/compiler/src/iree/compiler/API/Internal/IREEReduceToolEntryPoint.cpp +++ b/compiler/src/iree/compiler/API/Internal/IREEReduceToolEntryPoint.cpp @@ -84,9 +84,10 @@ static LogicalResult ireeReduceMainFromCL(int argc, char **argv, // and the process "appears to be stuck". Print a message to let the user know // about it! if (inputFilename == "-" && - sys::Process::FileDescriptorIsDisplayed(fileno(stdin))) + sys::Process::FileDescriptorIsDisplayed(fileno(stdin))) { llvm::errs() << "(processing input from stdin now, hit ctrl-c/ctrl-d to " "interrupt)\n"; + } OwningOpRef module = loadModule(registry, inputFilename); diff --git a/compiler/src/iree/compiler/API/Internal/LLDToolEntryPoint.cpp b/compiler/src/iree/compiler/API/Internal/LLDToolEntryPoint.cpp index b61f6f6cba0f..8c492f4d24b6 100644 --- a/compiler/src/iree/compiler/API/Internal/LLDToolEntryPoint.cpp +++ b/compiler/src/iree/compiler/API/Internal/LLDToolEntryPoint.cpp @@ -71,11 +71,13 @@ static Flavor getFlavor(StringRef s) { static Flavor parseFlavor(std::vector &v) { // Parse -flavor option. if (v.size() > 1 && v[1] == StringRef("-flavor")) { - if (v.size() <= 2) + if (v.size() <= 2) { die("missing arg value for '-flavor'"); + } Flavor f = getFlavor(v[2]); - if (f == Invalid) + if (f == Invalid) { die("Unknown flavor: " + StringRef(v[2])); + } v.erase(v.begin() + 1, v.begin() + 3); return f; } diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/ConvertStreamableOps.cpp b/compiler/src/iree/compiler/Bindings/Native/Transforms/ConvertStreamableOps.cpp index d404dba88dba..0d258f8c204a 100644 --- a/compiler/src/iree/compiler/Bindings/Native/Transforms/ConvertStreamableOps.cpp +++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/ConvertStreamableOps.cpp @@ -280,8 +280,9 @@ static LogicalResult convertStreamableCall(StreamableFunc &streamableFunc, for (auto [i, resultType] : llvm::enumerate(callOp.getResultTypes())) { if (auto shapedType = dyn_cast(resultType)) { const auto &resultDimArgs = streamableFunc.resultDimArgs[i]; - if (resultDimArgs.empty()) + if (resultDimArgs.empty()) { continue; + } if (resultDimArgs.front() == kTiedDim) { // Source from a tied operand. Types must match exactly. assert(streamableFunc.tiedOperands[i] != @@ -360,8 +361,9 @@ class ConvertStreamableOpsPass for (auto originalFuncOp : originalFuncOps) { auto streamableFuncOr = convertStreamableFunc(moduleOp, originalFuncOp, symbolTable); - if (!streamableFuncOr.has_value()) + if (!streamableFuncOr.has_value()) { return signalPassFailure(); + } auto streamableFunc = std::move(streamableFuncOr).value(); streamableFuncs[streamableFunc.funcOp.getName()] = std::move(streamableFunc); diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp index 9a61a5419a90..2aa6f8bbc35a 100644 --- a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp +++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp @@ -400,18 +400,20 @@ static StringAttr inferResultName(MLIRContext *context, int index, } static DictionaryAttr getIOAttr(ArrayAttr allAttrs, unsigned i) { - if (!allAttrs) + if (!allAttrs) { return nullptr; + } return cast_or_null(allAttrs.getValue()[i]); } static void formatIOAttr(DictionaryAttr attrs, llvm::raw_ostream &os) { - if (!attrs || attrs.empty()) + if (!attrs || attrs.empty()) { return; + } auto shouldIncludeAttr = [](const NamedAttribute &attr) { return attr.getName().getValue() != "iree.abi.name"; }; - if (!llvm::any_of(attrs, shouldIncludeAttr)) { + if (llvm::none_of(attrs, shouldIncludeAttr)) { return; } os << " {"; diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/BUILD.bazel index ec389ac0bb29..e44c2a853dc8 100644 --- a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "convert_streamable_ops.mlir", "wrap_entry_points.mlir", diff --git a/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp index ec0c6bc606c4..1d755871f224 100644 --- a/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp +++ b/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp @@ -335,8 +335,9 @@ class WrapEntryPointsPass auto shapeType = dynamicDims.tensorType; unsigned dynamicDimIdx = 0; for (unsigned i = 0; i < shapeType.getRank(); ++i) { - if (!shapeType.isDynamicDim(i)) + if (!shapeType.isDynamicDim(i)) { continue; + } auto dimValue = IREE::Util::ListGetOp::create( builder, loc, builder.getIndexType(), listValue, builder.createOrFold(loc, i)) @@ -492,8 +493,9 @@ class WrapEntryPointsPass wrapperFuncOp.setAllResultAttrs(resultAttrDict); populateReflectionAttrs(entryFuncOp, wrapperFuncOp); - if (auto affinityAttr = entryFuncOp->getAttr("stream.affinity")) + if (auto affinityAttr = entryFuncOp->getAttr("stream.affinity")) { wrapperFuncOp->setAttr("stream.affinity", affinityAttr); + } // Call the entryFuncOp and return the results. // If we wanted to perform additional work here to invalidate cached shapes diff --git a/compiler/src/iree/compiler/Bindings/TFLite/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Bindings/TFLite/Transforms/test/BUILD.bazel index 1579818a792e..c6692436852a 100644 --- a/compiler/src/iree/compiler/Bindings/TFLite/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Bindings/TFLite/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "wrap_entry_points.mlir", ], diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index e5e5806b78f9..f87cd7ddc22a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -105,6 +105,7 @@ iree_compiler_cc_library( "FissionTransferOpsInControlFlow.cpp", "FlattenMemRefSubspanPass.cpp", "FlattenMemRefs.cpp", + "FlattenSwizzleHintAllocs.cpp", "FoldAffineMinInDistributedLoops.cpp", "FoldSplitReductionAndWorkgroupMappingLoopsPass.cpp", "FoldTensorExtractOpPass.cpp", @@ -147,6 +148,7 @@ iree_compiler_cc_library( "PropagateReshapesByExpansion.cpp", "ReconcileTranslationInfo.cpp", "RematerializeParallelOps.cpp", + "RemoveIndexHints.cpp", "RemoveSingleIterationLoop.cpp", "ReplaceSlowMinMaxOps.cpp", "ReshapePatterns.cpp", @@ -195,6 +197,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Common:FoldTensorExtractOpIncGen", "//compiler/src/iree/compiler/Codegen/Dialect/CPU/IR:IREECPUDialect", "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", + "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms:IREECodegenTransforms", "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils", "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect", "//compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils:KnownTargets", diff --git a/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp b/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp index b033e618392b..fc1a56a7cc6f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp @@ -6,6 +6,7 @@ #include "iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h" #include "iree/compiler/Codegen/Common/Transforms.h" +#include "iree/compiler/Codegen/Dialect/Codegen/Transforms/Transforms.h" #include "iree/compiler/Codegen/Transforms/Transforms.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" @@ -36,17 +37,6 @@ using TensorDivisibilityInfo = namespace { -struct RemoveOptimizationBarrier final - : public OpRewritePattern { - using Base::Base; - - LogicalResult matchAndRewrite(IREE::Util::OptimizationBarrierOp barrierOp, - PatternRewriter &rewriter) const override { - rewriter.replaceOp(barrierOp, barrierOp.getOperands()); - return success(); - } -}; - /// This pass is used to materialize information about dynamic dimensions of /// `tensor` operands of an operation in the IR. If a dynamic dimension is /// known to be a multiple of a compile-time constant value, this pass @@ -85,12 +75,14 @@ getTensorDivisibilityInfo(const TensorDynamicDimAnalysis &dynamicDimAnalysis, } for (auto [index, dim] : llvm::enumerate(tensorType.getShape())) { - if (!tensorType.isDynamicDim(index)) + if (!tensorType.isDynamicDim(index)) { continue; + } std::optional dimDivisibility = dynamicDimAnalysis.getDivisibilityInfo(v, index); - if (!dimDivisibility) + if (!dimDivisibility) { continue; + } divisibilityInfo[index] = std::move(dimDivisibility.value()); } @@ -110,10 +102,6 @@ getTensorDivisibilityInfo(const TensorDynamicDimAnalysis &dynamicDimAnalysis, /// inverses of each other. The `util.optimization.barrier` avoid these from /// getting folded away during reshape propagation. Return the result of the /// `tensor.collapse_shape generated. -struct ReshapeOps { - tensor::ExpandShapeOp expandShapeOp; - tensor::CollapseShapeOp collapseShapeOp; -}; static std::optional blockDynamicDimensionsOfValue(RewriterBase &rewriter, const TensorDivisibilityInfo &divisibilityInfo, @@ -205,14 +193,17 @@ static LogicalResult blockDynamicDimensions( Operation *operation, llvm::SmallDenseSet limitToOperandNumbers, llvm::SmallDenseSet limitToResultNumbers) { for (OpOperand &operand : operation->getOpOperands()) { - if (!limitToOperandNumbers.contains(operand.getOperandNumber())) + if (!limitToOperandNumbers.contains(operand.getOperandNumber())) { continue; - if (operand.get().getDefiningOp()) + } + if (operand.get().getDefiningOp()) { continue; + } TensorDivisibilityInfo operandDivisibilityInfo = getTensorDivisibilityInfo(dynamicDimAnalysis, operand.get()); - if (operandDivisibilityInfo.empty()) + if (operandDivisibilityInfo.empty()) { continue; + } std::optional reshapes = blockDynamicDimensionsOfValue( rewriter, operandDivisibilityInfo, operand.get()); if (reshapes) { @@ -224,12 +215,14 @@ static LogicalResult blockDynamicDimensions( OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(operation); for (OpResult result : operation->getResults()) { - if (!limitToResultNumbers.contains(result.getResultNumber())) + if (!limitToResultNumbers.contains(result.getResultNumber())) { continue; + } TensorDivisibilityInfo resultDivisibilityInfo = getTensorDivisibilityInfo(dynamicDimAnalysis, result); - if (resultDivisibilityInfo.empty()) + if (resultDivisibilityInfo.empty()) { continue; + } std::optional reshapes = blockDynamicDimensionsOfValue(rewriter, resultDivisibilityInfo, result); if (reshapes) { @@ -333,6 +326,8 @@ void BlockDynamicDimensionsPass::runOnOperation() { controlFusionFn); IREE::LinalgExt::populateFoldReshapeOpsByExpansionPatterns(patterns, controlFusionFn); + IREE::Codegen::populateFoldReshapeOpsByExpansionPatterns(patterns, + controlFusionFn); // Add patterns to fold `tensor.empty` operations with its consumers. tensor::populateFoldTensorEmptyPatterns(patterns); // Add some additional patterns that can simplify the IR. @@ -382,6 +377,8 @@ void BlockDynamicDimensionsPass::runOnOperation() { controlFn); IREE::LinalgExt::populateFoldReshapeOpsByExpansionPatterns( bubbleExpandShapePatterns, controlFn); + IREE::Codegen::populateFoldReshapeOpsByExpansionPatterns( + bubbleExpandShapePatterns, controlFn); // Add patterns to fold the "bubbled-up" `tensor.expand_shape` operation and // "pushed-down" `tensor.collapse_shape` operation with their interface // bindings or `tensor.empty` operations. @@ -413,7 +410,7 @@ void BlockDynamicDimensionsPass::runOnOperation() { // Delete the optimization barrier and run some further cleanup. { RewritePatternSet removeBarrierOpsPatterns(context); - removeBarrierOpsPatterns.insert(context); + populateRemoveOptimizationBarrierPatterns(removeBarrierOpsPatterns); tensor::ExpandShapeOp::getCanonicalizationPatterns(removeBarrierOpsPatterns, context); tensor::CollapseShapeOp::getCanonicalizationPatterns( diff --git a/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp index db3f6fd95c83..c10f29fbbc42 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp @@ -88,8 +88,9 @@ static bool isFromReadOnlyTensor(Value v, const BufferizationPlan &plan) { /// here). static LogicalResult analyseConstantOp(arith::ConstantOp constantOp, BufferizationPlan &plan) { - if (!isa(constantOp.getResult().getType())) + if (!isa(constantOp.getResult().getType())) { return success(); + } plan.insert(constantOp.getResult()); return success(); } @@ -112,12 +113,14 @@ static OpType getEquivalentOpOfType(Value value, BufferizationPlan &plan) { SmallVector mappedTensors = plan.getTensorsMappedToSameSet(value); for (auto v : mappedTensors) { auto definingOp = v.getDefiningOp(); - if (!definingOp) + if (!definingOp) { continue; + } assert((!equivalentOp || equivalentOp == definingOp) && "found two interface binding ops marked as equivalent"); - if (!equivalentOp) + if (!equivalentOp) { equivalentOp = definingOp; + } } return equivalentOp; } @@ -252,12 +255,14 @@ getTiedOperandsForDPSOps(DestinationStyleOpInterface dpsOp, /// same equivalence class. static LogicalResult analyseDPSOps(DestinationStyleOpInterface dpsOp, BufferizationPlan &plan) { - if (!dpsOp.hasPureTensorSemantics()) + if (!dpsOp.hasPureTensorSemantics()) { return success(); + } auto results = dpsOp->getResults(); auto tiedOperands = getTiedOperandsForDPSOps(dpsOp, plan); - if (tiedOperands.empty()) + if (tiedOperands.empty()) { return failure(); + } for (auto [index, resultTensor, tiedOperand] : llvm::zip_equal( llvm::seq(0, results.size()), results, tiedOperands)) { if (tiedOperand) { @@ -328,13 +333,15 @@ static LogicalResult analyseDestructiveUpdateOp(Operation *op, Value source, } static LogicalResult analyseScfIfOp(scf::IfOp ifOp, BufferizationPlan &plan) { - if (!ifOp.getNumResults()) + if (!ifOp.getNumResults()) { return success(); + } for (auto [result, thenOperand, elseOperand] : llvm::zip_equal(ifOp.getResults(), ifOp.thenYield().getOperands(), ifOp.elseYield().getOperands())) { - if (!isa(result.getType())) + if (!isa(result.getType())) { continue; + } // All results and yields of the if-then-else are tied together. plan.unionSets(result, thenOperand); plan.unionSets(result, elseOperand); @@ -344,8 +351,9 @@ static LogicalResult analyseScfIfOp(scf::IfOp ifOp, BufferizationPlan &plan) { static LogicalResult analyseScfForOp(scf::ForOp forOp, BufferizationPlan &plan) { - if (forOp.getResults().empty()) + if (forOp.getResults().empty()) { return success(); + } if (!llvm::all_of(forOp->getResultTypes(), [](Type resultType) { return isa(resultType); })) { @@ -406,8 +414,9 @@ static void hasDestructiveUpdatePattern(Value source, BufferizationPlan &plan) { for (OpOperand &use : source.getUses()) { auto user = use.getOwner(); // Process only update ops uses here. - if (!isUpdateOp(user)) + if (!isUpdateOp(user)) { continue; + } // If this is not the first use in a tensor::InsertSliceOp abort. if (updateOp) { return; @@ -432,8 +441,9 @@ static void hasDestructiveUpdatePattern(Value source, BufferizationPlan &plan) { Block *updateOpBlock = updateOp->getBlock(); for (OpOperand &use : source.getUses()) { Operation *user = use.getOwner(); - if (user == updateOp) + if (user == updateOp) { continue; + } if (isReadOp(user)) { Value source = getSource(user); assert(source && "unable to find source from read op"); @@ -494,8 +504,9 @@ void BufferizationPlan::dump() { unsigned numSets = 0; for (auto it = mappedTensors.begin(), ie = mappedTensors.end(); it != ie; ++it) { - if (!(*it)->isLeader()) + if (!(*it)->isLeader()) { continue; + } llvm::dbgs() << "\tSet " << numSets; if (storeLeaders.count( getLeaderValue(getValue(*mappedTensors.member_begin(**it))))) { diff --git a/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.h b/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.h index 6523d7771b0a..7dc3015cefd2 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.h +++ b/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.h @@ -62,8 +62,9 @@ class BufferizationPlan { /// the dispatch region. bool isInStoreSet(Value v) { Value leader = getLeaderValue(v); - if (!leader) + if (!leader) { return false; + } return storeLeaders.count(leader); } diff --git a/compiler/src/iree/compiler/Codegen/Common/BufferizeCopyOnlyDispatchesPass.cpp b/compiler/src/iree/compiler/Codegen/Common/BufferizeCopyOnlyDispatchesPass.cpp index bc8d6a07830e..d77447291b7b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BufferizeCopyOnlyDispatchesPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/BufferizeCopyOnlyDispatchesPass.cpp @@ -63,11 +63,13 @@ void BufferizeCopyOnlyDispatchesPass::runOnOperation() { hasDispatchStore = true; return success(isReadOnly(storeOp.getValue())); }); - if (walkResult.wasInterrupted()) + if (walkResult.wasInterrupted()) { return; + } // The function is just a copy and is not yet bufferized. - if (!hasDispatchStore) + if (!hasDispatchStore) { return; + } // Apply the bufferization passes. std::optional maybeBufferizationPipeline = diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index 128b254f6a70..48497a85be7f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -98,6 +98,7 @@ iree_cc_library( "FissionTransferOpsInControlFlow.cpp" "FlattenMemRefSubspanPass.cpp" "FlattenMemRefs.cpp" + "FlattenSwizzleHintAllocs.cpp" "FoldAffineMinInDistributedLoops.cpp" "FoldSplitReductionAndWorkgroupMappingLoopsPass.cpp" "FoldTensorExtractOpPass.cpp" @@ -140,6 +141,7 @@ iree_cc_library( "PropagateReshapesByExpansion.cpp" "ReconcileTranslationInfo.cpp" "RematerializeParallelOps.cpp" + "RemoveIndexHints.cpp" "RemoveSingleIterationLoop.cpp" "ReplaceSlowMinMaxOps.cpp" "ReshapePatterns.cpp" @@ -229,6 +231,7 @@ iree_cc_library( iree::compiler::Codegen::Common::FoldTensorExtractOpIncGen iree::compiler::Codegen::Dialect::CPU::IR::IREECPUDialect iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect + iree::compiler::Codegen::Dialect::Codegen::Transforms::IREECodegenTransforms iree::compiler::Codegen::Dialect::Codegen::Utils iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect iree::compiler::Codegen::Dialect::GPU::TargetUtils::KnownTargets diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPULowerToUKernels.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPULowerToUKernels.cpp index a599c5ce6419..971a4d3c7c18 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPULowerToUKernels.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPULowerToUKernels.cpp @@ -77,8 +77,9 @@ class CPULowerToUKernelsPass /// Returns `true` if an `outsOperand` value is initialized to zero. static bool isInitializedToZero(Value outsOperand) { auto fillOp = outsOperand.getDefiningOp(); - if (!fillOp) + if (!fillOp) { return false; + } Value fillVal = fillOp.getDpsInputOperand(0)->get(); return matchPattern(fillVal, m_Zero()) || matchPattern(fillVal, m_AnyZeroFloat()); diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUPrepareUkernels.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUPrepareUkernels.cpp index 4c35dd38456c..3c3865e26b00 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUPrepareUkernels.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUPrepareUkernels.cpp @@ -60,8 +60,9 @@ static void tileNonPackedDimsFor3DPackOps(RewriterBase &rewriter, // Skip the tiling if the size is already 1. RankedTensorType srcType = packOp.getSourceType(); for (auto [idx, val] : llvm::enumerate(tileSizes)) { - if (val && srcType.getDimSize(idx) == 1) + if (val && srcType.getDimSize(idx) == 1) { return; + } } auto outerDimsPerm = packOp.getOuterDimsPerm(); @@ -96,8 +97,9 @@ static void tileNonPackedDimsFor5DPUnpackOps(RewriterBase &rewriter, // Skip the tiling if the size is already 1. RankedTensorType destType = unpackOp.getDestType(); for (auto [idx, val] : llvm::enumerate(tileSizes)) { - if (val && destType.getDimSize(idx) == 1) + if (val && destType.getDimSize(idx) == 1) { return; + } } auto tilingInterfaceOp = cast(unpackOp.getOperation()); @@ -157,8 +159,9 @@ dropBatchTileSize(IREE::CPU::LoweringConfigAttr config) { SmallVector newItems; for (auto [level, tileSizes, scalableTileFlags] : tilingInfo) { tileSizes.erase(tileSizes.begin()); - if (!scalableTileFlags.empty()) + if (!scalableTileFlags.empty()) { scalableTileFlags.erase(scalableTileFlags.begin()); + } newItems.emplace_back( IREE::CPU::getTilingLevelName(level), IREE::CPU::LoweringConfigAttr::getTilingLevelAttr( @@ -262,16 +265,18 @@ struct Convert3DPackto2DPackPattern : public OpRewritePattern { llvm::SmallDenseSet s; s.insert(packOp.getInnerDimsPos().begin(), packOp.getInnerDimsPos().end()); for (auto dim : llvm::seq(0, packOp.getSourceRank())) { - if (s.contains(dim)) + if (s.contains(dim)) { continue; + } srcPos = dim; break; } int destPos = srcPos; for (auto [idx, val] : llvm::enumerate(packOp.getOuterDimsPerm())) { - if (val == srcPos) + if (val == srcPos) { destPos = idx; + } } if (packOp.getSourceType().getDimSize(srcPos) != 1) { @@ -284,15 +289,17 @@ struct Convert3DPackto2DPackPattern : public OpRewritePattern { SmallVector newInnerDimsPos(packOp.getInnerDimsPos()); for (auto &val : newInnerDimsPos) { assert(val != srcPos); - if (val > srcPos) + if (val > srcPos) { val--; + } } SmallVector newOuterDimsPerm(packOp.getOuterDimsPerm()); if (!newOuterDimsPerm.empty()) { newOuterDimsPerm.erase(newOuterDimsPerm.begin() + destPos); for (auto &val : newOuterDimsPerm) { - if (val > srcPos) + if (val > srcPos) { val--; + } } } @@ -341,8 +348,9 @@ struct Convert5DUnPackto4DUnPackPattern int64_t destPos = 0; for (auto [idx, val] : llvm::enumerate(seqOrOuterDimsPerm)) { - if (s.contains(val)) + if (s.contains(val)) { continue; + } srcPos = idx; destPos = val; break; @@ -361,16 +369,18 @@ struct Convert5DUnPackto4DUnPackPattern SmallVector newInnerDimsPos(unpackOp.getInnerDimsPos()); for (auto &val : newInnerDimsPos) { assert(val != destPos); - if (val > destPos) + if (val > destPos) { val--; + } } SmallVector newOuterDimsPerm(unpackOp.getOuterDimsPerm()); if (!newOuterDimsPerm.empty()) { newOuterDimsPerm.erase(newOuterDimsPerm.begin() + srcPos); for (auto &val : newOuterDimsPerm) { - if (val > destPos) + if (val > destPos) { val--; + } } } diff --git a/compiler/src/iree/compiler/Codegen/Common/ConcretizePadResultShape.cpp b/compiler/src/iree/compiler/Codegen/Common/ConcretizePadResultShape.cpp index d4b782b67eb9..6458cb6f2702 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ConcretizePadResultShape.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ConcretizePadResultShape.cpp @@ -32,8 +32,9 @@ static Value getAsIndexValue(OpFoldResult attrOrValue, OpBuilder &builder, Location loc) { IntegerAttr attr; if (Value val = dyn_cast(attrOrValue)) { - if (val.getType().isIndex()) + if (val.getType().isIndex()) { return val; + } matchPattern(val, m_Constant(&attr)); } else { attr = cast(cast(attrOrValue)); @@ -52,8 +53,9 @@ struct ConcretizePadResultShape final : public OpRewritePattern { LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override { // If the result shape is already static, then nothing to do. - if (padOp.getResultType().hasStaticShape()) + if (padOp.getResultType().hasStaticShape()) { return failure(); + } int rank = padOp.getResultType().getRank(); SmallVector staticShape; @@ -61,8 +63,9 @@ struct ConcretizePadResultShape final : public OpRewritePattern { auto sourceIfxOp = dyn_cast_if_present( padOp.getSource().getDefiningOp()); - if (!sourceIfxOp) + if (!sourceIfxOp) { return failure(); + } SmallVector lowPad = padOp.getMixedLowPad(); SmallVector source = sourceIfxOp.getMixedSizes(); @@ -111,8 +114,9 @@ struct ConcretizePadResultShape final : public OpRewritePattern { affine::canonicalizeMapAndOperands(&map, &valueSizes); cstExpr = dyn_cast(map.getResult(0)); } - if (!cstExpr) + if (!cstExpr) { return failure(); + } staticShape.push_back(cstExpr.getValue()); } diff --git a/compiler/src/iree/compiler/Codegen/Common/ConfigTrackingCanonicalizer.cpp b/compiler/src/iree/compiler/Codegen/Common/ConfigTrackingCanonicalizer.cpp index 3ff9ece926f7..71e3b933b9f3 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ConfigTrackingCanonicalizer.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ConfigTrackingCanonicalizer.cpp @@ -77,10 +77,12 @@ struct ConfigTrackingCanonicalizerPass final GreedySimplifyRegionLevel::Normal); RewritePatternSet owningPatterns(context); - for (auto *dialect : context->getLoadedDialects()) + for (auto *dialect : context->getLoadedDialects()) { dialect->getCanonicalizationPatterns(owningPatterns); - for (RegisteredOperationName op : context->getRegisteredOperations()) + } + for (RegisteredOperationName op : context->getRegisteredOperations()) { op.getCanonicalizationPatterns(owningPatterns, context); + } patterns = std::make_shared(std::move(owningPatterns)); diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ArithToF32.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ArithToF32.cpp index fd33f84de711..b7941bcb2b7e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ArithToF32.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ArithToF32.cpp @@ -46,8 +46,9 @@ Value convertRankedFloat(OpBuilder &builder, Type type, ValueRange inputs, Location loc) { Type eTy = getElementTypeOrSelf(type); Type inputETy = getElementTypeOrSelf(inputs[0].getType()); - if (!isa(getElementTypeOrSelf(type))) + if (!isa(getElementTypeOrSelf(type))) { return nullptr; + } if (inputETy.getIntOrFloatBitWidth() > eTy.getIntOrFloatBitWidth()) { return arith::TruncFOp::create(builder, loc, type, inputs[0]); @@ -66,8 +67,9 @@ struct PrimitiveTypeConverter : public TypeConverter { explicit PrimitiveTypeConverter() { addConversion([](Type type) { return type; }); addConversion([&](SourceType type) -> Type { - if (!isSourceType(type)) + if (!isSourceType(type)) { return type; + } return getTargetType(type); }); addConversion([&](ComplexType type) { @@ -262,16 +264,19 @@ struct ConvertBf16ArithToF32Pass final auto checkOp = [&](Operation *op) { for (Type type : op->getResultTypes()) { - if (!typeConverter.isLegal(type)) + if (!typeConverter.isLegal(type)) { return false; + } } for (Type type : op->getOperandTypes()) { - if (!typeConverter.isLegal(type)) + if (!typeConverter.isLegal(type)) { return false; + } } for (auto ®ion : op->getRegions()) { - if (!typeConverter.isLegal(®ion)) + if (!typeConverter.isLegal(®ion)) { return false; + } } return true; }; diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp index 94b32b9db14d..458ea0d894de 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp @@ -48,8 +48,9 @@ class Bf16EmulationConverter : public TypeConverter { // Scalar case. addConversion([](FloatType ty) -> std::optional { - if (ty.isBF16()) + if (ty.isBF16()) { return IntegerType::get(ty.getContext(), 16); + } return ty; }); @@ -59,12 +60,14 @@ class Bf16EmulationConverter : public TypeConverter { addConversion([this](FunctionType ty) -> std::optional { SmallVector inputs; - if (failed(convertTypes(ty.getInputs(), inputs))) + if (failed(convertTypes(ty.getInputs(), inputs))) { return std::nullopt; + } SmallVector results; - if (failed(convertTypes(ty.getResults(), results))) + if (failed(convertTypes(ty.getResults(), results))) { return std::nullopt; + } return FunctionType::get(ty.getContext(), inputs, results); }); @@ -82,10 +85,11 @@ struct ConvertHalInterfaceBindingSubspan final matchAndRewrite(IREE::HAL::InterfaceBindingSubspanOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type newResultTy = getTypeConverter()->convertType(op.getType()); - if (!newResultTy) + if (!newResultTy) { return rewriter.notifyMatchFailure( op->getLoc(), llvm::formatv("failed to legalize memref type: {}", op.getType())); + } auto newOp = rewriter.replaceOpWithNewOp( @@ -105,10 +109,11 @@ struct ConvertMemRefAlloc final : OpConversionPattern { matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type newTy = getTypeConverter()->convertType(op.getType()); - if (!newTy) + if (!newTy) { return rewriter.notifyMatchFailure( op->getLoc(), llvm::formatv("failed to convert memref type: {}", op.getType())); + } rewriter.replaceOpWithNewOp( op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(), @@ -191,10 +196,11 @@ struct ConvertMemRefLoad final : OpConversionPattern { matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type newResTy = getTypeConverter()->convertType(op.getType()); - if (!newResTy) + if (!newResTy) { return rewriter.notifyMatchFailure( op->getLoc(), llvm::formatv("failed to convert memref type: {}", op.getMemRefType())); + } rewriter.replaceOpWithNewOp( op, newResTy, adaptor.getMemref(), adaptor.getIndices(), @@ -210,10 +216,11 @@ struct ConvertMemRefStore final : OpConversionPattern { matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type newTy = getTypeConverter()->convertType(op.getMemRefType()); - if (!newTy) + if (!newTy) { return rewriter.notifyMatchFailure( op->getLoc(), llvm::formatv("failed to convert memref type: {}", op.getMemRefType())); + } rewriter.replaceOpWithNewOp( op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices(), @@ -327,8 +334,9 @@ struct ConvertBf16ToUInt16BuffersPass final RewritePatternSet patterns(ctx); populateIreeBf16EmulationPatterns(patterns, typeConverter); - if (failed(applyPartialConversion(op, target, std::move(patterns)))) + if (failed(applyPartialConversion(op, target, std::move(patterns)))) { signalPassFailure(); + } } } }; diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp index 46c140f1acd6..d718f83ef523 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp @@ -147,8 +147,9 @@ walkUseToGetDispatchStoreOp(Value value, const BufferizationPlan &plan, return user; } value = getTiedResultForOperand(use, plan); - if (!value) + if (!value) { return nullptr; + } traversedUses.push_back(&use); } // If the value has a use which is a store, then use that directly. @@ -271,8 +272,9 @@ convertToDestinationPassingStyle(OpBuilder &b, auto walkResult = funcOp.walk( [&](tensor::EmptyOp emptyOp) -> WalkResult { for (auto result : emptyOp->getResults()) { - if (!isa(result.getType())) + if (!isa(result.getType())) { continue; + } if (plan.isInStoreSet(result) && !processed.count(result)) { return modifyResultToUseStoreBuffer(b, result, plan, processed); } @@ -291,20 +293,23 @@ canUseInOperandAsInitOperand(OpOperand *inOperand, OpOperand *initOperand, return false; } - if (inOperand->getOwner() != initOperand->getOwner()) + if (inOperand->getOwner() != initOperand->getOwner()) { return false; + } auto linalgOp = dyn_cast(inOperand->getOwner()); - if (!linalgOp) + if (!linalgOp) { return false; + } if (linalgOp.getMatchingIndexingMap(inOperand) != linalgOp.getMatchingIndexingMap(initOperand)) { return false; } - if (inOperand->get().getType() != initOperand->get().getType()) + if (inOperand->get().getType() != initOperand->get().getType()) { return false; + } if (useWARForCooperativeMatrixCodegen) { return true; @@ -330,8 +335,9 @@ canModifyUseToGetValueIntoStoreSet(BufferizationPlan &plan, OpOperand *use, // Currently only look at use in linalg.generic ops. auto genericOpConsumer = dyn_cast(use->getOwner()); - if (!genericOpConsumer) + if (!genericOpConsumer) { return std::nullopt; + } // All loops need to be parallel. if (genericOpConsumer.getNumLoops() != @@ -339,17 +345,20 @@ canModifyUseToGetValueIntoStoreSet(BufferizationPlan &plan, OpOperand *use, return std::nullopt; } - if (genericOpConsumer.isDpsInit(use)) + if (genericOpConsumer.isDpsInit(use)) { return std::nullopt; + } for (auto [index, initOperand] : llvm::enumerate(genericOpConsumer.getDpsInitsMutable())) { // Output tensor is unused in the body computation. - if (genericOpConsumer.payloadUsesValueFromOperand(&initOperand)) + if (genericOpConsumer.payloadUsesValueFromOperand(&initOperand)) { continue; + } // The result of this operation needs to be in a store set. - if (!plan.isInStoreSet(genericOpConsumer->getResult(index))) + if (!plan.isInStoreSet(genericOpConsumer->getResult(index))) { continue; + } if (!canUseInOperandAsInitOperand(use, &initOperand, useWARForCooperativeMatrixCodegen)) { continue; @@ -441,8 +450,9 @@ static LogicalResult adaptComputeConsumerToAvoidStackAllocation( [&](TilingInterface computeOp) -> WalkResult { for (auto result : computeOp->getResults()) { // If result is already in a store set. Nothing to do. - if (plan.isInStoreSet(result)) + if (plan.isInStoreSet(result)) { continue; + } // Check if there are any uses that can be modified to reuse the output // buffer. @@ -450,11 +460,13 @@ static LogicalResult adaptComputeConsumerToAvoidStackAllocation( std::optional reusableOperand = canModifyUseToGetValueIntoStoreSet( plan, &use, useWARForCooperativeMatrixCodegen); - if (!reusableOperand) + if (!reusableOperand) { continue; - if (failed(modifyUseToGetValueIntoStoreSet(rewriter, &use, - reusableOperand.value()))) + } + if (failed(modifyUseToGetValueIntoStoreSet( + rewriter, &use, reusableOperand.value()))) { continue; + } return WalkResult::interrupt(); } } @@ -486,8 +498,9 @@ replaceUnpackEmptyWithAllocTensor(OpBuilder &b, return; } auto emptyOp = unpackOp.getDest().getDefiningOp(); - if (!emptyOp) + if (!emptyOp) { return; + } OpBuilder::InsertionGuard g(b); b.setInsertionPointAfter(emptyOp); @@ -511,13 +524,16 @@ struct RemoveCstOutsDependency Location loc = op.getLoc(); for (OpOperand &opOperand : op.getDpsInitsMutable()) { ElementsAttr attr; - if (!matchPattern(opOperand.get(), m_Constant(&attr))) + if (!matchPattern(opOperand.get(), m_Constant(&attr))) { continue; - if (!attr.isSplat()) + } + if (!attr.isSplat()) { continue; + } auto type = dyn_cast(attr.getType()); - if (!type) + if (!type) { continue; + } TypedAttr scalarAttr = attr.getValues()[0]; modifiedOutput = true; diff --git a/compiler/src/iree/compiler/Codegen/Common/DecomposeConvolutionToLowerDimOps.cpp b/compiler/src/iree/compiler/Codegen/Common/DecomposeConvolutionToLowerDimOps.cpp index f7dcced3766a..88107abf00bc 100644 --- a/compiler/src/iree/compiler/Codegen/Common/DecomposeConvolutionToLowerDimOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/DecomposeConvolutionToLowerDimOps.cpp @@ -72,19 +72,22 @@ computeDecomposedLoweringConfig(ArrayRef computeOps, // ATM only folding of the H dim is supported. // TODO: Add support for cases where the W dim is folded. - if (!foldHDim(convOp)) + if (!foldHDim(convOp)) { return failure(); + } // 2. Get the current lowering config attached to the Conv Op. FailureOr loweringConfigAttr = getFirstLoweringConfig(computeOps); - if (failed(loweringConfigAttr)) + if (failed(loweringConfigAttr)) { return failure(); + } // TODO: Either remove "interchange" from lowering_config or add support in // this pass. - if (!loweringConfigAttr->isInterchangeEmpty()) + if (!loweringConfigAttr->isInterchangeEmpty()) { return failure(); + } // 3. Calculate new tiling levels. // Note that this will basically erase the _H_ dims from the orignal lowering @@ -159,8 +162,9 @@ class DecomposeConvolutionToLowerDimOpsPass final if (numConvOps == 1 && succeeded(newLoweringConfig)) { auto computeOps = getComputeOps(funcOp); for (auto computeOp : computeOps) { - if (isa(computeOp)) + if (isa(computeOp)) { setLoweringConfig(computeOp, newLoweringConfig.value()); + } } } } diff --git a/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp b/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp index 154bbd313dd5..cce497024012 100644 --- a/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp @@ -169,8 +169,9 @@ static LogicalResult commonRunOnOperation( scf::tileConsumerAndFuseProducersUsingSCF( rewriter, cast(op.getOperation()), packOptions); - if (failed(tileAndFuseResult)) + if (failed(tileAndFuseResult)) { return WalkResult::interrupt(); + } rewriter.replaceOp(op, tileAndFuseResult->replacements[op.getResult()]); return WalkResult::advance(); }); @@ -203,8 +204,9 @@ static LogicalResult commonRunOnOperation( FailureOr tilingResult = scf::tileUsingSCF( rewriter, cast(op.getOperation()), unpackTilingOptions); - if (failed(tilingResult)) + if (failed(tilingResult)) { return WalkResult::interrupt(); + } rewriter.replaceOp(op, tilingResult->replacements); return WalkResult::advance(); }); @@ -342,9 +344,8 @@ static LogicalResult isUnpaddedAndAtBoundary(Operation *op) { // If all consumers are dispatch tensor stores, then the `op` is decomposable // if it is an UnPackOp. if (isa(op) && - llvm::all_of(op->getUsers(), [&](Operation *user) { - return isa(user); - })) { + llvm::all_of(op->getUsers(), + llvm::IsaPred)) { return success(); } return failure(); diff --git a/compiler/src/iree/compiler/Codegen/Common/DecomposeSoftmax.cpp b/compiler/src/iree/compiler/Codegen/Common/DecomposeSoftmax.cpp index bcdb06e0f3b3..3e3ca38bed75 100644 --- a/compiler/src/iree/compiler/Codegen/Common/DecomposeSoftmax.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/DecomposeSoftmax.cpp @@ -43,8 +43,9 @@ struct FuseElementWiseGenericOps : public OpRewritePattern { // Find the first operand that is defined by another generic op on tensors. for (OpOperand &opOperand : genericOp->getOpOperands()) { - if (!linalg::areElementwiseOpsFusable(&opOperand)) + if (!linalg::areElementwiseOpsFusable(&opOperand)) { continue; + } // Don't fuse if it has external capture. For e.g., the gather like // payload operation like 'tensor.extract' would be cloned in // every consumer op, which is not what we want. diff --git a/compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp b/compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp index 82ec1b7b8d0a..8bede523de65 100644 --- a/compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp @@ -130,8 +130,9 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, // When extracting all available elements, just use the source vector as the // result. - if (vectorType.getNumElements() == numElemsToExtract) + if (vectorType.getNumElements() == numElemsToExtract) { return src; + } auto offsets = rewriter.getI64ArrayAttr({offset}); auto sizes = rewriter.getI64ArrayAttr({numElemsToExtract}); @@ -160,8 +161,9 @@ static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, "expected source and dest to be rank-1 vector types"); // If overwritting the destination vector, just return the source. - if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0) + if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0) { return src; + } auto offsets = rewriter.getI64ArrayAttr({offset}); auto strides = rewriter.getI64ArrayAttr({1}); @@ -344,9 +346,10 @@ struct IREEConvertVectorStore final : OpConversionPattern { ConversionPatternRewriter &rewriter) const override { // See #115653 - if (op.getValueToStore().getType().getRank() != 1) + if (op.getValueToStore().getType().getRank() != 1) { return rewriter.notifyMatchFailure(op, "only 1-D vectors are supported ATM"); + } auto loc = op.getLoc(); diff --git a/compiler/src/iree/compiler/Codegen/Common/EraseHALDescriptorTypeFromMemRef.cpp b/compiler/src/iree/compiler/Codegen/Common/EraseHALDescriptorTypeFromMemRef.cpp index ae957c479fa9..8d978d85c473 100644 --- a/compiler/src/iree/compiler/Codegen/Common/EraseHALDescriptorTypeFromMemRef.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/EraseHALDescriptorTypeFromMemRef.cpp @@ -46,8 +46,9 @@ struct EraseHALDescriptorTypeFromMemRefPass final AttrTypeReplacer replacer; replacer.addReplacement( [](BaseMemRefType memRefType) -> std::optional { - if (isLegalType(memRefType)) + if (isLegalType(memRefType)) { return std::nullopt; + } // Erase the #hal.descriptor_type memory space. if (auto rankedType = dyn_cast(memRefType)) { @@ -74,8 +75,9 @@ struct ConvertHALDescriptorTypeToGPUAddressSpacePass final AttrTypeReplacer replacer; replacer.addReplacement( [](BaseMemRefType memRefType) -> std::optional { - if (isLegalType(memRefType)) + if (isLegalType(memRefType)) { return std::nullopt; + } Attribute globalSpace = gpu::AddressSpaceAttr::get( memRefType.getContext(), gpu::AddressSpace::Global); diff --git a/compiler/src/iree/compiler/Codegen/Common/ExtractAddressComputation.h b/compiler/src/iree/compiler/Codegen/Common/ExtractAddressComputation.h index e18de97bae05..69ecb1b09d29 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ExtractAddressComputation.h +++ b/compiler/src/iree/compiler/Codegen/Common/ExtractAddressComputation.h @@ -40,8 +40,9 @@ struct StoreLoadLikeOpRewriter : public OpRewritePattern { auto ldTy = cast(srcMemRef.getType()); unsigned storeLoadRank = ldTy.getRank(); // Don't waste compile time if there is nothing to rewrite. - if (storeLoadRank == 0) + if (storeLoadRank == 0) { return failure(); + } // If our load already has only zeros as indices there is nothing // to do. diff --git a/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp b/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp index 814b73300be3..8e160b8fc4ec 100644 --- a/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp @@ -157,8 +157,9 @@ struct FlattenAlloc final : public OpConversionPattern { matchAndRewrite(AllocOpTy allocOp, typename AllocOpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto oldType = dyn_cast(allocOp.getType()); - if (!oldType || !oldType.getLayout().isIdentity()) + if (!oldType || !oldType.getLayout().isIdentity()) { return failure(); + } Value dynamicDim = createTotalElementCountValue( oldType, allocOp.getDynamicSizes(), allocOp.getLoc(), rewriter); @@ -176,8 +177,9 @@ struct FlattenGlobal final : public OpConversionPattern { using Base::Base; static Attribute flattenAttribute(Attribute value, ShapedType newType) { - if (!value) + if (!value) { return value; + } if (auto splatAttr = dyn_cast(value)) { return splatAttr.reshape(newType); } else if (auto denseAttr = dyn_cast(value)) { @@ -194,8 +196,9 @@ struct FlattenGlobal final : public OpConversionPattern { matchAndRewrite(memref::GlobalOp globalOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto oldType = dyn_cast(globalOp.getType()); - if (!oldType || !oldType.getLayout().isIdentity()) + if (!oldType || !oldType.getLayout().isIdentity()) { return failure(); + } auto tensorType = RankedTensorType::get({oldType.getNumElements()}, oldType.getElementType()); @@ -221,13 +224,15 @@ struct FlattenGetGlobal final matchAndRewrite(memref::GetGlobalOp getOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto oldType = dyn_cast(getOp.getType()); - if (!oldType || !oldType.getLayout().isIdentity()) + if (!oldType || !oldType.getLayout().isIdentity()) { return failure(); + } auto globalOp = dyn_cast_if_present( SymbolTable::lookupNearestSymbolFrom(getOp, getOp.getNameAttr())); - if (!globalOp) + if (!globalOp) { return failure(); + } auto loadedValue = rewriter.createOrFold( getOp.getLoc(), globalOp.getType(), getOp.getNameAttr()); @@ -250,8 +255,9 @@ struct FlattenBindingSubspan final auto oldType = dyn_cast(subspanOp.getType()); // IREE subspan ops only use memref types with the default identity // layout maps. - if (!oldType) + if (!oldType) { return failure(); + } OpFoldResult linearShape; if (oldType.hasStaticShape()) { @@ -441,8 +447,9 @@ struct FlattenSubView final : public OpConversionPattern { } Type neededResultType = getTypeConverter()->convertType(op.getResult().getType()); - if (!neededResultType || !isRankZeroOrOneMemRef(neededResultType)) + if (!neededResultType || !isRankZeroOrOneMemRef(neededResultType)) { return failure(); + } Value size = createTotalElementCountValue(op.getType(), op.getSizes(), op.getLoc(), rewriter); SmallVector offsets = mlir::getValueOrCreateConstantIndexOp( @@ -651,13 +658,15 @@ struct AdjustConversionCast final LogicalResult matchAndRewrite(UnrealizedConversionCastOp castOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (castOp->getNumOperands() != 1) + if (castOp->getNumOperands() != 1) { return failure(); + } Value input = adaptor.getOperands().front(); // We only want to handle cases where the cast op handles memref types. - if (!isa(input.getType())) + if (!isa(input.getType())) { return failure(); + } if (!isRankZeroOrOneMemRef(input.getType())) { return rewriter.notifyMatchFailure( @@ -695,8 +704,9 @@ struct FoldMemRefReshape final : public OpConversionPattern { Type newSourceType = adaptor.getSrc().getType(); Type neededResultType = typeConverter->convertType(op.getResult().getType()); - if (!neededResultType) + if (!neededResultType) { return failure(); + } if (newSourceType == neededResultType) { rewriter.replaceOp(op, adaptor.getSrc()); return success(); @@ -769,8 +779,9 @@ struct FlattenMemRefSubspanPass final [](MemRefType type) -> std::optional { // 0-D MemRef types can be used to represent raw pointers for // micro-kernel ABI purposes. Specially allow it. - if (isRankZeroMemRef(type)) + if (isRankZeroMemRef(type)) { return type; + } // Fall back to the default conversion flow. return std::nullopt; @@ -786,8 +797,9 @@ struct FlattenMemRefSubspanPass final internalTypeConverter.addConversion( [](MemRefType type) -> std::optional { // 0-D or 1-D MemRef types are okay. - if (isRankZeroOrOneMemRef(type)) + if (isRankZeroOrOneMemRef(type)) { return type; + } // Fall back to the default conversion flow. return std::nullopt; @@ -857,8 +869,9 @@ struct FlattenMemRefSubspanPass final }); target.addDynamicallyLegalOp( [](UnrealizedConversionCastOp castOp) { - if (castOp->getNumOperands() != 1) + if (castOp->getNumOperands() != 1) { return false; + } Type inputType = castOp->getOperandTypes().front(); return !isa(inputType) || diff --git a/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefs.cpp b/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefs.cpp index d880ff25396e..9d29253fb12b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefs.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefs.cpp @@ -45,8 +45,9 @@ static OpFoldResult computeProduct(Location loc, OpBuilder &builder, SmallVector dynamicPart; AffineExpr result = builder.getAffineConstantExpr(1); for (OpFoldResult term : terms) { - if (!term) + if (!term) { return term; + } std::optional maybeConst = getConstantIntValue(term); if (maybeConst) { result = result * builder.getAffineConstantExpr(*maybeConst); @@ -55,8 +56,9 @@ static OpFoldResult computeProduct(Location loc, OpBuilder &builder, result = result * builder.getAffineSymbolExpr(nDynamic++); } } - if (auto constant = dyn_cast(result)) + if (auto constant = dyn_cast(result)) { return getAsIndexOpFoldResult(builder.getContext(), constant.getValue()); + } return affine::AffineApplyOp::create(builder, loc, result, dynamicPart) .getResult(); } @@ -245,9 +247,10 @@ struct MemRefRewritePatternBase : public OpRewritePattern { LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const override { Value memref = getTargetMemref(op); - if (!needFlattenning(memref) || !checkLayout(memref)) + if (!needFlattenning(memref) || !checkLayout(memref)) { return rewriter.notifyMatchFailure(op, "nothing to do or unsupported layout"); + } auto &&[flatMemref, offset] = getFlattenMemrefAndOffset( rewriter, op->getLoc(), memref, op.getIndices()); replaceOp(op, rewriter, flatMemref, offset); @@ -301,11 +304,13 @@ struct FlattenSubview : public OpRewritePattern { LogicalResult matchAndRewrite(memref::SubViewOp op, PatternRewriter &rewriter) const override { Value memref = op.getSource(); - if (!needFlattenning(memref)) + if (!needFlattenning(memref)) { return rewriter.notifyMatchFailure(op, "nothing to do"); + } - if (!checkLayout(memref)) + if (!checkLayout(memref)) { return rewriter.notifyMatchFailure(op, "unsupported layout"); + } Location loc = op.getLoc(); SmallVector subOffsets = op.getMixedOffsets(); @@ -327,8 +332,9 @@ struct FlattenSubview : public OpRewritePattern { finalStrides.reserve(subRank); for (auto i : llvm::seq(0u, static_cast(srcType.getRank()))) { - if (droppedDims.test(i)) + if (droppedDims.test(i)) { continue; + } finalSizes.push_back(subSizes[i]); finalStrides.push_back(strides[i]); @@ -354,8 +360,9 @@ struct DecomposeMemrefsPass mlir::iree_compiler::populateDecomposeMemrefsPatterns(patterns); - if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); + } } }; diff --git a/compiler/src/iree/compiler/Codegen/Common/FlattenSwizzleHintAllocs.cpp b/compiler/src/iree/compiler/Codegen/Common/FlattenSwizzleHintAllocs.cpp new file mode 100644 index 000000000000..235ab5176d23 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/FlattenSwizzleHintAllocs.cpp @@ -0,0 +1,87 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_FLATTENSWIZZLEHINTALLOCSPASS +#include "iree/compiler/Codegen/Common/Passes.h.inc" + +namespace { +struct FlattenSwizzleHintAllocsPass final + : impl::FlattenSwizzleHintAllocsPassBase { + using Base::Base; + void runOnOperation() override; +}; +} // namespace + +/// This pass flattens swizzle hint ops that operate on allocations of rank > 1. +/// This is required since swizzle hint op indices require flat memrefs. +/// +/// Example: +/// ``` +/// %0 = iree.alloc() : tensor<512x32xf4E2M1FN> +/// %1 = iree.swizzle_hint %0 : tensor<512x32xf4E2M1FN> -> +/// tensor<512x32xf4E2M1FN> +/// ``` +/// +/// is flattened to: +/// ``` +/// %0 = iree.alloc() : tensor<16384xf4E2M1FN> +/// %1 = iree.swizzle_hint %0 : tensor<16384xf4E2M1FN> -> tensor<16384xf4E2M1FN> +/// %2 = iree.expand_shape %1 : tensor<16384xf4E2M1FN> -> +/// tensor<512x32xf4E2M1FN> +/// ``` +static void flattenSwizzleHintAllocs(RewriterBase &rewriter, + IREE::Codegen::SwizzleHintOp hintOp) { + auto allocOp = hintOp.getOperand().getDefiningOp(); + if (!allocOp || !allocOp->hasOneUse()) { + return; + } + MemRefType resultType = allocOp.getType(); + if (resultType.getRank() == 1 || !resultType.getLayout().isIdentity() || + !memref::isStaticShapeAndContiguousRowMajor(resultType)) { + return; + } + + SmallVector newResultShape = {resultType.getNumElements()}; + auto newResultType = + MemRefType::get(newResultShape, resultType.getElementType(), AffineMap(), + resultType.getMemorySpace()); + rewriter.setInsertionPoint(hintOp); + ReassociationIndices reassoc = + llvm::to_vector(llvm::seq(resultType.getRank())); + auto newAllocOp = + memref::AllocOp::create(rewriter, hintOp.getLoc(), newResultType); + auto newSwizzleHintOp = IREE::Codegen::SwizzleHintOp::create( + rewriter, hintOp.getLoc(), newAllocOp.getResult(), hintOp.getSwizzle()); + auto expandShape = memref::ExpandShapeOp::create(rewriter, hintOp.getLoc(), + resultType.getShape(), + newSwizzleHintOp, {reassoc}); + rewriter.replaceOp(hintOp, expandShape); +} + +void FlattenSwizzleHintAllocsPass::runOnOperation() { + FunctionOpInterface funcOp = getOperation(); + // Collect all swizzle hint ops that operate on allocations. + // Flatten all allocs of rank > 1. + SmallVector hintOps; + funcOp.walk( + [&](IREE::Codegen::SwizzleHintOp hint) { hintOps.push_back(hint); }); + + IRRewriter rewriter(funcOp->getContext()); + for (IREE::Codegen::SwizzleHintOp hintOp : hintOps) { + flattenSwizzleHintAllocs(rewriter, hintOp); + } +} + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/FoldAffineMinInDistributedLoops.cpp b/compiler/src/iree/compiler/Codegen/Common/FoldAffineMinInDistributedLoops.cpp index 34fb202ead83..3f3ea86d8aad 100644 --- a/compiler/src/iree/compiler/Codegen/Common/FoldAffineMinInDistributedLoops.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/FoldAffineMinInDistributedLoops.cpp @@ -57,8 +57,9 @@ canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op, rewriter.setInsertionPoint(op); FailureOr simplified = mlir::affine::simplifyConstrainedMinMaxOp(op, std::move(constraints)); - if (failed(simplified)) + if (failed(simplified)) { return failure(); + } return rewriter.replaceOpWithNewOp( op, simplified->getAffineMap(), simplified->getOperands()); } @@ -89,22 +90,26 @@ struct FoldAffineMinOverDistributedLoopInductionVariable final auto loopMatcher = [&](Value iv, OpFoldResult &lb, OpFoldResult &ub, OpFoldResult &step) { scf::ForOp forOp = scf::getForInductionVarOwner(iv); - if (!forOp) + if (!forOp) { return failure(); + } auto loopInfo = isTiledAndDistributedLoop(forOp); - if (!loopInfo) + if (!loopInfo) { return failure(); + } LLVM_DEBUG(llvm::dbgs() << *loopInfo); std::optional untiledStep = getConstantIntValue(loopInfo->untiledStep); // For IREE right now the original untiled loop should have step 1.. - if (!untiledStep || *untiledStep != 1) + if (!untiledStep || *untiledStep != 1) { return failure(); + } // ..and we tile according to some static tile sizes for processors. - if (!loopInfo->tileSize) + if (!loopInfo->tileSize) { return failure(); + } lb = loopInfo->untiledLowerBound; ub = loopInfo->untiledUpperBound; @@ -132,17 +137,21 @@ struct FoldAffineMinOverWorkgroupIDs final // Find all iteration variables among `minOp`'s operands add constrain them. for (Value operand : minOp->getOperands()) { // Skip duplicate ids. - if (!allIds.insert(operand).second) + if (!allIds.insert(operand).second) { continue; + } auto idOp = operand.getDefiningOp(); - if (!idOp) + if (!idOp) { continue; + } // Can't infer the range when workroupCount is unknown. unsigned index = idOp.getDimension().getZExtValue(); - if (index >= numWorkgroup.size()) + if (index >= numWorkgroup.size()) { return failure(); - if (numWorkgroup[index] == ShapedType::kDynamic) + } + if (numWorkgroup[index] == ShapedType::kDynamic) { continue; + } constraints.appendDimVar({idOp}); constraints.addBound(presburger::BoundType::LB, idOp, 0); constraints.addBound(presburger::BoundType::UB, idOp, diff --git a/compiler/src/iree/compiler/Codegen/Common/FoldTensorExtractOpPass.cpp b/compiler/src/iree/compiler/Codegen/Common/FoldTensorExtractOpPass.cpp index c4cc7e2c77a5..7ddd611d8a2e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/FoldTensorExtractOpPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/FoldTensorExtractOpPass.cpp @@ -59,7 +59,8 @@ class FoldTensorExtractOpPass final void FoldTensorExtractOpPass::runOnOperation() { RewritePatternSet patterns(&getContext()); populateWithGenerated(patterns); - if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); + } } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/FoldTensorSubsetIntoVectorTransferOps.cpp b/compiler/src/iree/compiler/Codegen/Common/FoldTensorSubsetIntoVectorTransferOps.cpp index 4e32b7bf1c6b..0041574d3a04 100644 --- a/compiler/src/iree/compiler/Codegen/Common/FoldTensorSubsetIntoVectorTransferOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/FoldTensorSubsetIntoVectorTransferOps.cpp @@ -22,8 +22,9 @@ using namespace mlir; static bool areAllRankReducedLeadingDim(tensor::ExtractSliceOp extractOp, unsigned trailingRank) { // If no ranks are reduced at all, it's a degenerated case; always true. - if (extractOp.getSourceType().getRank() == extractOp.getType().getRank()) + if (extractOp.getSourceType().getRank() == extractOp.getType().getRank()) { return true; + } RankedTensorType inferredType = extractOp.inferResultType( extractOp.getSourceType(), extractOp.getMixedSizes()); @@ -57,19 +58,25 @@ class FoldExtractSliceIntoTransferRead final LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override { // TODO: support 0-d corner case. - if (xferOp.getTransferRank() == 0) + if (xferOp.getTransferRank() == 0) { return failure(); - if (xferOp.hasOutOfBoundsDim()) + } + if (xferOp.hasOutOfBoundsDim()) { return failure(); - if (!xferOp.getPermutationMap().isMinorIdentity()) + } + if (!xferOp.getPermutationMap().isMinorIdentity()) { return failure(); - if (xferOp.getMask()) + } + if (xferOp.getMask()) { return failure(); + } auto extractOp = xferOp.getBase().getDefiningOp(); - if (!extractOp) + if (!extractOp) { return failure(); - if (!extractOp.hasUnitStride()) + } + if (!extractOp.hasUnitStride()) { return failure(); + } // Bail on illegal rank-reduction: we need to check that the rank-reduced // dims are exactly the leading dims. I.e. the following is illegal: @@ -87,8 +94,10 @@ class FoldExtractSliceIntoTransferRead final // ``` // For this, check the trailing `vectorRank` dims of the extract_slice // result tensor match the trailing dims of the inferred result tensor. - if (!areAllRankReducedLeadingDim(extractOp, extractOp.getType().getRank())) + if (!areAllRankReducedLeadingDim(extractOp, + extractOp.getType().getRank())) { return failure(); + } int64_t rankReduced = extractOp.getSourceType().getRank() - extractOp.getType().getRank(); @@ -132,12 +141,15 @@ class FoldExtractSliceIntoTransferRead final /// dynamic tensors, where it resolves the tensor sizes via value-bounds /// analysis, and then checks if the vector type fully overwrites the tensor. static bool isDestinationFullyOverwritten(vector::TransferWriteOp writeOp) { - if (writeOp.hasOutOfBoundsDim()) + if (writeOp.hasOutOfBoundsDim()) { return false; - if (writeOp.getVectorType().getRank() != writeOp.getShapedType().getRank()) + } + if (writeOp.getVectorType().getRank() != writeOp.getShapedType().getRank()) { return false; - if (writeOp.getMask()) + } + if (writeOp.getMask()) { return false; + } std::optional vscaleRange; auto vecType = writeOp.getVectorType(); @@ -155,8 +167,9 @@ static bool isDestinationFullyOverwritten(vector::TransferWriteOp writeOp) { [&](unsigned dimIndex) -> FailureOr { auto size = destShape[dimIndex]; // Fixed-size dimensions are simply included in the shape. - if (size != ShapedType::kDynamic) + if (size != ShapedType::kDynamic) { return iree_compiler::DimBoundSize{size}; + } // (Attempt to) resolve dynamic dimensions via value-bounds analysis. return iree_compiler::computeDimUpperBound(dest, dimIndex, vscaleRange); }; @@ -165,12 +178,15 @@ static bool isDestinationFullyOverwritten(vector::TransferWriteOp writeOp) { ArrayRef vecScalableFlags = vecType.getScalableDims(); for (unsigned d = 0, e = destShape.size(); d < e; ++d) { auto dimSize = resolveDestinationDimSize(d); - if (failed(dimSize)) + if (failed(dimSize)) { return false; - if (dimSize->scalable && !vecScalableFlags[d]) + } + if (dimSize->scalable && !vecScalableFlags[d]) { return false; - if (vecShape[d] != dimSize->baseSize) + } + if (vecShape[d] != dimSize->baseSize) { return false; + } } return true; } @@ -198,23 +214,28 @@ class FoldInsertSliceIntoTransferWrite final LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, PatternRewriter &rewriter) const override { - if (!insertOp.hasUnitStride()) + if (!insertOp.hasUnitStride()) { return failure(); + } auto xferOp = insertOp.getSource().getDefiningOp(); - if (!xferOp) + if (!xferOp) { return failure(); + } // TODO: support 0-d corner case. - if (xferOp.getTransferRank() == 0) + if (xferOp.getTransferRank() == 0) { return failure(); - if (!xferOp.getPermutationMap().isIdentity()) + } + if (!xferOp.getPermutationMap().isIdentity()) { return failure(); + } // Fold only if the TransferWriteOp completely overwrites the `source` with // a vector. I.e., the result of the TransferWriteOp is a new tensor whose // content is the data of the vector. - if (!isDestinationFullyOverwritten(xferOp)) + if (!isDestinationFullyOverwritten(xferOp)) { return failure(); + } // Bail on illegal rank-reduction: we need to check that the rank-reduced // dims are exactly the leading dims. I.e. the following is illegal: @@ -241,8 +262,9 @@ class FoldInsertSliceIntoTransferWrite final auto actualSourceTensorShape = insertOp.getSourceType().getShape(); if (rankReduced > 0 && actualSourceTensorShape.take_back(vectorRank) != - inferredSourceTensorType.getShape().take_back(vectorRank)) + inferredSourceTensorType.getShape().take_back(vectorRank)) { return failure(); + } SmallVector indices = getValueOrCreateConstantIndexOp( rewriter, insertOp.getLoc(), insertOp.getMixedOffsets()); @@ -328,8 +350,9 @@ class FoldExtractSliceIntoTransferWrite final if (!maybeDestSize || !maybeIndex) { continue; } - if (vecSize + *maybeIndex <= *maybeDestSize) + if (vecSize + *maybeIndex <= *maybeDestSize) { inBounds[idx] = true; + } } rewriter.replaceOpWithNewOp( diff --git a/compiler/src/iree/compiler/Codegen/Common/ForOpCanonicalizationPass.cpp b/compiler/src/iree/compiler/Codegen/Common/ForOpCanonicalizationPass.cpp index 9c11c2a30050..8f230dfaafd0 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ForOpCanonicalizationPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ForOpCanonicalizationPass.cpp @@ -158,8 +158,9 @@ struct CanonicalizeForOpInductionVarShape final mapping.map(loopIndVar, start); initArgs[index] = rewriter.clone(*finalIvUser, mapping)->getResult(0); } - if (iteratorFolded.empty()) + if (iteratorFolded.empty()) { return failure(); + } auto newLoop = scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), @@ -230,8 +231,9 @@ struct PackForOpInductionVarVector final : public OpRewritePattern { targetTypes.push_back(targetType); } } - if (ivIndices.empty()) + if (ivIndices.empty()) { return failure(); + } // Bit cast all init values to the smaller vector (fewer elements). auto ivInitValues = llvm::to_vector<8>(forOp.getInitArgs()); @@ -287,8 +289,9 @@ struct PackForOpInductionVarVector final : public OpRewritePattern { yieldOp->setOperands(ivRetValues); SmallVector forRetValues; - for (Value result : newLoop.getResults()) + for (Value result : newLoop.getResults()) { forRetValues.push_back(result); + } // Bit cast return values to the old type to fix for op uses. rewriter.setInsertionPointAfter(newLoop); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp index 3b5958aa2f1b..4324d169ec78 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp @@ -244,8 +244,9 @@ struct DistributeContract final : OpDistributionPattern { int64_t lhsKBatch = lhsLayout.getBatchTile()[lhsK]; int64_t rhsKBatch = rhsLayout.getBatchTile()[rhsK]; - if (lhsKBatch != rhsKBatch) + if (lhsKBatch != rhsKBatch) { return std::nullopt; + } return lhsKBatch; } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel index 7e905bf9f5b2..1ae7e44ee08c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel @@ -74,6 +74,7 @@ iree_compiler_cc_library( "GPUDistributeScfFor.cpp", "GPUDistributeSharedMemoryCopy.cpp", "GPUDistributionPatterns.cpp", + "GPUExpandDimensions.cpp", "GPUFuseAndHoistParallelLoops.cpp", "GPUGeneralizeNamedOps.cpp", "GPUGreedilyDistributeToThreads.cpp", @@ -125,6 +126,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/LinalgExt/Transforms", "//compiler/src/iree/compiler/Dialect/LinalgExt/Utils", "//compiler/src/iree/compiler/Dialect/TensorExt/IR", + "//compiler/src/iree/compiler/Dialect/Util/IR", "//compiler/src/iree/compiler/Utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:AMDGPUDialect", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt index 8fa41efd5439..c09cba442a02 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt @@ -67,6 +67,7 @@ iree_cc_library( "GPUDistributeScfFor.cpp" "GPUDistributeSharedMemoryCopy.cpp" "GPUDistributionPatterns.cpp" + "GPUExpandDimensions.cpp" "GPUFuseAndHoistParallelLoops.cpp" "GPUGeneralizeNamedOps.cpp" "GPUGreedilyDistributeToThreads.cpp" @@ -159,6 +160,7 @@ iree_cc_library( iree::compiler::Dialect::LinalgExt::Transforms iree::compiler::Dialect::LinalgExt::Utils iree::compiler::Dialect::TensorExt::IR + iree::compiler::Dialect::Util::IR iree::compiler::Utils PUBLIC ) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp index 8ab4d99bbf73..b5c51c5cad45 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp @@ -11,6 +11,7 @@ #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h" +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLForwardCompat.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -61,7 +62,6 @@ getTiledOps(Operation *funcOp, IREE::GPU::TilingLevel tilingLevel) { void GPUApplyTilingLevelPass::runOnOperation() { FunctionOpInterface funcOp = getOperation(); - if (!llvm::is_contained({IREE::GPU::TilingLevel::Reduction, IREE::GPU::TilingLevel::Thread, IREE::GPU::TilingLevel::Subgroup, @@ -107,6 +107,7 @@ void GPUApplyTilingLevelPass::runOnOperation() { // Apply cleanup patterns. { RewritePatternSet patterns(context); + IREE::GPU::populateFoldSwizzleHintOpPatterns(patterns); // Merge consecutive insert/extract slice ops to simplify later loop // hoisting patterns. tensor::populateFoldTensorEmptyPatterns(patterns); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCheckResourceUsage.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCheckResourceUsage.cpp index 3db1921caa86..10836994c951 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCheckResourceUsage.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCheckResourceUsage.cpp @@ -29,8 +29,9 @@ static int shapedTypeStaticSize( std::function getIndexBitwidth) { int allocSize = 1; for (auto dimSize : shapedType.getShape()) { - if (ShapedType::isDynamic(dimSize)) + if (ShapedType::isDynamic(dimSize)) { continue; + } allocSize *= dimSize; } if (auto elementType = dyn_cast(shapedType.getElementType())) { @@ -42,8 +43,9 @@ static int shapedTypeStaticSize( assert(getIndexBitwidth && "getIndexBitwidth should have been set earlier"); allocSize *= getIndexBitwidth(func); - } else + } else { allocSize *= IREE::Util::getTypeBitWidth(shapedType.getElementType()); + } } return allocSize; } @@ -53,19 +55,22 @@ static int shapedTypeStaticSize( static LogicalResult checkGPUAllocationSize( mlir::FunctionOpInterface funcOp, unsigned limit, std::function getIndexBitwidth) { - if (funcOp.getFunctionBody().empty()) + if (funcOp.getFunctionBody().empty()) { return success(); + } SmallVector allocOps; funcOp.walk([&](memref::AllocOp allocOp) { allocOps.push_back(allocOp); }); - if (allocOps.empty()) + if (allocOps.empty()) { return success(); + } int cumSize = 0; for (auto allocOp : allocOps) { auto allocType = cast(allocOp.getType()); - if (!hasSharedMemoryAddressSpace(allocType)) + if (!hasSharedMemoryAddressSpace(allocType)) { continue; + } if (!allocOp.getDynamicSizes().empty()) { return allocOp.emitOpError( diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp index a1de04220136..8d4f0fad59a6 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp @@ -108,8 +108,9 @@ computeThreadNumThreadsImpl(OpBuilder &builder, Operation *op, // Find minimum elements per transfer across all DMA sizes. int64_t minElementsPerTransfer = std::numeric_limits::max(); for (int64_t dmaSize : dmaSizes) { - if (dmaSize % elementBits != 0) + if (dmaSize % elementBits != 0) { continue; + } int64_t elementsPerLane = dmaSize / elementBits; int64_t elementsPerTransfer = *subgroupSize * elementsPerLane; minElementsPerTransfer = @@ -424,8 +425,9 @@ struct ConvertGatherToCoalescedDMA int64_t minElementsPerTransfer = std::numeric_limits::max(); for (int64_t dmaSize : dmaSizes) { - if (dmaSize % elementBits != 0) + if (dmaSize % elementBits != 0) { continue; + } int64_t elementsPerLane = dmaSize / elementBits; int64_t elementsPerTransfer = *subgroupSize * elementsPerLane; minElementsPerTransfer = @@ -611,20 +613,23 @@ struct GPUConvertToCoalescedDMAPass final OpTy op) { MLIRContext *context = &getContext(); auto dmaConfig = getLoweringConfig(op); - if (!dmaConfig) + if (!dmaConfig) { return failure(); + } // Get the function containing this operation. auto funcOp = op->template getParentOfType(); - if (!funcOp) + if (!funcOp) { return failure(); + } // Get workgroup size and subgroup size from translation_info. std::optional> workgroupSize = getWorkgroupSize(funcOp); std::optional subgroupSize = getSubgroupSize(funcOp); - if (!workgroupSize || !subgroupSize) + if (!workgroupSize || !subgroupSize) { return failure(); + } // Calculate number of subgroups per dimension. // workgroupSize is [X, Y, Z], and we divide by subgroupSize to get warps. @@ -670,8 +675,9 @@ struct GPUConvertToCoalescedDMAPass final // We need innermostDim >= subgroupSize * minElementsPerLane. int64_t minElementsPerTransfer = std::numeric_limits::max(); for (int64_t dmaSize : dmaSizes) { - if (dmaSize % elementBits != 0) + if (dmaSize % elementBits != 0) { continue; + } int64_t elementsPerLane = dmaSize / elementBits; int64_t elementsPerTransfer = *subgroupSize * elementsPerLane; minElementsPerTransfer = @@ -688,8 +694,9 @@ struct GPUConvertToCoalescedDMAPass final auto [tileSizes, numTiledDims] = computeSubgroupTileSizes(rewriter, shape, numWarps); - if (numTiledDims == 0) + if (numTiledDims == 0) { return failure(); + } scf::SCFTilingOptions tilingOptions; tilingOptions.setTileSizes(tileSizes); @@ -735,8 +742,9 @@ struct GPUConvertToCoalescedDMAPass final }) .Default([](Operation *) { return failure(); }); - if (failed(tilingResult)) + if (failed(tilingResult)) { continue; + } // Replace the original op with the tiled version. rewriter.replaceOp(op, tilingResult->replacements); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCreateFastSlowPath.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCreateFastSlowPath.cpp index efefcebb8cd7..b9f4cea258ae 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCreateFastSlowPath.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCreateFastSlowPath.cpp @@ -58,21 +58,24 @@ static void applyFastSlowPathConversion(mlir::FunctionOpInterface funcOp) { // Find the anchor tensor.pad op, from which we get the conditions for // switching between the fast and slow path. auto padOps = llvm::to_vector(body->getOps()); - if (llvm::size(padOps) != 1) + if (llvm::size(padOps) != 1) { return; + } tensor::PadOp padOp = *padOps.begin(); // If all padding sizes are zero, we don't need to do anything. SmallVector lowPads = padOp.getMixedLowPad(); SmallVector highPads = padOp.getMixedHighPad(); - if (llvm::all_of(lowPads, isZero) && llvm::all_of(highPads, isZero)) + if (llvm::all_of(lowPads, isZero) && llvm::all_of(highPads, isZero)) { return; + } IRRewriter rewriter(funcOp.getContext()); rewriter.setInsertionPoint(body->getTerminator()); SmallVector allOps; - for (Operation &op : body->without_terminator()) + for (Operation &op : body->without_terminator()) { allOps.push_back(&op); + } BackwardSliceOptions options; options.filter = [](Operation *op) { return true; }; @@ -96,13 +99,15 @@ static void applyFastSlowPathConversion(mlir::FunctionOpInterface funcOp) { } } Value ifCond = eqZeroCmpVals.front(); - for (Value cmp : llvm::ArrayRef(eqZeroCmpVals).drop_front()) + for (Value cmp : llvm::ArrayRef(eqZeroCmpVals).drop_front()) { ifCond = arith::AndIOp::create(rewriter, loc, ifCond, cmp); + } SmallVector cloneOps; for (Operation *op : allOps) { - if (!padSizeOps.contains(op)) + if (!padSizeOps.contains(op)) { cloneOps.push_back(op); + } } // Build the scf.if op itself. Clone all ops other than those used for @@ -122,15 +127,17 @@ static void applyFastSlowPathConversion(mlir::FunctionOpInterface funcOp) { }; auto elseBuilder = [&](OpBuilder &builder, Location loc) { IRMapping bvm; - for (Operation *op : cloneOps) + for (Operation *op : cloneOps) { builder.clone(*op, bvm); + } scf::YieldOp::create(builder, loc); }; scf::IfOp::create(rewriter, padOp.getLoc(), ifCond, thenBuilder, elseBuilder); // All of these ops have been cloned to both regions. Erease them now. - for (Operation *op : llvm::reverse(cloneOps)) + for (Operation *op : llvm::reverse(cloneOps)) { rewriter.eraseOp(op); + } } struct GPUCreateFastSlowPathPass final diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistribute.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistribute.cpp index e893d8a019ae..c5645046b66e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistribute.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistribute.cpp @@ -25,8 +25,9 @@ replaceUnitMappingIdsHelper(RewriterBase &rewriter, Location loc, Block *parent, Value replacement, ArrayRef availableMappingSizes) { parent->walk([&](gpu::ThreadIdOp idOp) { - if (availableMappingSizes[static_cast(idOp.getDimension())] == 1) + if (availableMappingSizes[static_cast(idOp.getDimension())] == 1) { rewriter.replaceAllUsesWith(idOp.getResult(), replacement); + } }); } @@ -51,14 +52,17 @@ DiagnosedSilenceableFailure static mapNestedForallToThreadsImpl( diag = mlir::transform::gpu::mapOneForallToThreadsImpl( rewriter, std::nullopt, forallOp, blockDims, warpSize, syncAfterDistribute); - if (diag.isDefiniteFailure()) + if (diag.isDefiniteFailure()) { return WalkResult::interrupt(); - if (diag.succeeded()) + } + if (diag.succeeded()) { return WalkResult::skip(); + } return WalkResult::advance(); }); - if (walkResult.wasInterrupted()) + if (walkResult.wasInterrupted()) { return diag; + } // Replace ids of dimensions known to be 1 by 0 to simplify the IR. // Here, the result of mapping determines the available mapping sizes. @@ -96,16 +100,19 @@ struct GPUDistributePass final if (!hasWorkgroupMapping) { result = mapNestedForallToThreadsImpl( rewriter, forallOp, workgroupSize.value(), subgroupSize, false); - if (result.isDefiniteFailure()) + if (result.isDefiniteFailure()) { return WalkResult::interrupt(); - if (result.succeeded()) + } + if (result.succeeded()) { return WalkResult::skip(); + } } return WalkResult::advance(); }); - if (walkResult.wasInterrupted()) + if (walkResult.wasInterrupted()) { return signalPassFailure(); + } } }; } // namespace diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeScfFor.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeScfFor.cpp index 504e2366b47a..67b5b1c680fe 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeScfFor.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeScfFor.cpp @@ -40,8 +40,9 @@ struct DistributeLoop final : OpRewritePattern { // Only distribute if we see the marker attribute. auto numDimAttr = forOp->getAttrOfType(getGPUDistributeAttrName()); - if (!numDimAttr) + if (!numDimAttr) { return failure(); + } // Get workgroup sizes if not using gpu.block_dim SmallVector workgroupSize; diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp index 9a4fca90eb3e..31b55b56a77d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp @@ -70,8 +70,9 @@ static LogicalResult tileCopyToWorkgroupMem(mlir::FunctionOpInterface funcOp, unsigned rank = dstMemRefType.getRank(); // Return empty tile size for zero dim tensor. - if (rank == 0) + if (rank == 0) { return tileSizesVal; + } int copyTileSize = copyVectorNumBits / dstMemRefType.getElementTypeBitWidth(); for (unsigned i = 0; i < rank - 1; i++) { @@ -145,8 +146,9 @@ getTileToDistributableSize(linalg::GenericOp copyOp, unroll.push_back(numThreads * numElementPerThread); assert(threadsAvailable % numThreads == 0); threadsAvailable = threadsAvailable / numThreads; - if (threadsAvailable == 1) + if (threadsAvailable == 1) { break; + } } assert(threadsAvailable == 1); unroll.resize(shape.size(), 1); @@ -162,8 +164,9 @@ static LogicalResult tileToUnroll(mlir::FunctionOpInterface funcOp, [flatWorkgroupSize](OpBuilder &builder, Operation *operation) { SmallVector tileSizesVal; auto copyOp = dyn_cast(operation); - if (!copyOp) + if (!copyOp) { return tileSizesVal; + } std::optional> staticSize = getTileToDistributableSize(copyOp, flatWorkgroupSize); for (int64_t dim : *staticSize) { @@ -235,8 +238,9 @@ static LogicalResult tileAndDistribute(mlir::FunctionOpInterface funcOp, [](OpBuilder &builder, Operation *operation) { SmallVector tileSizesVal; auto copyOp = dyn_cast(operation); - if (!copyOp) + if (!copyOp) { return tileSizesVal; + } SmallVector staticSize = getNativeDstShape(copyOp); for (int64_t dim : staticSize) { tileSizesVal.push_back(arith::ConstantIndexOp::create( @@ -308,8 +312,9 @@ static Value createFlatId(mlir::FunctionOpInterface funcOp, static void hoistAlloc(mlir::FunctionOpInterface funcOp) { SmallVector allocs; funcOp.walk([&](memref::AllocOp alloc) { - if (alloc.getOperands().empty()) + if (alloc.getOperands().empty()) { allocs.push_back(alloc); + } }); for (memref::AllocOp alloc : allocs) { alloc->moveBefore(&(*funcOp.getBlocks().begin()), @@ -325,15 +330,17 @@ static void removeRedundantBarriers(mlir::FunctionOpInterface funcOp) { Operation *prevOp = copyOp->getPrevNode(); SmallVector redundantBarriers; while (prevOp) { - if (isa(prevOp)) + if (isa(prevOp)) { redundantBarriers.push_back(prevOp); - else + } else { break; + } prevOp = prevOp->getPrevNode(); } if (prevOp && hasMarker(prevOp, getCopyToWorkgroupMemoryMarker())) { - for (Operation *op : redundantBarriers) + for (Operation *op : redundantBarriers) { op->erase(); + } } } }); @@ -345,8 +352,9 @@ static int64_t numIteration(scf::ForOp forOp) { auto ubCstOp = forOp.getUpperBound().getDefiningOp(); auto stepCstOp = forOp.getStep().getDefiningOp(); if (!lbCstOp || !ubCstOp || !stepCstOp || lbCstOp.value() < 0 || - ubCstOp.value() < 0 || stepCstOp.value() < 0) + ubCstOp.value() < 0 || stepCstOp.value() < 0) { return 0; + } int64_t tripCount = llvm::divideCeil(ubCstOp.value() - lbCstOp.value(), stepCstOp.value()); return tripCount; @@ -358,8 +366,9 @@ unrollSharedMemoryLoops(mlir::FunctionOpInterface funcOp, const llvm::SmallDenseSet &loopsToIgnore) { SmallVector forOpsToUnroll; funcOp.walk([&](scf::ForOp forOp) { - if (!loopsToIgnore.count(forOp)) + if (!loopsToIgnore.count(forOp)) { forOpsToUnroll.push_back(forOp); + } }); for (scf::ForOp forOp : llvm::reverse(forOpsToUnroll)) { (void)loopUnrollByFactor(forOp, numIteration(forOp)); @@ -378,11 +387,13 @@ LogicalResult gpuDistributeSharedMemoryCopy(mlir::FunctionOpInterface funcOp) { MLIRContext *context = funcOp.getContext(); SmallVector copiesToWorkgroupMem; funcOp.walk([&](linalg::GenericOp copyOp) { - if (hasMarker(copyOp, getCopyToWorkgroupMemoryMarker())) + if (hasMarker(copyOp, getCopyToWorkgroupMemoryMarker())) { copiesToWorkgroupMem.push_back(copyOp); + } }); - if (copiesToWorkgroupMem.empty()) + if (copiesToWorkgroupMem.empty()) { return success(); + } // Step 0. First clean up the IR. hoistAlloc(funcOp); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp index 07db6b4b82aa..3714e71e28ff 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp @@ -31,13 +31,15 @@ struct DistributeConstants final : OpDistributionPattern { DistributionSignature &signature, PatternRewriter &rewriter) const override { auto constant = dyn_cast(constantOp.getResult()); - if (!constant) + if (!constant) { return failure(); + } // Only handle splat values for now. auto attr = dyn_cast(constantOp.getValue()); - if (!attr) + if (!attr) { return failure(); + } VectorLayoutInterface layout = signature[constant]; @@ -62,8 +64,9 @@ struct DistributePoison final : OpDistributionPattern { PatternRewriter &rewriter) const override { auto poisonVal = dyn_cast(poisonOp.getResult()); - if (!poisonVal) + if (!poisonVal) { return failure(); + } SmallVector distributedShape = signature[poisonVal].getDistributedShape(); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUExpandDimensions.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUExpandDimensions.cpp new file mode 100644 index 000000000000..71cb20c85580 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUExpandDimensions.cpp @@ -0,0 +1,290 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/Transforms.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" +#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" +#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/Support/DebugLog.h" +#include "llvm/Support/LogicalResult.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-codegen-gpu-expand-dimensions" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_GPUEXPANDDIMENSIONSPASS +#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc" + +namespace { + +struct GPUExpandDimensionsPass final + : impl::GPUExpandDimensionsPassBase { + using Base::Base; + void runOnOperation() override; +}; +} // namespace + +// Compute the expanded shape for a reassociation group. Requires the original +// dimension to be static and evenly divisible by the product of static factors +// in the target shape. +static FailureOr> computeExpandedGroupShape( + RewriterBase &rewriter, Location loc, OpFoldResult origDimSize, + ArrayRef groupTargetShape, unsigned iteratorDim) { + if (groupTargetShape.size() == 1) { + return SmallVector{origDimSize}; + } + + std::optional staticOrigDim = getConstantIntValue(origDimSize); + if (!staticOrigDim) { + return rewriter.notifyMatchFailure( + loc, "dimension " + Twine(iteratorDim) + + " is dynamic, but expand_dims requires static dimensions"); + } + + int64_t staticFactor = llvm::product_of( + llvm::make_filter_range(groupTargetShape, ShapedType::isStatic)); + + if (staticFactor < 1) { + return rewriter.notifyMatchFailure( + loc, "invalid expansion factor " + Twine(staticFactor) + + " for iterator dimension " + Twine(iteratorDim)); + } + + if (staticOrigDim.value() % staticFactor != 0) { + return rewriter.notifyMatchFailure( + loc, "dimension " + Twine(iteratorDim) + + " (size=" + Twine(staticOrigDim.value()) + + ") not divisible by expansion factor " + Twine(staticFactor)); + } + + return llvm::map_to_vector( + groupTargetShape, [&](int64_t size) -> OpFoldResult { + if (ShapedType::isStatic(size)) { + return rewriter.getIndexAttr(size); + } + AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + return affine::makeComposedFoldedAffineApply( + rewriter, loc, s0.floorDiv(staticFactor), {origDimSize}); + }); +} + +// For an operation annotated with the `expand_dims` attribute, replace relevant +// operands with tensor.expand_shape/tensor.collapse_shape pair to materialize +// dimension expansion according to the reassociation and output_shape defined +// in the attribute. +// +// Example: +// +// ```mlir +// %0 = (..., %0, ...) { +// lowering_config = #iree_gpu.lowering_config<{ +// expand_dims = #iree_gpu.expand_dims +// [[0], [1, 2]], output_shape = [?, ?, 8]> +// }> +// } : ... -> tensor<4x128xf32> +// ``` +// +// becomes: +// +// ```mlir +// %expanded = tensor.expand_shape %0 [[0], [1, 2]] +// : tensor<4x128xf32> into tensor<4x16x8xf32> +// %barrier = util.optimization_barrier %expanded +// %collapsed = tensor.collapse_shape %barrier [[0], [1, 2]] +// : tensor<4x16x8xf32> into tensor<4x128xf32> +// %1 = (..., %collapsed, ...) : ... -> tensor<4x128xf32> +// ``` +static std::optional +createDimensionExpansionOps(RewriterBase &rewriter, + IREE::GPU::DimensionExpansionAttr config, Value v, + AffineMap indexingMap, linalg::LinalgOp op) { + auto tensorType = dyn_cast(v.getType()); + if (!tensorType) { + return std::nullopt; + } + + Location loc = v.getLoc(); + MLIRContext *ctx = op.getContext(); + int64_t tensorRank = tensorType.getRank(); + ArrayRef outputShape = config.getOutputShape().asArrayRef(); + SmallVector origShape = tensor::getMixedSizes(rewriter, loc, v); + + // Map each tensor dimension to its expanded shape components. + SmallVector> expandedShapes(tensorRank); + for (auto [iterDim, reassocIndices] : + llvm::enumerate(config.getReassociationIndices())) { + std::optional tensorDim = + indexingMap.getResultPosition(getAffineDimExpr(iterDim, ctx)); + if (!tensorDim.has_value()) { + continue; + } + + auto groupOutputShape = llvm::map_to_vector( + reassocIndices, [&](int64_t i) { return outputShape[i]; }); + + FailureOr> groupShape = computeExpandedGroupShape( + rewriter, loc, origShape[tensorDim.value()], groupOutputShape, iterDim); + if (failed(groupShape)) { + return std::nullopt; + } + + expandedShapes[tensorDim.value()] = std::move(groupShape.value()); + } + + // Build reassociation indices and expanded shape in tensor dimension order. + SmallVector reassociation; + SmallVector expandedShape; + for (auto [tensorDim, expanded] : llvm::enumerate(expandedShapes)) { + ReassociationIndices &indices = reassociation.emplace_back(); + auto addDim = [&](OpFoldResult dim) { + indices.push_back(expandedShape.size()); + expandedShape.push_back(dim); + }; + if (expanded.empty()) { + addDim(origShape[tensorDim]); + } else { + llvm::for_each(expanded, addDim); + } + } + + // If no expansion is needed, return early. + if (llvm::equal(origShape, expandedShape)) { + return std::nullopt; + } + + auto staticShape = llvm::map_to_vector(expandedShape, [](OpFoldResult ofr) { + return getConstantIntValue(ofr).value(); + }); + + auto expandedType = RankedTensorType::get( + staticShape, tensorType.getElementType(), tensorType.getEncoding()); + + auto expandOp = tensor::ExpandShapeOp::create(rewriter, loc, expandedType, v, + reassociation, expandedShape); + Value barrier = IREE::Util::OptimizationBarrierOp::create( + rewriter, loc, expandOp.getResult()) + .getResult(0); + auto collapseOp = tensor::CollapseShapeOp::create(rewriter, loc, tensorType, + barrier, reassociation); + + return ReshapeOps{expandOp, collapseOp}; +} + +static LogicalResult expandIterationSpace(RewriterBase &rewriter, + linalg::LinalgOp op) { + auto loweringConfig = getLoweringConfig(op); + if (!loweringConfig) { + return success(); + } + auto config = IREE::GPU::getDimensionExpansion(loweringConfig); + if (!config) { + return success(); + } + + LDBG() << "Expanding dimensions for op: " << *op; + + for (OpOperand &operand : op->getOpOperands()) { + AffineMap indexingMap = op.getMatchingIndexingMap(&operand); + std::optional reshapes = createDimensionExpansionOps( + rewriter, config, operand.get(), indexingMap, op); + if (reshapes.has_value()) { + rewriter.modifyOpInPlace( + op, [&]() { operand.set(reshapes.value().collapseShapeOp); }); + } + } + + return success(); +} + +void GPUExpandDimensionsPass::runOnOperation() { + Operation *operation = getOperation(); + MLIRContext *context = &getContext(); + IRRewriter rewriter(context); + + SmallVector worklist; + operation->walk([&](linalg::LinalgOp op) { + if (auto cfg = getLoweringConfig(op)) { + if (IREE::GPU::getDimensionExpansion(cfg)) { + worklist.push_back(op); + } + } + }); + + for (linalg::LinalgOp op : worklist) { + rewriter.setInsertionPoint(op); + if (failed(expandIterationSpace(rewriter, op))) { + return signalPassFailure(); + } + } + + LDBG() << "After expanding dimensions: " << *operation; + + ConfigTrackingListener listener; + GreedyRewriteConfig config; + config.setListener(&listener); + + { + RewritePatternSet bubbleExpandShapePatterns(context); + linalg::ControlFusionFn controlFn = [](OpOperand *opOperand) { + return !isa_and_nonnull( + opOperand->get().getDefiningOp()); + }; + linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns, + controlFn); + IREE::LinalgExt::populateFoldReshapeOpsByExpansionPatterns( + bubbleExpandShapePatterns, controlFn); + tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns); + tensor::populateBubbleUpExpandShapePatterns(bubbleExpandShapePatterns); + linalg::FillOp::getCanonicalizationPatterns( + bubbleExpandShapePatterns, bubbleExpandShapePatterns.getContext()); + memref::populateResolveRankedShapedTypeResultDimsPatterns( + bubbleExpandShapePatterns); + if (failed(applyPatternsGreedily( + operation, std::move(bubbleExpandShapePatterns), config))) { + operation->emitOpError( + "failed in application of bubble up expand shape patterns"); + return signalPassFailure(); + } + } + + LDBG() << "After reshape propagation: " << *operation; + + { + RewritePatternSet removeBarrierOpsPatterns(context); + populateRemoveOptimizationBarrierPatterns(removeBarrierOpsPatterns); + tensor::ExpandShapeOp::getCanonicalizationPatterns(removeBarrierOpsPatterns, + context); + tensor::CollapseShapeOp::getCanonicalizationPatterns( + removeBarrierOpsPatterns, context); + tensor::populateFoldTensorEmptyPatterns(removeBarrierOpsPatterns); + linalg::FillOp::getCanonicalizationPatterns(removeBarrierOpsPatterns, + context); + memref::populateResolveRankedShapedTypeResultDimsPatterns( + removeBarrierOpsPatterns); + if (failed(applyPatternsGreedily(operation, + std::move(removeBarrierOpsPatterns)))) { + operation->emitOpError("failed in cleanup patterns"); + return signalPassFailure(); + } + } + + return; +} + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp index 4a8dfffa10c5..48f4e11ba04c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp @@ -175,6 +175,11 @@ struct FuseTilableDestinationProducers final : OpRewritePattern { tileableProducer = forallOp.getTiedLoopInit(iterArg) ->get() .getDefiningOp(); + // Pad fusion is handled separately as we dont want zero slice guards that + // happen by default. + if (tileableProducer && isa(tileableProducer)) { + tileableProducer = nullptr; + } if (tileableProducer) { break; } @@ -266,7 +271,9 @@ struct FuseTilableSliceProducers final return failure(); } auto tilableProducer = sliceOp.getSource().getDefiningOp(); - if (!tilableProducer) { + // Pad fusion is handled separately as we dont want zero slice guards that + // happen by default. + if (!tilableProducer || isa(tilableProducer)) { return failure(); } @@ -394,6 +401,12 @@ void GPUFuseAndHoistParallelLoopsPass::runOnOperation() { patterns.add(context); tensor::populateFoldTensorEmptyPatterns(patterns); scf::ForallOp::getCanonicalizationPatterns(patterns, context); + auto zeroSliceGuard = [](tensor::ExtractSliceOp) -> std::optional { + // Do not use zero slice gaurd. + return false; + }; + patterns.add(context, + zeroSliceGuard); if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { funcOp->emitOpError("failed to apply fusion + hoisting patterns (set 3)"); return signalPassFailure(); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUGreedilyDistributeToThreads.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUGreedilyDistributeToThreads.cpp index 290713161da5..bd27e31e8282 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUGreedilyDistributeToThreads.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUGreedilyDistributeToThreads.cpp @@ -139,8 +139,9 @@ static void processRegion(RewriterBase &rewriter, Region *region) { if (auto tilableOp = dyn_cast(op)) { // Do not distribute to threads of an op wants to use DMA. if (auto useDMAConfig = - getLoweringConfig(op)) + getLoweringConfig(op)) { continue; + } tileToThreads(rewriter, tilableOp); continue; } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp index 70a501317a3e..006bb813003b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp @@ -17,6 +17,7 @@ #include "llvm/Support/InterleavedRange.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Remarks.h" #define DEBUG_TYPE "iree-codegen-gpu-heuristics" @@ -65,13 +66,28 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const GemmSize &gemmSize) { static int64_t calculateOperandsSharedMemoryUsedInBytes( const GPUMMASchedule &schedule, int64_t lhsBitwidth, int64_t rhsBitwidth, + int64_t lhsScaleBitwidth = 0, int64_t rhsScaleBitwidth = 0, int64_t numRhs = 1) { int64_t tileM = schedule.getTotalMSize() * schedule.getTotalMTileSize() * schedule.getTotalMSubgroupCount(); int64_t tileN = schedule.getTotalNSize() * schedule.getTotalNTileSize() * schedule.getTotalNSubgroupCount(); + + // For scaled matmul, the K dimension is split into Ko (outer) and Kb (block), + // where elements in a Kb block share the same scale. For lhs and rhs we + // account for both Ko and Kb, while for scale operands, only Ko. For regular + // matmul, scale bitwidth is 0 so the scale terms below have no effect. int64_t tileK = schedule.getTotalKSize() * schedule.getTotalKTileSize(); - return (tileM * tileK * lhsBitwidth + numRhs * tileN * tileK * rhsBitwidth) / + int64_t tileKb = schedule.kSizes.back() * schedule.kTileSizes.back(); + int64_t tileKo = tileK / tileKb; + + int64_t lhsSharedMemoryUsed = tileM * tileK * lhsBitwidth; + int64_t rhsSharedMemoryUsed = numRhs * tileN * tileK * rhsBitwidth; + int64_t aScaleSharedMemoryUsed = tileM * tileKo * lhsScaleBitwidth; + int64_t bScaleSharedMemoryUsed = numRhs * tileN * tileKo * rhsScaleBitwidth; + + return (lhsSharedMemoryUsed + rhsSharedMemoryUsed + aScaleSharedMemoryUsed + + bScaleSharedMemoryUsed) / 8; } @@ -191,8 +207,9 @@ static FailureOr fitScheduleInSharedMemory( auto decrementIfPossible = [](MutableArrayRef sizes) -> LogicalResult { for (int64_t &size : sizes) { - if (size <= 1) + if (size <= 1) { continue; + } --size; return success(); } @@ -647,9 +664,9 @@ static int64_t adjustSeedsForWgpCount(const GPUMatmulShapeType &problem, FailureOr deduceMMASchedule( const GPUMatmulShapeType &problem, ArrayRef intrinsics, const GPUMMAHeuristicSeeds &seeds, int64_t sharedMemLimitInBytes, - int64_t subgroupSize, std::optional wgpCount, bool transposedLhs, - bool transposedRhs, bool canUpcastAcc, bool mustBeAligned, - bool doCPromotion, int64_t splitReductionTripCnt) { + int64_t subgroupSize, std::optional wgpCount, Location loc, + bool transposedLhs, bool transposedRhs, bool canUpcastAcc, + bool mustBeAligned, bool doCPromotion, int64_t splitReductionTripCnt) { SmallVector sortedIntrinsics = sortMMAIntrinsics(problem, intrinsics); @@ -673,14 +690,19 @@ FailureOr deduceMMASchedule( LDBG() << "Chosen MMA schedule:\n" << schedule; auto isValidSchedule = [&](const GPUMMASchedule &schedule) -> bool { - int64_t lhsBitwidth = intrinsic.aType.getIntOrFloatBitWidth(); - int64_t rhsBitwidth = intrinsic.bType.getIntOrFloatBitWidth(); - int64_t resultBitwidth = intrinsic.cType.getIntOrFloatBitWidth(); + int64_t lhsBitwidth = problem.aType.getIntOrFloatBitWidth(); + int64_t rhsBitwidth = problem.bType.getIntOrFloatBitWidth(); + int64_t resultBitwidth = problem.cType.getIntOrFloatBitWidth(); + int64_t lhsScaleBitwidth = + problem.aScaleType ? problem.aScaleType.getIntOrFloatBitWidth() : 0; + int64_t rhsScaleBitwidth = + problem.bScaleType ? problem.bScaleType.getIntOrFloatBitWidth() : 0; bool isAligned = isValidMMASchedule(problem, schedule, mustBeAligned, subgroupSize, transposedLhs, transposedRhs); int64_t sharedMemoryUsed = calculateOperandsSharedMemoryUsedInBytes( - schedule, lhsBitwidth, rhsBitwidth, problem.numHorizontallyFusedOps); + schedule, lhsBitwidth, rhsBitwidth, lhsScaleBitwidth, + rhsScaleBitwidth, problem.numHorizontallyFusedOps); // Add accumulator/result memory when it uses shared memory (LDS): // - Result needs padding in shared memory, OR // - matmul_accumulate loads accumulator from global memory via shared mem @@ -694,7 +716,15 @@ FailureOr deduceMMASchedule( LDBG() << "Available Shared Memory: " << sharedMemLimitInBytes << " bytes" << "Predicted Shared Memory Used by Schedule: " << sharedMemoryUsed << " bytes"; - return isAligned && sharedMemoryUsed <= sharedMemLimitInBytes; + + bool isValid = isAligned && sharedMemoryUsed <= sharedMemLimitInBytes; + if (isValid) { + // Only emit remark for the shared memory usage of the valid schedule. + remark::analysis(loc, remark::RemarkOpts::name("SharedMemoryUsage") + .category("deduceMMASchedule")) + << std::to_string(sharedMemoryUsed); + } + return isValid; }; return fitScheduleInSharedMemory(schedule, isValidSchedule); } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h index d613fb4e8c6d..16902b4064b6 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h @@ -20,9 +20,13 @@ struct GPUMatmulShapeType { SmallVector nSizes; SmallVector kSizes; SmallVector batchSizes; + Type aType; Type bType; Type cType; + Type aScaleType; + Type bScaleType; + GemmSize gemmSize = GemmSize::NotSet; // Number of horizontally fused operations. @@ -34,11 +38,14 @@ struct GPUMatmulShapeType { int64_t numHorizontallyFusedOps = 1) : mSizes({m}), nSizes({n}), kSizes({k}), batchSizes({}), aType(a), bType(b), cType(c), numHorizontallyFusedOps(numHorizontallyFusedOps) {} + GPUMatmulShapeType(ArrayRef m, ArrayRef n, ArrayRef k, ArrayRef batch, Type a, - Type b, Type c, int64_t numHorizontallyFusedOps = 1) + Type b, Type c, Type aScale = nullptr, + Type bScale = nullptr, int64_t numHorizontallyFusedOps = 1) : mSizes(m), nSizes(n), kSizes(k), batchSizes(batch), aType(a), bType(b), - cType(c), numHorizontallyFusedOps(numHorizontallyFusedOps) {} + cType(c), aScaleType(aScale), bScaleType(bScale), + numHorizontallyFusedOps(numHorizontallyFusedOps) {} }; /// Struct containing information about a GPU MMA intrinsic type. @@ -147,7 +154,7 @@ struct GPUMMASchedule { FailureOr deduceMMASchedule( const GPUMatmulShapeType &problem, ArrayRef intrinsics, const GPUMMAHeuristicSeeds &seeds, int64_t sharedMemLimitInBytes, - int64_t subgroupSize, std::optional cuCount, + int64_t subgroupSize, std::optional cuCount, Location loc, bool transposedLhs = false, bool transposedRhs = false, bool canUpcastAcc = false, bool mustBeAligned = true, bool doCPromotion = false, int64_t splitReductionTripCnt = 0); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMultiBuffering.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMultiBuffering.cpp index fc2a04e2d30d..192dce67de4f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMultiBuffering.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMultiBuffering.cpp @@ -32,14 +32,16 @@ struct GPUMultiBufferingPass final SmallVector allocs; // Collect all the alloc operations. funcOp.walk([&](memref::AllocOp allocOp) { - if (hasSharedMemoryAddressSpace(allocOp.getType())) + if (hasSharedMemoryAddressSpace(allocOp.getType())) { allocs.push_back(allocOp); + } }); assert(funcOp.getBlocks().size() == 1); for (memref::AllocOp allocOp : allocs) { - if (allocOp->getParentOp() != funcOp) + if (allocOp->getParentOp() != funcOp) { allocOp->moveBefore(&*funcOp.begin()->begin()); + } } // Then perform multibuffering transformations. @@ -50,8 +52,9 @@ struct GPUMultiBufferingPass final // Skip allocations not used in a loop. for (Operation *user : allocOp->getUsers()) { auto loop = user->getParentOfType(); - if (!loop) + if (!loop) { return WalkResult::advance(); + } } allocs.push_back(allocOp); return WalkResult::advance(); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp index 6ae77aea687f..a4923e0dc3f3 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp @@ -36,8 +36,9 @@ using namespace mlir::iree_compiler::IREE::VectorExt; using VectorValue = TypedValue; static bool isBroadcast(AffineExpr expr) { - if (auto constExpr = dyn_cast(expr)) + if (auto constExpr = dyn_cast(expr)) { return constExpr.getValue() == 0; + } return false; } @@ -81,8 +82,9 @@ static SmallVector getTransferIndicesFromNestedLayout( // a constant less than `elementCount`, we can do this, unlocking // potential optimizations. bool disjoint = false; - if (std::optional offsetConst = getConstantIntValue(offset)) + if (std::optional offsetConst = getConstantIntValue(offset)) { disjoint = *offsetConst < elementCount; + } slicedIndices[pos] = affine::AffineLinearizeIndexOp::create(b, loc, ids, sizes, disjoint); } @@ -222,8 +224,9 @@ static LogicalResult populateWarpAndThreadIndices( int64_t rank = vectorLayout.getRank(); SmallVector threadIds = vectorLayout.computeThreadIds(threadId, subgroupSize, rewriter); - if (threadIds.empty() && rank != 0) + if (threadIds.empty() && rank != 0) { return failure(); + } warpIndices = SmallVector(threadIds.begin(), threadIds.begin() + rank); threadIndices = SmallVector(threadIds.begin() + rank, threadIds.begin() + 2 * rank); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPackToIntrinsics.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPackToIntrinsics.cpp index f90411db7f8f..8057b238432c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPackToIntrinsics.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPackToIntrinsics.cpp @@ -131,8 +131,9 @@ struct PackDestinationForOp final : OpRewritePattern { // Get the enclosing scf.for op. auto parentOp = yieldOp->getParentOp(); auto forOp = dyn_cast(parentOp); - if (!forOp) + if (!forOp) { return failure(); + } linalg::UnPackOp unpackOp; linalg::PackOp packOp; diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.cpp index de963c3ba1fd..a38a177f9a42 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.cpp @@ -54,11 +54,13 @@ struct FlattenTransferReadOp : public OpRewritePattern { Value source = transferReadOp.getBase(); MemRefType sourceType = dyn_cast(source.getType()); // Contiguity check is valid on tensors only. - if (!sourceType) + if (!sourceType) { return failure(); + } // Already 2D or lower nothing to do. - if (vectorType.getRank() < 3) + if (vectorType.getRank() < 3) { return failure(); + } // The innermost dim is always considered non-unit as it wont be dropped // Therefore, we initialize `numberOfNonUnitDims` to 1 and not 0 int numberOfNonUnitDims = 1; @@ -86,12 +88,15 @@ struct FlattenTransferReadOp : public OpRewritePattern { } int rankOfCollapsedVector = 2; // TODO: generalize this pattern, relax the requirements here. - if (transferReadOp.hasOutOfBoundsDim()) + if (transferReadOp.hasOutOfBoundsDim()) { return failure(); - if (!transferReadOp.getPermutationMap().isMinorIdentity()) + } + if (!transferReadOp.getPermutationMap().isMinorIdentity()) { return failure(); - if (transferReadOp.getMask()) + } + if (transferReadOp.getMask()) { return failure(); + } ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr( SmallVector(rankOfCollapsedVector, true)); auto newidentityMap = @@ -113,8 +118,9 @@ struct FlattenTransferReadOp : public OpRewritePattern { SmallVector subViewOffsets, subViewSizes, subViewStrides; subViewSizes.append(sourceType.getRank() - vectorType.getRank(), rewriter.getIndexAttr(1)); - for (int64_t dim : vectorType.getShape()) + for (int64_t dim : vectorType.getShape()) { subViewSizes.push_back(rewriter.getIndexAttr(dim)); + } for (int i = 0; i < sourceType.getRank(); i++) { subViewOffsets.push_back(transferReadOp.getIndices()[i]); subViewStrides.push_back(rewriter.getIndexAttr(1)); @@ -136,8 +142,9 @@ struct FlattenTransferReadOp : public OpRewritePattern { rewriter, loc, vectorTypeBroadcast, readCollapse); SmallVector transposePermutation; for (int i = 0; i < vectorType.getRank(); i++) { - if (i == vectorType.getRank() - 2) + if (i == vectorType.getRank() - 2) { continue; + } transposePermutation.push_back(i); } transposePermutation.insert(transposePermutation.begin() + @@ -186,8 +193,9 @@ struct CombineTransferReadOpBroadcast final /// Returns true if op is appropriate contract for promotion. static LogicalResult contractOpFilter(Operation *op) { auto linalgOp = dyn_cast(op); - if (!linalgOp) + if (!linalgOp) { return failure(); + } // Limit promotion to matmul and batch matmul, there may be generic // ops with more batch dimensions we didn't distribute and therefore // cannot find a higher bound. @@ -206,8 +214,9 @@ struct DropSharedMemoryDeallocOp : public OpRewritePattern { LogicalResult matchAndRewrite(memref::DeallocOp op, PatternRewriter &rewriter) const override { if (!hasSharedMemoryAddressSpace( - cast(op.getMemref().getType()))) + cast(op.getMemref().getType()))) { return failure(); + } rewriter.eraseOp(op); return success(); } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPipelining.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPipelining.cpp index b17a554b3684..d13f6b6ece7f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPipelining.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPipelining.cpp @@ -48,22 +48,26 @@ static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter, // speculatively. if (!isa(op)) { // Return/execute the op if it is a side effect free. - if (mlir::isMemoryEffectFree(op)) + if (mlir::isMemoryEffectFree(op)) { return op; + } // Return/execute the op if it is barrier, commit group, or ldmatrix op. if (isa(op)) + nvgpu::DeviceAsyncWaitOp>(op)) { return op; + } // Return/execute the op if it is a shared memory load. if (auto loadOp = dyn_cast(op)) { auto loadBaseType = cast(loadOp.getBase().getType()); - if (hasSharedMemoryAddressSpace(loadBaseType)) + if (hasSharedMemoryAddressSpace(loadBaseType)) { return op; + } } if (auto loadOp = dyn_cast(op)) { auto loadBaseType = loadOp.getMemRefType(); - if (hasSharedMemoryAddressSpace(loadBaseType)) + if (hasSharedMemoryAddressSpace(loadBaseType)) { return op; + } } // If we are here that means the operation does not have predication support // and cannot be speculatively executed. Thus, unpeeled epilogue is not @@ -107,12 +111,14 @@ static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter, /// set. static void addDepOps(llvm::SmallDenseSet &dep, Operation *op, Block *block) { - if (!dep.insert(op).second) + if (!dep.insert(op).second) { return; + } for (Value operand : op->getOperands()) { Operation *defOp = operand.getDefiningOp(); - if (defOp && defOp->getBlock() == block) + if (defOp && defOp->getBlock() == block) { addDepOps(dep, defOp, block); + } } } @@ -123,8 +129,9 @@ static void getPipelineStages(scf::ForOp forOp, std::vector> &ops, unsigned depth) { - if (!forOp->hasAttr(kPipeliningLoopMarker)) + if (!forOp->hasAttr(kPipeliningLoopMarker)) { return; + } // Track dependencies of stage 0 ops. llvm::SmallDenseSet loadDep; @@ -138,12 +145,14 @@ getPipelineStages(scf::ForOp forOp, // stage `maxDepth`. In order to have a correct scheduling even with back // edges we order stages in decreasing order. for (Operation &op : forOp.getBody()->getOperations()) { - if (!loadDep.count(&op) && !isa(op)) + if (!loadDep.count(&op) && !isa(op)) { ops.push_back(std::make_pair(&op, depth)); + } } for (Operation &op : forOp.getBody()->getOperations()) { - if (loadDep.count(&op)) + if (loadDep.count(&op)) { ops.push_back(std::make_pair(&op, 0)); + } } } @@ -156,8 +165,9 @@ static void setAsyncAnnotations(Operation *op, // copies in flight. bool copyBeforeLoad = schedule == PipeliningSchedulingStrategy::nvidiaTensorCore; - if (waitOp.getNumGroups()) + if (waitOp.getNumGroups()) { return; + } int numGroupInFlight = 0; if (part == scf::PipeliningOption::PipelinerPart::Kernel || part == scf::PipeliningOption::PipelinerPart::Prologue) { @@ -178,8 +188,9 @@ static void setAsyncAnnotations(Operation *op, schedule == PipeliningSchedulingStrategy::loadStoreStage0 ? 0 : 1; if (pipelineStoreStage != 0 || part != mlir::scf::PipeliningOption::PipelinerPart::Prologue || - iteration >= depth - 1) + iteration >= depth - 1) { return; + } OpBuilder b(op); barrierOp->setAttr(kPipeliningExtraBarrier, b.getUnitAttr()); } @@ -194,12 +205,14 @@ static bool setPipeliningMarkers(scf::ForOp forOp, bool pipelineStoreStage) { SmallVector barriers; for (Operation &op : forOp.getBody()->getOperations()) { // Pipeline the most inner for op that should be a flat region. - if (op.getNumRegions() > 0) + if (op.getNumRegions() > 0) { return false; + } if (isa(op)) { barriers.push_back(&op); - if (pipelineStoreStage == 0) + if (pipelineStoreStage == 0) { op.setAttr(kPipeliningFirstStage, builder.getUnitAttr()); + } } if (isa(op)) { copyToWorkgroupMemory = true; @@ -212,21 +225,26 @@ static bool setPipeliningMarkers(scf::ForOp forOp, bool pipelineStoreStage) { continue; } auto ld = dyn_cast(op); - if (!ld) + if (!ld) { continue; + } auto ldSrcType = cast(ld.getBase().getType()); - if (!hasGlobalMemoryAddressSpace(ldSrcType) || !ld->hasOneUse()) + if (!hasGlobalMemoryAddressSpace(ldSrcType) || !ld->hasOneUse()) { continue; + } auto st = dyn_cast(ld->use_begin()->getOwner()); - if (!st) + if (!st) { continue; + } auto stSrcType = cast(st.getBase().getType()); - if (!hasSharedMemoryAddressSpace(stSrcType)) + if (!hasSharedMemoryAddressSpace(stSrcType)) { continue; + } copyToWorkgroupMemory = true; ld->setAttr(kPipeliningFirstStage, builder.getUnitAttr()); - if (pipelineStoreStage == 0) + if (pipelineStoreStage == 0) { st->setAttr(kPipeliningFirstStage, builder.getUnitAttr()); + } } if (copyToWorkgroupMemory) { forOp->setAttr(kPipeliningLoopMarker, builder.getUnitAttr()); @@ -287,14 +305,16 @@ struct MainLoopInfo { // of some other op. void backwardSliceOfDependentOps(llvm::SetVector &dependentOps, Operation *op, Block *block) { - if (!seenDepOps.insert(op)) + if (!seenDepOps.insert(op)) { return; + } // Add the unseen op to the dependentOps and recurse on its operands. dependentOps.insert(op); for (Value operand : op->getOperands()) { Operation *defOp = operand.getDefiningOp(); - if (defOp && defOp->getBlock() == block) + if (defOp && defOp->getBlock() == block) { backwardSliceOfDependentOps(dependentOps, defOp, block); + } } } @@ -304,8 +324,9 @@ struct MainLoopInfo { void mmaOperandDefOperation(Operation *op, llvm::SetVector &defOperation, Block *block) { - if (!op) + if (!op) { return; + } // If the operations defining the mma.sync's operand is one of the // qualifying operations, add the operations to the current kgroup defining @@ -326,14 +347,16 @@ struct MainLoopInfo { void vistMmaSyncOp(Operation *op, int kgroup) { // if the operation in an `scf.yield`, we reached the end of MmaSyncOp chain // return. - if (seenMmaOps.count(op) || isa(op)) + if (seenMmaOps.count(op) || isa(op)) { return; + } seenMmaOps.insert(op); // If the kgroup is not in the vector, create a new WarpMmaOp. - if (warpOperations.size() < kgroup + 1) + if (warpOperations.size() < kgroup + 1) { warpOperations.push_back(WarpMmaOp()); + } mmaOperandDefOperation(op->getOperand(0).getDefiningOp(), warpOperations[kgroup].lhsOperations, @@ -426,8 +449,9 @@ struct MainLoopInfo { LDBG() << "-- missing warpOperations -> not schedulable"; isSchedulable = false; } - if (!isSchedulable) + if (!isSchedulable) { return; + } // Collect the dependent operations for `cp.async` in the mainloop order for // coarse-grained software pipeling. The deps are collected in stage order, @@ -552,8 +576,9 @@ static void getNvidiaAmpereTensorCorePipeline( // Issue mma.sync on previous loaded kgroup. for (Operation &op : forOp.getBody()->getOperations()) { - if (mainloop.warpOperations[kgroup].mmaOperations.count(&op)) + if (mainloop.warpOperations[kgroup].mmaOperations.count(&op)) { ops.push_back(std::make_pair(&op, numStages - 1)); + } } } @@ -565,8 +590,9 @@ static void getNvidiaAmpereTensorCorePipeline( // it at one place. // Schedule all cp.async and one cp.async.commit_group. for (Operation &op : forOp.getBody()->getOperations()) { - if (mainloop.copyGlobalToSharedOpDeps.count(&op)) + if (mainloop.copyGlobalToSharedOpDeps.count(&op)) { ops.push_back(std::make_pair(&op, 0 /*pipelineStage*/)); + } } ops.push_back( std::make_pair(mainloop.asyncCreateGroupOp[0], 0 /*pipelineStage*/)); @@ -585,14 +611,16 @@ static void getNvidiaAmpereTensorCorePipeline( // into one stage ahead. for (Operation &op : forOp.getBody()->getOperations()) { if (mainloop.warpOperations[0].lhsOperations.count(&op) || - mainloop.warpOperations[0].rhsOperations.count(&op)) + mainloop.warpOperations[0].rhsOperations.count(&op)) { ops.push_back(std::make_pair(&op, numStages - 2)); + } } // Issue mma.sync on for the last kgroup at the end of the mainloop. for (Operation &op : forOp.getBody()->getOperations()) { - if (mainloop.warpOperations[numKgroups - 1].mmaOperations.count(&op)) + if (mainloop.warpOperations[numKgroups - 1].mmaOperations.count(&op)) { ops.push_back(std::make_pair(&op, numStages - 1)); + } } // Prints the mainloop schedule generated for NVIDIA Ampere through native @@ -667,8 +695,9 @@ struct GPUPipeliningPass final // Remove extra barriers from the prologue assuming appropriate // multi-buffering. funcOp.walk([](gpu::BarrierOp barrierOp) { - if (barrierOp->hasAttr(kPipeliningExtraBarrier)) + if (barrierOp->hasAttr(kPipeliningExtraBarrier)) { barrierOp->erase(); + } }); } }; diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPromoteMatmulOperands.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPromoteMatmulOperands.cpp index d953725fd6a1..1625081df28b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPromoteMatmulOperands.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPromoteMatmulOperands.cpp @@ -68,8 +68,9 @@ void promoteResult(OpBuilder &builder, Operation *op, Value valToMakeShared) { // TODO (nirvedhmeshram) : This is fairly special case. Instead we should // just promote results before doing padding which introduces the extract // slice. - if (!valToMakeShared.hasOneUse()) + if (!valToMakeShared.hasOneUse()) { return; + } valueToReplace = extractSliceOp.getResult(); for (auto user : extractSliceOp->getUsers()) { opsToReplaceUseIn.insert(user); @@ -120,8 +121,9 @@ void promoteResult(OpBuilder &builder, Operation *op, Value valToMakeShared) { void promoteOperand(OpBuilder &builder, Operation *op, unsigned index, IREE::GPU::PromotionAttr promotionAttr) { auto dpsOp = dyn_cast(op); - if (!dpsOp) + if (!dpsOp) { return; + } // We use the convention that if we are passing an index beyond the inputs // then we promote the result of the corresponding dps init. if (index >= dpsOp.getNumDpsInputs()) { diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUReduceBankConflicts.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUReduceBankConflicts.cpp index 05ddd0cc95a4..d8a98dbecd33 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUReduceBankConflicts.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUReduceBankConflicts.cpp @@ -40,16 +40,19 @@ static bool hasCollapseShapeUser(memref::AllocOp allocOp) { static void padAlloc(MLIRContext *context, memref::AllocOp allocOp, unsigned paddingSizeBits) { auto allocOpShape = allocOp.getType().getShape(); - if (allocOpShape.empty()) + if (allocOpShape.empty()) { return; + } int64_t innerDim = allocOpShape.back(); - if (ShapedType::isDynamic(innerDim)) + if (ShapedType::isDynamic(innerDim)) { return; + } // Return if we have CollapseShape op as an user as padding in that case is // unsupported. - if (hasCollapseShapeUser(allocOp)) + if (hasCollapseShapeUser(allocOp)) { return; + } Type elType = allocOp.getType().getElementType(); unsigned bitwidth = @@ -125,8 +128,9 @@ static unsigned computeEffectiveExtraBytes(mlir::FunctionOpInterface funcOp, MemRefType allocType = cast(allocOp.getType()); ArrayRef shape = allocType.getShape(); - if (shape.empty()) + if (shape.empty()) { return; + } int outerProduct = 1; for (std::size_t i = 0; i < shape.size() - 1; ++i) { @@ -181,8 +185,9 @@ struct GPUReduceBankConflictsPass final return; } - if (failed(reduceSharedMemoryBankConflicts(funcOp, paddingBits))) + if (failed(reduceSharedMemoryBankConflicts(funcOp, paddingBits))) { signalPassFailure(); + } } }; @@ -198,8 +203,9 @@ LogicalResult reduceSharedMemoryBankConflicts(mlir::FunctionOpInterface funcOp, sharedMemAllocs.push_back(allocOp); } }); - for (memref::AllocOp alloc : sharedMemAllocs) + for (memref::AllocOp alloc : sharedMemAllocs) { padAlloc(funcOp->getContext(), alloc, paddingSize); + } // In the current form this always succeeds. return success(); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorAlloc.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorAlloc.cpp index fdf549f9d10c..9c3708e0a06b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorAlloc.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorAlloc.cpp @@ -29,8 +29,9 @@ constexpr int copyVectorNumBits = 128; /// Filter to decide which contract ops need allocations. static bool contractOpFilter(Operation *op) { auto linalgOp = dyn_cast(op); - if (!linalgOp) + if (!linalgOp) { return false; + } if (!linalg::isaContractionOpInterface(linalgOp)) { return false; @@ -39,8 +40,9 @@ static bool contractOpFilter(Operation *op) { // The workgroup specialization already makes static shapes available for the // main tile part and makes the partial tile computation small, so promoting // to shared memory for the partial tile actually hurts the performance. - if (linalgOp.hasDynamicShape()) + if (linalgOp.hasDynamicShape()) { return false; + } // Check if the shape is tile-distributable. The leading dimension must be a // multiple of the target vector size, which is 128b / the element bit width. @@ -76,8 +78,9 @@ static bool contractOpFilter(Operation *op) { /// Filter to decide which transpose ops need allocations. static bool transposeOpFilter(Operation *op) { auto linalgOp = dyn_cast(op); - if (!linalgOp) + if (!linalgOp) { return false; + } LinalgOpInfo opInfo(linalgOp, sharedMemTransposeFilter); return opInfo.isTranspose(); } @@ -101,18 +104,21 @@ struct SwapAllocTensorPattern final LogicalResult matchAndRewrite(bufferization::AllocTensorOp allocOp, PatternRewriter &rewriter) const override { - if (!allocOp.getCopy()) + if (!allocOp.getCopy()) { return failure(); + } auto linalgOp = allocOp.getCopy().getDefiningOp(); - if (!linalgOp) + if (!linalgOp) { return failure(); + } // Make sure we don't use the initial values for the linalg output we are // copying during the tensor allocation. unsigned resultNumber = cast(allocOp.getCopy()).getResultNumber(); OpOperand *initOperand = linalgOp.getDpsInitOperand(resultNumber); - if (linalgOp.payloadUsesValueFromOperand(initOperand)) + if (linalgOp.payloadUsesValueFromOperand(initOperand)) { return failure(); + } rewriter.setInsertionPoint(linalgOp); std::optional memorySpace = allocOp.getMemorySpace(); @@ -148,12 +154,14 @@ struct GPUTensorAllocPass final funcOp.walk([&](Operation *op) { switch (promoteSharedMemPattern) { case GPUPromoteSharedMemPattern::ContractionOpPattern: - if (contractOpFilter(op)) + if (contractOpFilter(op)) { opsToPromote.push_back(op); + } break; case GPUPromoteSharedMemPattern::TransposeOpPattern: - if (transposeOpFilter(op)) + if (transposeOpFilter(op)) { opsToPromote.push_back(op); + } break; } }); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp index 0f597ccba08b..328c8924297b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp @@ -45,8 +45,9 @@ class TileConsumerAndFuseInputProducer final LogicalResult matchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const override { - if (failed(filter.checkAndNotify(rewriter, op))) + if (failed(filter.checkAndNotify(rewriter, op))) { return failure(); + } // Make sure we have a PartitionableLoopInterface op here and query the tile // sizes from the partitionable loops. @@ -63,8 +64,9 @@ class TileConsumerAndFuseInputProducer final } // Mask out non reduction dimensions. for (unsigned depth : partitionedLoops) { - if (depth < tileSizes.size()) + if (depth < tileSizes.size()) { tileSizes[depth] = 0; + } } // Make sure we have a tile size for each dimension. @@ -120,11 +122,13 @@ class TileConsumerAndFuseInputProducer final return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); } - if (!fuseInputProducer) + if (!fuseInputProducer) { return tilingResult; + } // If there are no generated loops generated, fusion is immaterial. - if (tilingResult->loops.empty()) + if (tilingResult->loops.empty()) { return tilingResult; + } // Collect immediate input operands that are fusable into the tiled loop. // We have tensor extract slice ops taking slices of the untiled op. @@ -135,15 +139,18 @@ class TileConsumerAndFuseInputProducer final assert(tilingResult->tiledOps.size() == 1); Operation *tiledOp = tilingResult->tiledOps.front(); auto dsOp = dyn_cast(tiledOp); - if (!dsOp) + if (!dsOp) { return tilingResult; + } for (OpOperand *operand : dsOp.getDpsInputOperands()) { auto sliceOp = operand->get().getDefiningOp(); - if (!sliceOp) + if (!sliceOp) { continue; + } auto tilingOp = sliceOp.getSource().getDefiningOp(); - if (!tilingOp) + if (!tilingOp) { continue; + } if (isa(sliceOp.getSource().getDefiningOp())) { continue; } @@ -248,13 +255,15 @@ static LogicalResult tileParallelDims(mlir::FunctionOpInterface funcOp, for (TilingInterface tilingOp : computeOps) { auto attr = tilingOp->getAttr(LinalgTransforms::kLinalgTransformMarker); - if (attr == marker) + if (attr == marker) { continue; + } size_t numLoops = 0; for (auto type : tilingOp.getLoopIteratorTypes()) { - if (type == utils::IteratorType::parallel) + if (type == utils::IteratorType::parallel) { numLoops++; + } } IRRewriter rewriter(tilingOp->getContext()); rewriter.setInsertionPoint(tilingOp); @@ -263,8 +272,9 @@ static LogicalResult tileParallelDims(mlir::FunctionOpInterface funcOp, auto partitionedLoops = interfaceOp.getPartitionableLoops(kNumMaxParallelDims); // If there are no dimensions to tile skip the transformation. - if (partitionedLoops.empty()) + if (partitionedLoops.empty()) { continue; + } SmallVector numThreads(numLoops, rewriter.getIndexAttr(0)); int64_t id = 0, threadDim = 0; SmallVector idDims; @@ -307,8 +317,9 @@ static LogicalResult tileAndUnrollConv(mlir::FunctionOpInterface funcOp) { IRRewriter rewriter(funcOp.getContext()); SmallVector tileSizes = getAsIndexOpFoldResult( funcOp.getContext(), getTileSizes(consumerOp, 1)); - if (tileSizes.empty()) + if (tileSizes.empty()) { return success(); + } FailureOr tileAndFuseResult = scf::tileConsumerAndFuseProducersUsingSCF( @@ -375,8 +386,9 @@ struct GPUTensorTilePass final // Tile to serial loops to the wg tile size to handle reductions and other // dimension that have not been distributed. if (failed(tileReductionToSerialLoops(funcOp, /*fuseInputProducer=*/false, - /*coalesceLoops=*/false))) + /*coalesceLoops=*/false))) { return signalPassFailure(); + } LLVM_DEBUG({ llvm::dbgs() << "// --- After tile reductions:\n"; diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTile.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTile.cpp index 2fdedf111a5f..8d2397649aaa 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTile.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTile.cpp @@ -64,8 +64,9 @@ collectComputeOps(mlir::FunctionOpInterface funcOp, computeOps = getComputeOps(funcOp); for (Operation *op : computeOps) { if (auto config = - getLoweringConfig(op)) + getLoweringConfig(op)) { configs.push_back(config); + } } if (computeOps.size() > 1) { // Only keep the last compute ops. @@ -80,8 +81,9 @@ collectComputeOps(mlir::FunctionOpInterface funcOp, ifOps.front()->walk([&configs](Operation *op) { if (isa(op)) { if (auto config = - getLoweringConfig(op)) + getLoweringConfig(op)) { configs.push_back(config); + } } }); @@ -276,8 +278,9 @@ struct GPUTilePass final : impl::GPUTilePassBase { SmallVector computeOps; FailureOr loweringConfig = collectComputeOps(funcOp, computeOps); - if (failed(loweringConfig)) + if (failed(loweringConfig)) { return signalPassFailure(); + } assert(computeOps.size() <= 2); // Now tile the last computation op to invocations and fuse all operand @@ -286,8 +289,9 @@ struct GPUTilePass final : impl::GPUTilePassBase { for (Operation *computeOp : computeOps) { auto consumerOp = dyn_cast(computeOp); if (!consumerOp || - failed(tileAndDistributeToThreads(consumerOp, threadTileSizes))) + failed(tileAndDistributeToThreads(consumerOp, threadTileSizes))) { return signalPassFailure(); + } } LLVM_DEBUG({ diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileAndConvertConvToMatmul.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileAndConvertConvToMatmul.cpp index 8c19c39c02ff..19be905eb779 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileAndConvertConvToMatmul.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileAndConvertConvToMatmul.cpp @@ -57,8 +57,9 @@ void static removeUnitExtentDimsfromMaps(linalg::LinalgOp linalgOp, return; } SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); - if (indexingMaps.empty()) + if (indexingMaps.empty()) { return; + } AffineMap inputMap = indexingMaps[0]; AffineMap filterMap = indexingMaps[1]; AffineMap outputMap = indexingMaps[2]; diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileReduction.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileReduction.cpp index 4231691da23a..a186ae6e17ba 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileReduction.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileReduction.cpp @@ -25,12 +25,14 @@ static LogicalResult tileReduction(linalg::LinalgOp op) { SmallVector dims; op.getReductionDims(dims); SmallVector tileSize = getTileSizes(op, 1); - if (tileSize.empty()) + if (tileSize.empty()) { return success(); + } // Make sure reduction dimensions are the innermost ones. for (int i = 0; i < dims.size(); ++i) { - if (dims[dims.size() - 1 - i] != op.getNumLoops() - 1 - i) + if (dims[dims.size() - 1 - i] != op.getNumLoops() - 1 - i) { return success(); + } } IRRewriter rewriter(op.getContext()); SmallVector sizes; @@ -40,8 +42,9 @@ static LogicalResult tileReduction(linalg::LinalgOp op) { rewriter.setInsertionPoint(op); FailureOr results = scf::tileReductionUsingScf( rewriter, cast(op.getOperation()), sizes); - if (failed(results)) + if (failed(results)) { return failure(); + } rewriter.replaceOp(op, results->replacements); return success(); } @@ -50,14 +53,16 @@ static LogicalResult tileFusedOps(linalg::LinalgOp op) { IRRewriter rewriter(op.getContext()); rewriter.setInsertionPoint(op); SmallVector tileSizes = getTileSizes(op, 1); - if (tileSizes.empty()) + if (tileSizes.empty()) { return success(); + } linalg::LinalgTilingOptions tileOption; tileOption.setTileSizes(tileSizes); FailureOr tiledOps = linalg::tileLinalgOp(rewriter, op, tileOption); - if (failed(tiledOps)) + if (failed(tiledOps)) { return failure(); + } rewriter.replaceOp(op, tiledOps->tensorResults); return success(); } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp index f7ab3452fbce..da283d2f3248 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp @@ -321,8 +321,9 @@ static void applyVectorDistribution(Operation *root, while (!worklist.empty()) { Operation *op = worklist.front(); worklist.pop_front(); - if (op == nullptr) + if (op == nullptr) { continue; + } LLVM_DEBUG(llvm::dbgs() << "Distributing: "); LLVM_DEBUG(op->print(llvm::dbgs(), OpPrintingFlags().skipRegions())); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td index 78352069d7a6..a86a1284a28a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td @@ -383,6 +383,14 @@ def GPUApplyPaddingLevelPass : ]; } +def GPUExpandDimensionsPass : + InterfacePass<"iree-codegen-gpu-expand-dimensions", "mlir::FunctionOpInterface"> { + let summary = "Pass to expand tensor op dims based on `expand_dims` lowering_config"; + let dependentDialects = [ + "::mlir::iree_compiler::IREE::Util::UtilDialect" + ]; +} + def GPUTensorTileToSerialLoopsPass : InterfacePass<"iree-codegen-gpu-tensor-tile-to-serial-loops", "mlir::FunctionOpInterface"> { let summary = "Pass to tile reduction dimensions for certain GPU ops"; diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp index 3f3f2a0c649b..0946e7374918 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp @@ -62,14 +62,18 @@ static bool isUniformLoad(Operation *op) { using namespace IREE::HAL; auto loadOp = dyn_cast(op); - if (!loadOp) + if (!loadOp) { return false; - if (!hasGlobalMemoryAddressSpace(loadOp.getMemRefType())) + } + if (!hasGlobalMemoryAddressSpace(loadOp.getMemRefType())) { return false; + } auto space = loadOp.getMemRefType().getMemorySpace(); auto descTypeAttr = dyn_cast_if_present(space); - if (descTypeAttr && descTypeAttr.getValue() == DescriptorType::UniformBuffer) + if (descTypeAttr && + descTypeAttr.getValue() == DescriptorType::UniformBuffer) { return true; + } auto subspan = loadOp.getMemRef().getDefiningOp(); if (auto fatBufferCast = @@ -77,16 +81,20 @@ static bool isUniformLoad(Operation *op) { subspan = fatBufferCast.getSource().getDefiningOp(); } - if (!subspan) + if (!subspan) { return false; + } descTypeAttr = dyn_cast_if_present( cast(subspan.getResult().getType()).getMemorySpace()); - if (descTypeAttr && descTypeAttr.getValue() == DescriptorType::UniformBuffer) + if (descTypeAttr && + descTypeAttr.getValue() == DescriptorType::UniformBuffer) { return true; + } if (auto flags = subspan.getDescriptorFlags()) { - if (bitEnumContainsAll(*flags, IREE::HAL::DescriptorFlags::ReadOnly)) + if (bitEnumContainsAll(*flags, IREE::HAL::DescriptorFlags::ReadOnly)) { return true; + } } return false; } @@ -97,18 +105,24 @@ static void moveScalarAndBindingUniformCode(gpu::WarpExecuteOnLane0Op warpOp) { /// Hoist ops without side effect as well as special binding ops. auto canBeHoisted = [](Operation *op, function_ref definedOutside) { - if (op->getNumRegions() != 0) + if (op->getNumRegions() != 0) { return false; - if (!llvm::all_of(op->getOperands(), definedOutside)) + } + if (!llvm::all_of(op->getOperands(), definedOutside)) { return false; - if (isMemoryEffectFree(op)) + } + if (isMemoryEffectFree(op)) { return true; + } if (isa(op)) + IREE::HAL::InterfaceConstantLoadOp, memref::AssumeAlignmentOp>( + op)) { return true; - if (isUniformLoad(op)) + } + if (isUniformLoad(op)) { return true; + } // Shared memory is already scoped to the workgroup and can safely be // hoisted out of the the warp op. if (auto allocOp = dyn_cast(op)) { @@ -144,8 +158,9 @@ static void moveScalarAndBindingUniformCode(gpu::WarpExecuteOnLane0Op warpOp) { } // Move all the ops marked as uniform outside of the region. - for (Operation *op : opsToMove) + for (Operation *op : opsToMove) { op->moveBefore(warpOp); + } } /// Pattern to convert single element vector.insert to broadcast, this is a @@ -155,8 +170,9 @@ struct InsertToBroadcast final : OpRewritePattern { LogicalResult matchAndRewrite(vector::InsertOp insertOp, PatternRewriter &rewriter) const override { - if (insertOp.getDestVectorType().getNumElements() != 1) + if (insertOp.getDestVectorType().getNumElements() != 1) { return failure(); + } rewriter.replaceOpWithNewOp( insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore()); return success(); @@ -173,8 +189,9 @@ struct WarpOpBarrier final : OpRewritePattern { warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); Operation *lastNode = yield->getPrevNode(); auto barrierOp = dyn_cast_if_present(lastNode); - if (!barrierOp) + if (!barrierOp) { return failure(); + } rewriter.setInsertionPointAfter(warpOp); (void)gpu::BarrierOp::create(rewriter, barrierOp.getLoc()); @@ -274,8 +291,9 @@ struct VectorReductionToGPUPass final }; auto distributionFn = [](Value val) { auto vecType = dyn_cast(val.getType()); - if (!vecType) + if (!vecType) { return AffineMap::get(val.getContext()); + } // Create an identity dim map of rank |vecRank|. This greedily divides // threads along the outermost vector dimensions to the innermost ones. int64_t vecRank = vecType.getRank(); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/WorkgroupReordering.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/WorkgroupReordering.cpp index c567745392b3..3faa7a244368 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/WorkgroupReordering.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/WorkgroupReordering.cpp @@ -139,20 +139,23 @@ struct ReorderWorkgroupsPass final .Case("", ReorderWorkgroupsStrategy::None) .Case("transpose", ReorderWorkgroupsStrategy::Transpose) .Default(failure()); - if (failed(selectedStrategy)) + if (failed(selectedStrategy)) { return failure(); + } reorderingStrategy = *selectedStrategy; return success(); } void runOnOperation() override { - if (reorderingStrategy == ReorderWorkgroupsStrategy::None) + if (reorderingStrategy == ReorderWorkgroupsStrategy::None) { return; + } FunctionOpInterface funcOp = getOperation(); - if (filterFn && failed(filterFn(funcOp))) + if (filterFn && failed(filterFn(funcOp))) { return; + } LLVM_DEBUG({ llvm::dbgs() << "--- Before reorder workgroups with workgroup counts ---"; diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel index c36e063de06b..ed2b9576967e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel @@ -17,9 +17,11 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "amdgpu_lower_coalesced_dma_to_gather_lds.mlir", "decompose_horizontally_fused_gemms.mlir", + "flatten_swizzle_hint_allocs.mlir", "gpu_alloc_private_memory_for_dps_ops.mlir", "gpu_apply_derived_thread_config.mlir", "gpu_apply_padding_online_attention.mlir", @@ -36,6 +38,8 @@ iree_lit_test_suite( "gpu_distribute_forall.mlir", "gpu_distribute_scf_for.mlir", "gpu_distribute_shared_memory.mlir", + "gpu_expand_dimensions.mlir", + "gpu_fold_swizzle_hint_ops.mlir", "gpu_fuse_and_hoist_forall.mlir", "gpu_generalize_named_ops.mlir", "gpu_greedily_distribute_to_threads.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt index 002a332570b4..9e626933d760 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt @@ -16,6 +16,7 @@ iree_lit_test_suite( SRCS "amdgpu_lower_coalesced_dma_to_gather_lds.mlir" "decompose_horizontally_fused_gemms.mlir" + "flatten_swizzle_hint_allocs.mlir" "gpu_alloc_private_memory_for_dps_ops.mlir" "gpu_apply_derived_thread_config.mlir" "gpu_apply_padding_online_attention.mlir" @@ -32,6 +33,8 @@ iree_lit_test_suite( "gpu_distribute_forall.mlir" "gpu_distribute_scf_for.mlir" "gpu_distribute_shared_memory.mlir" + "gpu_expand_dimensions.mlir" + "gpu_fold_swizzle_hint_ops.mlir" "gpu_fuse_and_hoist_forall.mlir" "gpu_generalize_named_ops.mlir" "gpu_greedily_distribute_to_threads.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/flatten_swizzle_hint_allocs.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/flatten_swizzle_hint_allocs.mlir new file mode 100644 index 000000000000..b571927fa818 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/flatten_swizzle_hint_allocs.mlir @@ -0,0 +1,96 @@ +// RUN: iree-opt --allow-unregistered-dialect --pass-pipeline="builtin.module(func.func(iree-codegen-flatten-swizzle-hint-allocs))" \ +// RUN: --mlir-print-local-scope %s | FileCheck %s + +// Test: 1D alloc should NOT be flattened (already 1D). +func.func @skip_1d_alloc() { + %alloc = memref.alloc() : memref<2048xf32, #gpu.address_space> + %0 = iree_codegen.swizzle_hint %alloc[#iree_codegen.rotate_rows<64, 4>] : memref<2048xf32, #gpu.address_space> + "test.use"(%0) : (memref<2048xf32, #gpu.address_space>) -> () + return +} + +// CHECK-LABEL: func @skip_1d_alloc +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<2048xf32, #gpu.address_space> +// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ALLOC]][#iree_codegen.rotate_rows<64, 4>] : memref<2048xf32, #gpu.address_space> +// CHECK-NOT: memref.expand_shape +// CHECK: "test.use"(%[[HINT]]) + +// Test: 2D alloc with swizzle hint should be flattened to 1D. +func.func @flatten_2d_alloc() { + %alloc = memref.alloc() : memref<32x64xf32, #gpu.address_space> + %0 = iree_codegen.swizzle_hint %alloc[#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32, #gpu.address_space> + "test.use"(%0) : (memref<32x64xf32, #gpu.address_space>) -> () + return +} + +// CHECK-LABEL: func @flatten_2d_alloc +// CHECK: %[[ALLOC1D:.+]] = memref.alloc() : memref<2048xf32, #gpu.address_space> +// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ALLOC1D]][#iree_codegen.rotate_rows<64, 4>] : memref<2048xf32, #gpu.address_space> +// CHECK: %[[EXPAND:.+]] = memref.expand_shape %[[HINT]] {{\[\[}}0, 1{{\]\]}} output_shape [32, 64] : memref<2048xf32, #gpu.address_space> into memref<32x64xf32, #gpu.address_space> +// CHECK: "test.use"(%[[EXPAND]]) +// CHECK-NOT: memref.alloc() : memref<32x64xf32 +// CHECK-NOT: iree_codegen.swizzle_hint {{.*}} : memref<32x64xf32 +// CHECK: return + +// Test: 3D alloc with swizzle hint should be flattened to 1D. +func.func @flatten_3d_alloc() { + %alloc = memref.alloc() : memref<4x8x16xf32, #gpu.address_space> + %0 = iree_codegen.swizzle_hint %alloc[#iree_codegen.rotate_rows<64, 4>] : memref<4x8x16xf32, #gpu.address_space> + "test.use"(%0) : (memref<4x8x16xf32, #gpu.address_space>) -> () + return +} + +// CHECK-LABEL: func @flatten_3d_alloc +// CHECK: %[[ALLOC1D:.+]] = memref.alloc() : memref<512xf32, #gpu.address_space> +// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ALLOC1D]][#iree_codegen.rotate_rows<64, 4>] : memref<512xf32, #gpu.address_space> +// CHECK: %[[EXPAND:.+]] = memref.expand_shape %[[HINT]] {{\[\[}}0, 1, 2{{\]\]}} output_shape [4, 8, 16] : memref<512xf32, #gpu.address_space> into memref<4x8x16xf32, #gpu.address_space> +// CHECK: "test.use"(%[[EXPAND]]) +// CHECK-NOT: memref.alloc() : memref<4x8x16xf32 +// CHECK-NOT: iree_codegen.swizzle_hint {{.*}} : memref<4x8x16xf32 +// CHECK: return + +// Test: Non-alloc operand should NOT be affected. +func.func @skip_non_alloc(%arg0: memref<32x64xf32, #gpu.address_space>) { + %0 = iree_codegen.swizzle_hint %arg0[#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32, #gpu.address_space> + "test.use"(%0) : (memref<32x64xf32, #gpu.address_space>) -> () + return +} + +// CHECK-LABEL: func @skip_non_alloc +// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: memref<32x64xf32, #gpu.address_space> +// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ARG0]][#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32, #gpu.address_space> +// CHECK-NOT: memref.expand_shape +// CHECK: "test.use"(%[[HINT]]) + +// Test: Alloc with multiple uses should NOT be flattened. +func.func @skip_multi_use_alloc() { + %alloc = memref.alloc() : memref<32x64xf32, #gpu.address_space> + %0 = iree_codegen.swizzle_hint %alloc[#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32, #gpu.address_space> + "test.use"(%alloc) : (memref<32x64xf32, #gpu.address_space>) -> () + "test.use"(%0) : (memref<32x64xf32, #gpu.address_space>) -> () + return +} + +// CHECK-LABEL: func @skip_multi_use_alloc +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x64xf32, #gpu.address_space> +// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ALLOC]][#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32, #gpu.address_space> +// CHECK-NOT: memref.expand_shape +// CHECK: "test.use"(%[[ALLOC]]) +// CHECK: "test.use"(%[[HINT]]) + +// Test: XOR shuffle swizzle attribute. +func.func @flatten_xor_shuffle() { + %alloc = memref.alloc() : memref<16x128xi8, #gpu.address_space> + %0 = iree_codegen.swizzle_hint %alloc[#iree_codegen.xor_shuffle<128, 16>] : memref<16x128xi8, #gpu.address_space> + "test.use"(%0) : (memref<16x128xi8, #gpu.address_space>) -> () + return +} + +// CHECK-LABEL: func @flatten_xor_shuffle +// CHECK: %[[ALLOC1D:.+]] = memref.alloc() : memref<2048xi8, #gpu.address_space> +// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ALLOC1D]][#iree_codegen.xor_shuffle<128, 16>] : memref<2048xi8, #gpu.address_space> +// CHECK: %[[EXPAND:.+]] = memref.expand_shape %[[HINT]] {{\[\[}}0, 1{{\]\]}} output_shape [16, 128] : memref<2048xi8, #gpu.address_space> into memref<16x128xi8, #gpu.address_space> +// CHECK: "test.use"(%[[EXPAND]]) +// CHECK-NOT: memref.alloc() : memref<16x128xi8 +// CHECK-NOT: iree_codegen.swizzle_hint {{.*}} : memref<16x128xi8 +// CHECK: return diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir index 348bc4db92be..d5e21133494c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir @@ -735,3 +735,42 @@ module { // SERIAL: linalg.generic // SERIAL: scf.forall.in_parallel // SERIAL-NOT: mapping + +// ----- + +func.func @matmul_transpose_b_with_swizzle(%5: tensor<64x64xf32>, %6: tensor<64x1280xf16>, %7: tensor<64x1280xf16>) -> tensor<64x64xf32> { + %c4 = arith.constant 4 : index + %c1280 = arith.constant 1280 : index + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %8 = linalg.fill ins(%cst : f32) outs(%5 : tensor<64x64xf32>) -> tensor<64x64xf32> + %9 = tensor.empty() : tensor<64x1280xf16> + %swizzle_9 = iree_codegen.swizzle_hint %9[#iree_codegen.xor_shuffle<256, 32>] : tensor<64x1280xf16> + %10 = tensor.empty() : tensor<64x1280xf16> + %swizzle_10 = iree_codegen.swizzle_hint %10[#iree_codegen.xor_shuffle<256, 32>] : tensor<64x1280xf16> + %11 = scf.for %arg0 = %c0 to %c1280 step %c4 iter_args(%arg1 = %8) -> (tensor<64x64xf32>) { + %extracted_slice = tensor.extract_slice %6[0, %arg0] [64, 4] [1, 1] : tensor<64x1280xf16> to tensor<64x4xf16> + %extracted_slice_0 = tensor.extract_slice %swizzle_9[0, %arg0] [64, 4] [1, 1] : tensor<64x1280xf16> to tensor<64x4xf16> + %12 = linalg.copy {lowering_config = #iree_gpu.lowering_config<{thread = [1, 1]}>} ins(%extracted_slice : tensor<64x4xf16>) outs(%extracted_slice_0 : tensor<64x4xf16>) -> tensor<64x4xf16> + %extracted_slice_1 = tensor.extract_slice %7[0, %arg0] [64, 4] [1, 1] : tensor<64x1280xf16> to tensor<64x4xf16> + %extracted_slice_2 = tensor.extract_slice %swizzle_10[0, %arg0] [64, 4] [1, 1] : tensor<64x1280xf16> to tensor<64x4xf16> + %13 = linalg.copy {lowering_config = #iree_gpu.lowering_config<{thread = [1, 1]}>} ins(%extracted_slice_1 : tensor<64x4xf16>) outs(%extracted_slice_2 : tensor<64x4xf16>) -> tensor<64x4xf16> + %14 = linalg.matmul + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + {lowering_config = #iree_gpu.lowering_config<{thread = [4, 4]}>} + ins(%12, %13 : tensor<64x4xf16>, tensor<64x4xf16>) + outs(%arg1 : tensor<64x64xf32>) -> tensor<64x64xf32> + scf.yield %14 : tensor<64x64xf32> + } + return %11 : tensor<64x64xf32> +} + +// CHECK-LABEL: func.func @matmul_transpose_b_with_swizzle + +// THREAD-LABEL: func.func @matmul_transpose_b_with_swizzle +// THREAD: %2 = tensor.empty() : tensor<64x4xf16> +// THREAD: %3 = iree_codegen.swizzle_hint %2[#iree_codegen.xor_shuffle<256, 32>] : tensor<64x4xf16> diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_expand_dimensions.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_expand_dimensions.mlir new file mode 100644 index 000000000000..d30de91e5fc3 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_expand_dimensions.mlir @@ -0,0 +1,93 @@ +// RUN: iree-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-expand-dimensions))" | FileCheck %s + +func.func @expand_matvec(%a: tensor<4x16384xf16>, %b: tensor<1x16384xf16>) -> tensor<4x1xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %empty = tensor.empty() : tensor<4x1xf32> + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<4x1xf32>) -> tensor<4x1xf32> + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%a, %b : tensor<4x16384xf16>, tensor<1x16384xf16>) + outs(%fill : tensor<4x1xf32>) + attrs = { + lowering_config = #iree_gpu.lowering_config<{ + expand_dims = #iree_gpu.expand_dims<[[0], [1], [2, 3]], output_shape = [?, ?, ?, 8]>, + lane_basis = [[1, 1, 64, 1], [0, 1, 2, 3]], + partial_reduction = [0, 0, 64, 0], + subgroup_basis = [[1, 1, 1, 1], [0, 1, 2, 3]], + thread = [0, 0, 1, 8], + workgroup = [4, 1, 0, 0]}>} { + ^bb0(%in: f16, %in_0: f16, %out: f32): + %0 = arith.extf %in : f16 to f32 + %1 = arith.extf %in_0 : f16 to f32 + %2 = arith.mulf %0, %1 : f32 + %3 = arith.addf %out, %2 : f32 + linalg.yield %3 : f32 + } -> tensor<4x1xf32> + return %result : tensor<4x1xf32> +} + +// CHECK-LABEL: func.func @expand_matvec +// CHECK: %[[A_EXPAND:.*]] = tensor.expand_shape %{{.*}} {{\[}}[0], [1, 2]] output_shape [4, 2048, 8] : tensor<4x16384xf16> into tensor<4x2048x8xf16> +// CHECK: %[[B_EXPAND:.*]] = tensor.expand_shape %{{.*}} {{\[}}[0], [1, 2]] output_shape [1, 2048, 8] : tensor<1x16384xf16> into tensor<1x2048x8xf16> +// CHECK: linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "reduction"] +// CHECK-SAME: ins(%[[A_EXPAND]], %[[B_EXPAND]] : tensor<4x2048x8xf16>, tensor<1x2048x8xf16>) + +// ----- + +func.func @expand_multiple_dims(%a: tensor<4x16384xf16>, %b: tensor<4x16384xf16>) -> tensor<4x16384xf16> { + %empty = tensor.empty() : tensor<4x16384xf16> + %result = linalg.add { + lowering_config = #iree_gpu.lowering_config<{ + expand_dims = #iree_gpu.expand_dims<[[0], [1, 2, 3]], output_shape = [?, ?, 2, 4]> + }>} + ins(%a, %b : tensor<4x16384xf16>, tensor<4x16384xf16>) outs(%empty : tensor<4x16384xf16>) -> tensor<4x16384xf16> + return %result : tensor<4x16384xf16> +} + +// CHECK-LABEL: func.func @expand_multiple_dims +// CHECK: %[[A_EXPAND:.*]] = tensor.expand_shape %{{.*}} {{\[}}[0], [1, 2, 3]] output_shape [4, 2048, 2, 4] : tensor<4x16384xf16> into tensor<4x2048x2x4xf16> +// CHECK: %[[B_EXPAND:.*]] = tensor.expand_shape %{{.*}} {{\[}}[0], [1, 2, 3]] output_shape [4, 2048, 2, 4] : tensor<4x16384xf16> into tensor<4x2048x2x4xf16> +// CHECK: linalg.generic +// CHECK-SAME: ins(%[[A_EXPAND]], %[[B_EXPAND]] : tensor<4x2048x2x4xf16>, tensor<4x2048x2x4xf16>) + +// ----- + +// Verify that dynamic dimensions are gracefully handled (no expansion occurs). +func.func @no_expand_dynamic_dims(%a: tensor<4x?xf16>, %b: tensor<4x?xf16>) -> tensor<4x128xf16> { + %empty = tensor.empty() : tensor<4x128xf16> + %result = linalg.add { + lowering_config = #iree_gpu.lowering_config<{ + expand_dims = #iree_gpu.expand_dims<[[0], [1, 2]], output_shape = [?, ?, 8]> + }>} + ins(%a, %b : tensor<4x?xf16>, tensor<4x?xf16>) outs(%empty : tensor<4x128xf16>) -> tensor<4x128xf16> + return %result : tensor<4x128xf16> +} + +// CHECK-LABEL: func.func @no_expand_dynamic_dim +// CHECK-NOT: tensor.expand_shape +// CHECK: linalg.add +// CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<4x?xf16>, tensor<4x?xf16>) + +// ----- + +// Verify that non-divisible dimensions are gracefully handled (no expansion occurs). +func.func @no_expand_not_divisible(%a: tensor<4x127xf16>, %b: tensor<4x127xf16>) -> tensor<4x127xf16> { + %empty = tensor.empty() : tensor<4x127xf16> + %result = linalg.add { + lowering_config = #iree_gpu.lowering_config<{ + expand_dims = #iree_gpu.expand_dims<[[0], [1, 2]], output_shape = [?, ?, 8]> + }>} + ins(%a, %b : tensor<4x127xf16>, tensor<4x127xf16>) outs(%empty : tensor<4x127xf16>) -> tensor<4x127xf16> + return %result : tensor<4x127xf16> +} + +// CHECK-LABEL: func.func @no_expand_not_divisible +// CHECK-NOT: tensor.expand_shape +// CHECK: linalg.add +// CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<4x127xf16>, tensor<4x127xf16>) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fold_swizzle_hint_ops.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fold_swizzle_hint_ops.mlir new file mode 100644 index 000000000000..4fdd41f9e6cf --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fold_swizzle_hint_ops.mlir @@ -0,0 +1,120 @@ +// RUN: iree-opt --mlir-print-local-scope --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-apply-tiling-level, canonicalize, cse))" %s | FileCheck %s + +// Test: tensor.extract_slice of swizzle_hint(tensor.empty) should fold +// to swizzle_hint(tensor.empty) with the sliced shape. +func.func @fold_extract_slice_of_swizzle_hint() -> tensor<16x32xf32> { + %empty = tensor.empty() : tensor<64x64xf32> + %swizzle = iree_codegen.swizzle_hint %empty[#iree_codegen.rotate_rows<64, 4>] : tensor<64x64xf32> + %slice = tensor.extract_slice %swizzle[0, 0] [16, 32] [1, 1] : tensor<64x64xf32> to tensor<16x32xf32> + return %slice : tensor<16x32xf32> +} + +// CHECK-LABEL: func.func @fold_extract_slice_of_swizzle_hint +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<16x32xf32> +// CHECK: %[[SWIZZLE:.+]] = iree_codegen.swizzle_hint %[[EMPTY]][#iree_codegen.rotate_rows<64, 4>] : tensor<16x32xf32> +// CHECK: return %[[SWIZZLE]] + +// Test: tensor.extract_slice with dynamic sizes should fold correctly. +func.func @fold_extract_slice_dynamic(%size0: index, %size1: index) -> tensor { + %empty = tensor.empty() : tensor<64x64xf32> + %swizzle = iree_codegen.swizzle_hint %empty[#iree_codegen.xor_shuffle<128, 16>] : tensor<64x64xf32> + %slice = tensor.extract_slice %swizzle[0, 0] [%size0, %size1] [1, 1] : tensor<64x64xf32> to tensor + return %slice : tensor +} + +// CHECK-LABEL: func.func @fold_extract_slice_dynamic +// CHECK-SAME: %[[SIZE0:[A-Za-z0-9]+]]: index +// CHECK-SAME: %[[SIZE1:[A-Za-z0-9]+]]: index +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[SIZE0]], %[[SIZE1]]) : tensor +// CHECK: %[[SWIZZLE:.+]] = iree_codegen.swizzle_hint %[[EMPTY]][#iree_codegen.xor_shuffle<128, 16>] : tensor +// CHECK: return %[[SWIZZLE]] + +// Test: tensor.expand_shape of swizzle_hint(tensor.empty) should fold +// to swizzle_hint(tensor.empty) with the expanded shape. +func.func @fold_expand_shape_of_swizzle_hint() -> tensor<4x16x64xf32> { + %empty = tensor.empty() : tensor<64x64xf32> + %swizzle = iree_codegen.swizzle_hint %empty[#iree_codegen.rotate_rows<64, 4>] : tensor<64x64xf32> + %expanded = tensor.expand_shape %swizzle [[0, 1], [2]] output_shape [4, 16, 64] : tensor<64x64xf32> into tensor<4x16x64xf32> + return %expanded : tensor<4x16x64xf32> +} + +// CHECK-LABEL: func.func @fold_expand_shape_of_swizzle_hint +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x16x64xf32> +// CHECK: %[[SWIZZLE:.+]] = iree_codegen.swizzle_hint %[[EMPTY]][#iree_codegen.rotate_rows<64, 4>] : tensor<4x16x64xf32> +// CHECK: return %[[SWIZZLE]] + +// Test: tensor.collapse_shape of swizzle_hint(tensor.empty) should fold +// to swizzle_hint(tensor.empty) with the collapsed shape. +func.func @fold_collapse_shape_of_swizzle_hint() -> tensor<64x64xf32> { + %empty = tensor.empty() : tensor<4x16x4x16xf32> + %swizzle = iree_codegen.swizzle_hint %empty[#iree_codegen.rotate_rows<64, 4>] : tensor<4x16x4x16xf32> + %collapsed = tensor.collapse_shape %swizzle [[0, 1], [2, 3]] : tensor<4x16x4x16xf32> into tensor<64x64xf32> + return %collapsed : tensor<64x64xf32> +} + +// CHECK-LABEL: func.func @fold_collapse_shape_of_swizzle_hint +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64x64xf32> +// CHECK: %[[SWIZZLE:.+]] = iree_codegen.swizzle_hint %[[EMPTY]][#iree_codegen.rotate_rows<64, 4>] : tensor<64x64xf32> +// CHECK: return %[[SWIZZLE]] + +// Negative test: extract_slice of swizzle_hint without tensor.empty source +// should NOT fold. +func.func @no_fold_extract_slice_non_empty(%arg0: tensor<64x64xf32>) -> tensor<16x32xf32> { + %swizzle = iree_codegen.swizzle_hint %arg0[#iree_codegen.rotate_rows<64, 4>] : tensor<64x64xf32> + %slice = tensor.extract_slice %swizzle[0, 0] [16, 32] [1, 1] : tensor<64x64xf32> to tensor<16x32xf32> + return %slice : tensor<16x32xf32> +} + +// CHECK-LABEL: func.func @no_fold_extract_slice_non_empty +// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<64x64xf32> +// CHECK: %[[SWIZZLE:.+]] = iree_codegen.swizzle_hint %[[ARG0]][#iree_codegen.rotate_rows<64, 4>] : tensor<64x64xf32> +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[SWIZZLE]] +// CHECK: return %[[SLICE]] + +// Negative test: expand_shape of swizzle_hint without tensor.empty source +// should NOT fold. +func.func @no_fold_expand_shape_non_empty(%arg0: tensor<64x64xf32>) -> tensor<4x16x64xf32> { + %swizzle = iree_codegen.swizzle_hint %arg0[#iree_codegen.rotate_rows<64, 4>] : tensor<64x64xf32> + %expanded = tensor.expand_shape %swizzle [[0, 1], [2]] output_shape [4, 16, 64] : tensor<64x64xf32> into tensor<4x16x64xf32> + return %expanded : tensor<4x16x64xf32> +} + +// CHECK-LABEL: func.func @no_fold_expand_shape_non_empty +// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<64x64xf32> +// CHECK: %[[SWIZZLE:.+]] = iree_codegen.swizzle_hint %[[ARG0]][#iree_codegen.rotate_rows<64, 4>] : tensor<64x64xf32> +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[SWIZZLE]] +// CHECK: return %[[EXPANDED]] + +// Test: XOR shuffle swizzle attribute is preserved through folding. +func.func @fold_xor_shuffle_swizzle() -> tensor<8x64xf32> { + %empty = tensor.empty() : tensor<16x128xf32> + %swizzle = iree_codegen.swizzle_hint %empty[#iree_codegen.xor_shuffle<128, 16>] : tensor<16x128xf32> + %slice = tensor.extract_slice %swizzle[0, 0] [8, 64] [1, 1] : tensor<16x128xf32> to tensor<8x64xf32> + return %slice : tensor<8x64xf32> +} + +// CHECK-LABEL: func.func @fold_xor_shuffle_swizzle +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x64xf32> +// CHECK: %[[SWIZZLE:.+]] = iree_codegen.swizzle_hint %[[EMPTY]][#iree_codegen.xor_shuffle<128, 16>] : tensor<8x64xf32> +// CHECK: return %[[SWIZZLE]] + +// Test: Rank-reducing extract_slice should work correctly. +func.func @fold_rank_reducing_extract_slice() -> tensor<32xf32> { + %empty = tensor.empty() : tensor<64x64xf32> + %swizzle = iree_codegen.swizzle_hint %empty[#iree_codegen.rotate_rows<64, 4>] : tensor<64x64xf32> + %slice = tensor.extract_slice %swizzle[0, 0] [1, 32] [1, 1] : tensor<64x64xf32> to tensor<32xf32> + return %slice : tensor<32xf32> +} + +// CHECK-LABEL: func.func @fold_rank_reducing_extract_slice +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32xf32> +// CHECK: %[[SWIZZLE:.+]] = iree_codegen.swizzle_hint %[[EMPTY]][#iree_codegen.rotate_rows<64, 4>] : tensor<32xf32> +// CHECK: return %[[SWIZZLE]] + +#encoding = #iree_encoding.encoding (m, k)>, affine_map<(m, n, k) -> (k, n)>, affine_map<(m, n, k) -> (m, n)>], iteration_sizes = [?, ?, ?]> +func.func @fold_swizzle_hint_of_encoding() -> tensor<16xbf16,#encoding> { + %empty = tensor.empty() : tensor<8x16xbf16, #encoding> + %swizzle = iree_codegen.swizzle_hint %empty[#iree_codegen.rotate_rows<8, 4>] : tensor<8x16xbf16, #encoding> + %slice = tensor.extract_slice %swizzle[0, 0] [1, 16] [1, 1] : tensor<8x16xbf16, #encoding> to tensor<16xbf16,#encoding> + return %slice : tensor<16xbf16,#encoding> +} diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir index deee7df12d3b..278dcdb89474 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir @@ -875,3 +875,72 @@ func.func @fuse_warp_and_lane_foralls_with_coalesced_dma(%src: tensor<2x2x64xf32 // CHECK: } // CHECK: } {mapping = [#gpu.thread, #gpu.thread, #gpu.thread]} // CHECK: return %[[THREAD_FORALL]] + +// ----- + +// Check that we dont make a zeroslice guard when fusing pad. +#map = affine_map<(d0) -> (d0 * 64)> +func.func @fuse_pad(%arg0: tensor, %arg1: index) -> tensor<128xf16> { + %c4 = arith.constant 4 : index + %c128 = arith.constant 128 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %0 = tensor.empty() : tensor<128xf16> + %padded = tensor.pad %arg0 low[0] high[%arg1] { + ^bb0(%arg10: index): + tensor.yield %cst : f16 + } : tensor to tensor<128xf16> + %1 = scf.forall (%arg2) in (2) shared_outs(%arg3 = %0) -> (tensor<128xf16>) { + %2 = affine.apply #map(%arg2) + %extracted_slice = tensor.extract_slice %padded[%2] [64] [1] : tensor<128xf16> to tensor<64xf16> + %extracted_slice_0 = tensor.extract_slice %arg3[%2] [64] [1] : tensor<128xf16> to tensor<64xf16> + %3 = linalg.copy ins(%extracted_slice : tensor<64xf16>) outs(%extracted_slice_0 : tensor<64xf16>) -> tensor<64xf16> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg3[%2] [64] [1] : tensor<64xf16> into tensor<128xf16> + } + } {mapping = [#gpu.thread]} + return %1 : tensor<128xf16> +} + +// CHECK-LABEL: func @fuse_pad +// CHECK: scf.forall +// CHECK-NOT: scf.if +// CHECK: tensor.pad +// CHECK: linalg.copy +// CHECK: scf.forall.in_parallel +// CHECK: return + +// ----- + +// Check that we can fuse padded destinations. +#map = affine_map<(d0) -> (d0 * 64)> +func.func @fuse_pad_dest(%arg0: tensor<128xf16>, %arg1: index) -> tensor<128xf16> { + %c4 = arith.constant 4 : index + %c128 = arith.constant 128 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %0 = tensor.empty() : tensor<128xf16> + %padded = tensor.pad %arg0 low[0] high[1] { + ^bb0(%arg10: index): + tensor.yield %cst : f16 + } : tensor<128xf16> to tensor<129xf16> + %extracted_slice_dest = tensor.extract_slice %padded[0] [128] [1] : tensor<129xf16> to tensor<128xf16> + %1 = scf.forall (%arg2) in (2) shared_outs(%arg3 = %extracted_slice_dest) -> (tensor<128xf16>) { + %2 = affine.apply #map(%arg2) + %extracted_slice = tensor.extract_slice %0[%2] [64] [1] : tensor<128xf16> to tensor<64xf16> + %extracted_slice_0 = tensor.extract_slice %arg3[%2] [64] [1] : tensor<128xf16> to tensor<64xf16> + %3 = linalg.copy ins(%extracted_slice : tensor<64xf16>) outs(%extracted_slice_0 : tensor<64xf16>) -> tensor<64xf16> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg3[%2] [64] [1] : tensor<64xf16> into tensor<128xf16> + } + } {mapping = [#gpu.thread]} + return %1 : tensor<128xf16> +} + +// CHECK-LABEL: func @fuse_pad_dest +// CHECK-NOT: tensor.pad +// CHECK: scf.forall +// CHECK-NOT: tensor.pad +// CHECK: linalg.copy +// CHECK: scf.forall.in_parallel +// CHECK: return diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir index cb209b3db860..28600bdd4130 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir @@ -305,3 +305,90 @@ func.func @promote_with_cache_swizzle_f4_no_stride(%a: tensor<2x34x34x129xf4E2M1 // CHECK-SAME: lowering_config = #iree_gpu.use_global_load_dma // CHECK-SAME: ins(%[[SWIZZLE_B]] // CHECK: linalg.batch_matmul {{.*}} ins(%[[PA]], %[[PB]] + +// ----- + +#lowering_config = #iree_gpu.lowering_config<{ + promote_operands = [0, 1], + promotion_types = [ + #iree_gpu.swizzle_operand>, + #iree_gpu.swizzle_operand>]}> + +func.func @promote_with_swizzle_operand(%a: tensor<32x64xf32>, %b: tensor<64x128xf32>) -> tensor<32x128xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %empty = tensor.empty() : tensor<32x128xf32> + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<32x128xf32>) -> tensor<32x128xf32> + %mm = linalg.matmul {lowering_config = #lowering_config} + ins(%a, %b : tensor<32x64xf32>, tensor<64x128xf32>) outs(%fill : tensor<32x128xf32>) -> tensor<32x128xf32> + return %mm : tensor<32x128xf32> +} + +// SwizzleOperand attribute creates swizzle_hint op with xor_shuffle +// and flattens/expands the tensor for shared memory swizzling. +// CHECK-LABEL: func.func @promote_with_swizzle_operand +// CHECK-SAME: %[[A:[A-Za-z0-9]+]]: tensor<32x64xf32> +// CHECK-SAME: %[[B:[A-Za-z0-9]+]]: tensor<64x128xf32> +// CHECK: %[[EMPTY_A:.+]] = tensor.empty() : tensor<2048xf32> +// CHECK: %[[SWIZZLE_A:.+]] = iree_codegen.swizzle_hint %[[EMPTY_A]][#iree_codegen.xor_shuffle<128, 16>] : tensor<2048xf32> +// CHECK: %[[EXPAND_A:.+]] = tensor.expand_shape %[[SWIZZLE_A]] {{\[\[}}0, 1{{\]\]}} output_shape [32, 64] : tensor<2048xf32> into tensor<32x64xf32> +// CHECK: %[[COPY_A:.+]] = linalg.copy +// CHECK-SAME: lowering_config = #iree_gpu.use_global_load_dma +// CHECK-SAME: ins(%[[A]] : tensor<32x64xf32>) outs(%[[EXPAND_A]] : tensor<32x64xf32>) +// CHECK: %[[EMPTY_B:.+]] = tensor.empty() : tensor<8192xf32> +// CHECK: %[[SWIZZLE_B:.+]] = iree_codegen.swizzle_hint %[[EMPTY_B]][#iree_codegen.xor_shuffle<256, 32>] : tensor<8192xf32> +// CHECK: %[[EXPAND_B:.+]] = tensor.expand_shape %[[SWIZZLE_B]] {{\[\[}}0, 1{{\]\]}} output_shape [64, 128] : tensor<8192xf32> into tensor<64x128xf32> +// CHECK: %[[COPY_B:.+]] = linalg.copy +// CHECK-SAME: lowering_config = #iree_gpu.derived_thread_config +// CHECK-SAME: ins(%[[B]] : tensor<64x128xf32>) outs(%[[EXPAND_B]] : tensor<64x128xf32>) +// CHECK: linalg.matmul {{.*}} ins(%[[COPY_A]], %[[COPY_B]] : tensor<32x64xf32>, tensor<64x128xf32>) + +// ----- + +#lowering_config = #iree_gpu.lowering_config<{ + promote_operands = [1], + promotion_types = [ + #iree_gpu.swizzle_operand>]}> + +func.func @promote_with_swizzle_operand_f16(%a: tensor<32x64xf16>, %b: tensor<64x128xf16>) -> tensor<32x128xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %empty = tensor.empty() : tensor<32x128xf32> + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<32x128xf32>) -> tensor<32x128xf32> + %mm = linalg.matmul {lowering_config = #lowering_config} + ins(%a, %b : tensor<32x64xf16>, tensor<64x128xf16>) outs(%fill : tensor<32x128xf32>) -> tensor<32x128xf32> + return %mm : tensor<32x128xf32> +} + +// SwizzleOperand with f16 element type. +// CHECK-LABEL: func.func @promote_with_swizzle_operand_f16 +// CHECK-SAME: %[[A:[A-Za-z0-9]+]]: tensor<32x64xf16> +// CHECK-SAME: %[[B:[A-Za-z0-9]+]]: tensor<64x128xf16> +// CHECK: %[[EMPTY_B:.+]] = tensor.empty() : tensor<8192xf16> +// CHECK: %[[SWIZZLE_B:.+]] = iree_codegen.swizzle_hint %[[EMPTY_B]][#iree_codegen.xor_shuffle<64, 8>] : tensor<8192xf16> +// CHECK: %[[EXPAND_B:.+]] = tensor.expand_shape %[[SWIZZLE_B]] {{\[\[}}0, 1{{\]\]}} output_shape [64, 128] : tensor<8192xf16> into tensor<64x128xf16> +// CHECK: %[[COPY_B:.+]] = linalg.copy +// CHECK-SAME: lowering_config = #iree_gpu.use_global_load_dma +// CHECK-SAME: ins(%[[B]] : tensor<64x128xf16>) outs(%[[EXPAND_B]] : tensor<64x128xf16>) +// CHECK: linalg.matmul {{.*}} ins(%[[A]], %[[COPY_B]] : tensor<32x64xf16>, tensor<64x128xf16>) + +// ----- + +#lowering_config = #iree_gpu.lowering_config<{ + promote_operands = [0], + promotion_types = [ + #iree_gpu.swizzle_operand>]}> + +func.func @swizzle_operand_no_promote_fill(%b: tensor<128x128xf32>) -> tensor<4x128xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %empty = tensor.empty() : tensor<4x128xf32> + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<4x128xf32>) -> tensor<4x128xf32> + %mm = linalg.matmul {lowering_config = #lowering_config} + ins(%fill, %b : tensor<4x128xf32>, tensor<128x128xf32>) outs(%fill : tensor<4x128xf32>) -> tensor<4x128xf32> + return %mm : tensor<4x128xf32> +} + +// Verify that fills are not promoted even with swizzle_operand. +// CHECK-LABEL: func.func @swizzle_operand_no_promote_fill +// CHECK-NOT: iree_codegen.swizzle_hint +// CHECK-NOT: tensor.expand_shape +// CHECK: linalg.matmul +// CHECK: return diff --git a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp index a6114b60cc0a..fe13e195b271 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp @@ -96,8 +96,9 @@ getVectorSizes(Operation *op, bool useConfiguredVectorSizes) { auto ty = padOp.getResultType(); // TODO(hanchung): Infer the vector sizes for pad op after // maskedVectorize method allows dynamic result shapes. - if (!ty.hasStaticShape()) + if (!ty.hasStaticShape()) { return; + } vectorSizes = SmallVector(ty.getShape()); }) .Case([&](IREE::LinalgExt::GatherOp gatherOp) { @@ -121,10 +122,12 @@ static LogicalResult isWithinVectorSizeLimit(linalg::LinalgOp linalgOp, int64_t maxFlatVecSize = 1; for (OpOperand &operand : linalgOp->getOpOperands()) { auto type = dyn_cast(operand.get().getType()); - if (!type) + if (!type) { continue; - if (!type.hasStaticShape()) + } + if (!type.hasStaticShape()) { return failure(); + } maxFlatVecSize = std::max(maxFlatVecSize, type.getNumElements()); } return success(maxFlatVecSize < maxVectorSize); @@ -183,11 +186,13 @@ void GenericVectorizationPass::runOnOperation() { // Do not vectorize the op if the vector size is greater than or equal // to limit. if (enableVectorMasking) { - if (llvm::product_of(vectorSizes) >= maxVectorSize) + if (llvm::product_of(vectorSizes) >= maxVectorSize) { continue; + } } else { - if (failed(isWithinVectorSizeLimit(linalgOp, maxVectorSize))) + if (failed(isWithinVectorSizeLimit(linalgOp, maxVectorSize))) { continue; + } } } // Pad scalable dims with `false` to match the vector sizes. diff --git a/compiler/src/iree/compiler/Codegen/Common/HoistStaticallyBoundAllocations.cpp b/compiler/src/iree/compiler/Codegen/Common/HoistStaticallyBoundAllocations.cpp index 2fd1b1b6bf47..5879aedbc214 100644 --- a/compiler/src/iree/compiler/Codegen/Common/HoistStaticallyBoundAllocations.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/HoistStaticallyBoundAllocations.cpp @@ -36,8 +36,9 @@ void HoistStaticallyBoundAllocationsPass::runOnOperation() { IRRewriter rewriter(funcOp->getContext()); std::optional vscaleRange; - if (this->vscaleMax != 0 && this->vscaleMin <= this->vscaleMax) + if (this->vscaleMax != 0 && this->vscaleMin <= this->vscaleMax) { vscaleRange = {this->vscaleMin, this->vscaleMax}; + } hoistStaticallyBoundAllocationsInFunc(rewriter, funcOp, vscaleRange); diff --git a/compiler/src/iree/compiler/Codegen/Common/HoistUnrolledVectorExtractInsertSlice.cpp b/compiler/src/iree/compiler/Codegen/Common/HoistUnrolledVectorExtractInsertSlice.cpp index 9c2d0974d21e..ac8026e5d85a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/HoistUnrolledVectorExtractInsertSlice.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/HoistUnrolledVectorExtractInsertSlice.cpp @@ -37,19 +37,22 @@ getUnrolledExtractSlices( SmallVector res; for (auto user : srcTensor.getUsers()) { auto extractStridedSliceOp = dyn_cast(user); - if (!extractStridedSliceOp) + if (!extractStridedSliceOp) { return failure(); + } res.push_back(extractStridedSliceOp); } - if (res.size() != insertOps.size()) + if (res.size() != insertOps.size()) { return failure(); + } std::reverse(res.begin(), res.end()); for (auto [extractOp, insertOp] : llvm::zip_equal(res, insertOps)) { auto offset0 = insertOp.getOffsets(); auto offset1 = extractOp.getOffsets(); - if (offset0 != offset1) + if (offset0 != offset1) { return failure(); + } } return res; @@ -72,8 +75,9 @@ getUnrolledInsertSlices(scf::ForOp forOp, BlockArgument bbArg, SmallVector res; Value v = yieldOperand.get(); auto insertStridedSliceOp = v.getDefiningOp(); - if (!insertStridedSliceOp) + if (!insertStridedSliceOp) { return failure(); + } ArrayRef vecShape = insertStridedSliceOp.getSourceVectorType().getShape(); @@ -81,8 +85,9 @@ getUnrolledInsertSlices(scf::ForOp forOp, BlockArgument bbArg, insertStridedSliceOp.getDestVectorType().getShape(); int numOps = 1; for (auto [vecSize, destSize] : llvm::zip_equal(vecShape, destShape)) { - if (destSize % vecSize) + if (destSize % vecSize) { return failure(); + } numOps *= destSize / vecSize; } @@ -91,19 +96,22 @@ getUnrolledInsertSlices(scf::ForOp forOp, BlockArgument bbArg, insertStridedSliceOp = insertStridedSliceOp.getDest() .getDefiningOp(); } - if (res.size() != numOps) + if (res.size() != numOps) { return failure(); + } std::reverse(res.begin(), res.end()); SmallVector expectedOffsets(vecShape.size(), 0); for (vector::InsertStridedSliceOp op : res) { SmallVector offsets = getI64SubArray(op.getOffsets()); - if (expectedOffsets != offsets) + if (expectedOffsets != offsets) { return failure(); + } expectedOffsets.back() += vecShape.back(); for (int pos = expectedOffsets.size() - 1; pos > 0; pos--) { - if (expectedOffsets[pos] != destShape[pos]) + if (expectedOffsets[pos] != destShape[pos]) { break; + } expectedOffsets[pos] = 0; expectedOffsets[pos - 1] += vecShape[pos - 1]; } @@ -189,11 +197,13 @@ static scf::ForOp hoistUnrolledVectorExtractInsert(RewriterBase &rewriter, LLVM_DEBUG(DBGS() << "Consider " << it.value() << "\n"); OpOperand &ret = yield->getOpOperand(it.index()); auto insertOps = getUnrolledInsertSlices(forOp, it.value(), ret); - if (failed(insertOps)) + if (failed(insertOps)) { continue; + } auto extractOps = getUnrolledExtractSlices(it.value(), insertOps.value()); - if (failed(extractOps)) + if (failed(extractOps)) { continue; + } newForOp = hoistVectorExtractInsertSlice(rewriter, extractOps.value(), insertOps.value(), it.value()); break; diff --git a/compiler/src/iree/compiler/Codegen/Common/IREECodegenCanonicalizer.cpp b/compiler/src/iree/compiler/Codegen/Common/IREECodegenCanonicalizer.cpp index 20bc8b4b9d6e..45ac22dcd333 100644 --- a/compiler/src/iree/compiler/Codegen/Common/IREECodegenCanonicalizer.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/IREECodegenCanonicalizer.cpp @@ -24,8 +24,9 @@ namespace { /// shape is same as the size of the subview. In such cases, the subview can /// be folded into its source. static bool isTrivialSubViewOp(memref::SubViewOp subviewOp) { - if (subviewOp.getSourceType().getRank() != subviewOp.getType().getRank()) + if (subviewOp.getSourceType().getRank() != subviewOp.getType().getRank()) { return false; + } if (!areAllConstantIntValue(subviewOp.getMixedOffsets(), 0) || !areAllConstantIntValue(subviewOp.getMixedStrides(), 1)) { @@ -81,8 +82,9 @@ class DynamicTrivialSubViewOpFolder final LogicalResult matchAndRewrite(memref::SubViewOp subViewOp, PatternRewriter &rewriter) const override { - if (!isTrivialSubViewOp(subViewOp)) + if (!isTrivialSubViewOp(subViewOp)) { return failure(); + } if (subViewOp.getSourceType() == subViewOp.getType()) { rewriter.replaceOp(subViewOp, subViewOp.getSource()); return success(); @@ -105,8 +107,9 @@ struct IREECodegenCanonicalizerPass final GreedySimplifyRegionLevel::Normal); RewritePatternSet owningPatterns(context); - for (auto *dialect : context->getLoadedDialects()) + for (auto *dialect : context->getLoadedDialects()) { dialect->getCanonicalizationPatterns(owningPatterns); + } for (RegisteredOperationName op : context->getRegisteredOperations()) { if (op.getStringRef() == memref::CopyOp::getOperationName()) { owningPatterns.add(context); diff --git a/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp b/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp index c6acdd001089..3aabe98df1e9 100644 --- a/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp @@ -63,9 +63,10 @@ static FailureOr defaultAllocationFn(OpBuilder &builder, Location loc, // type memory space; that's runtime allocations. So erase and fallback to // the default 0 memory space. It is fine given this is just the default // allocator; backends are expected to control by themselves. - if (isa(storage)) + if (isa(storage)) { type = MemRefType::get(type.getShape(), type.getElementType(), type.getLayout()); + } } return memref::AllocOp::create(builder, loc, type, dynamicSizes).getResult(); } @@ -172,12 +173,14 @@ eliminateEmptyTensors(RewriterBase &rewriter, Operation *op, const OneShotBufferizationOptions &options) { // Analyze IR. OneShotAnalysisState state(op, options); - if (failed(analyzeOp(op, state))) + if (failed(analyzeOp(op, state))) { return failure(); + } // Rewrite tensor.empty ops that are anchored on specific ops. - if (failed(bufferization::eliminateEmptyTensors(rewriter, op, state))) + if (failed(bufferization::eliminateEmptyTensors(rewriter, op, state))) { return failure(); + } return success(); } @@ -215,11 +218,13 @@ void EliminateEmptyTensorsPass::runOnOperation() { auto bufferizationOptions = getBufferizationOptions(); OneShotAnalysisState state(funcOp, bufferizationOptions); // Analyze IR. - if (failed(analyzeOp(funcOp, state))) + if (failed(analyzeOp(funcOp, state))) { return signalPassFailure(); + } // Eliminate empty tensors. - if (failed(bufferization::eliminateEmptyTensors(rewriter, funcOp, state))) + if (failed(bufferization::eliminateEmptyTensors(rewriter, funcOp, state))) { return signalPassFailure(); + } } // The following is copied from bufferization::runOneShotBufferize with @@ -229,10 +234,12 @@ runIREEOneShotBufferize(Operation *op, const IREEOneShotBufferizationOptions &options, bufferization::BufferizationState &state) { OneShotAnalysisState analyzeState(op, options); - if (failed(analyzeOp(op, analyzeState))) + if (failed(analyzeOp(op, analyzeState))) { return failure(); - if (options.testAnalysisOnly) + } + if (options.testAnalysisOnly) { return success(); + } return bufferization::runOneShotBufferize(op, options, state); } @@ -302,10 +309,12 @@ std::unique_ptr> createIREEComprehensiveBufferizePass( std::optional allocationFn, std::optional memCpyFn) { - if (!allocationFn) + if (!allocationFn) { allocationFn = defaultAllocationFn; - if (!memCpyFn) + } + if (!memCpyFn) { memCpyFn = defaultMemCpyFn; + } return std::make_unique(allocationFn.value(), memCpyFn.value()); } diff --git a/compiler/src/iree/compiler/Codegen/Common/IREEExpandStridedMetadata.cpp b/compiler/src/iree/compiler/Codegen/Common/IREEExpandStridedMetadata.cpp index d2a9ed3ca3ed..a02432b02be6 100644 --- a/compiler/src/iree/compiler/Codegen/Common/IREEExpandStridedMetadata.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/IREEExpandStridedMetadata.cpp @@ -158,8 +158,9 @@ struct ConvertMemRefExtractMetadataToIREECodegen using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, PatternRewriter &rewriter) const override { - if (!getSourceInterfaceBinding(op.getSource())) + if (!getSourceInterfaceBinding(op.getSource())) { return failure(); + } // Replace with iree_codegen version which doesn't fold. rewriter.replaceOpWithNewOp( op, op.getSource()); @@ -173,8 +174,9 @@ struct ResolveExtractMetadataFromHalInterfaceBindingSubspan LogicalResult matchAndRewrite(IREE::Codegen::ExtractStridedMetadataOp op, PatternRewriter &rewriter) const override { auto binding = getSourceInterfaceBinding(op.getSource()); - if (!binding) + if (!binding) { return failure(); + } auto memRefType = cast(binding->getResult().getType()); auto loc = op.getLoc(); @@ -287,8 +289,9 @@ struct ConvertIREECodegenExtractMetadataToMemRef // Pattern ResolveExtractMetadataFromHalInterfaceBindingSubspan must // resolve these first to preserve SSA links through buffer binding // optimizations. - if (getSourceInterfaceBinding(op.getSource())) + if (getSourceInterfaceBinding(op.getSource())) { return failure(); + } // Only convert ops that don't have HAL bindings (or are already resolved). rewriter.replaceOpWithNewOp( diff --git a/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp index 1980cd3d220a..0039de5fd7e3 100644 --- a/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp @@ -161,8 +161,9 @@ static void updateNamedSequenceOp( seenNames.insert(newSeqName); // Skip updating ForeachMatchOp if the NamedSequenceOp is not used in it. - if (!namedSequenceToUser.contains(op)) + if (!namedSequenceToUser.contains(op)) { return; + } ForeachMatchOp foreachMatchOp = namedSequenceToUser[op]; @@ -408,12 +409,8 @@ static FailureOr emitLinkedDefaultTuningSpec(ModuleOp module) { SmallVector mergedActions; for (ForeachMatchOp foreachMatchOp : foreachMatchOps) { - ArrayAttr matchers = foreachMatchOp.getMatchers(); - ArrayAttr actions = foreachMatchOp.getActions(); - for (auto [matcher, action] : llvm::zip_equal(matchers, actions)) { - mergedMatchers.push_back(cast(matcher)); - mergedActions.push_back(cast(action)); - } + llvm::append_range(mergedMatchers, foreachMatchOp.getMatchers()); + llvm::append_range(mergedActions, foreachMatchOp.getActions()); } Region ®ion = newEntryPoint.getRegion(); @@ -422,8 +419,8 @@ static FailureOr emitLinkedDefaultTuningSpec(ModuleOp module) { builder.setInsertionPointToStart(body); auto mergedForeachMatch = ForeachMatchOp::create( builder, loc, resultTypes, newEntryPoint.getArgument(0), - /* forwarded_inputs = */ ValueRange(), - /* restrictRoot = */ nullptr, /* flattenResults = */ nullptr, + /*forwarded_inputs=*/ValueRange(), + /*restrict_root=*/false, /*flatten_results=*/false, builder.getArrayAttr(mergedMatchers), builder.getArrayAttr(mergedActions)); transform::YieldOp::create(builder, loc, mergedForeachMatch->getResult(0)); diff --git a/compiler/src/iree/compiler/Codegen/Common/LowerUKernelDescriptors.cpp b/compiler/src/iree/compiler/Codegen/Common/LowerUKernelDescriptors.cpp index f35fd48c60e0..74b460ebeafb 100644 --- a/compiler/src/iree/compiler/Codegen/Common/LowerUKernelDescriptors.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/LowerUKernelDescriptors.cpp @@ -120,8 +120,9 @@ convertToUKernelGeneric(RewriterBase &rewriter, Operation *op, StringRef name, provider.createAndReplaceWithUkernelOp( rewriter, name, targetConfiguration, op, tensorInputs, tensorOutputs, otherOperands); - if (retVal) + if (retVal) { return retVal.value(); + } } // Default ukernel generic op is created when a provider doesn't exist or when // the provider doesn't implement the replacement method. diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncoding.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncoding.cpp index c606152324b1..d8686b812d6a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncoding.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncoding.cpp @@ -204,8 +204,9 @@ materializeFuncOpEncodings(FunctionOpInterface funcOp, // the pipeline. if (isa(consumer) && isa_and_nonnull(producer) && - !producer->hasOneUse()) + !producer->hasOneUse()) { return false; + } return true; }); memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp index 5b50c41becbd..b87bfc53e342 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp @@ -299,11 +299,12 @@ static Value generateEncodingTransferOps(RewriterBase &rewriter, Value src, Value value = src; if (srcType.getEncoding()) { value = IREE::Encoding::UnsetEncodingOp::create( - rewriter, src.getLoc(), srcType.dropEncoding(), value, dynamicDims); + rewriter, src.getLoc(), srcType.dropEncoding(), value, dynamicDims, + /*encodingDims=*/ValueRange{}); } if (destType.getEncoding()) { - value = IREE::Encoding::SetEncodingOp::create(rewriter, src.getLoc(), - destType, value); + value = IREE::Encoding::SetEncodingOp::create( + rewriter, src.getLoc(), destType, value, /*encodingDims=*/ValueRange{}); } return value; } @@ -460,8 +461,9 @@ struct MaterializeOperation : public OpConversionPattern { this->template getTypeConverter(); FailureOr convertedOp = lowerOpWithEncoding(rewriter, op, adaptor.getOperands(), *converter); - if (failed(convertedOp)) + if (failed(convertedOp)) { return failure(); + } rewriter.replaceOp(op, convertedOp.value()); return success(); @@ -704,8 +706,9 @@ void populateMaterializeEncodingPatterns( auto resultType = dyn_cast( subspanOp.getResult().getType()); // For types that are not `TensorExt::DispatchTensorType` mark as legal. - if (!resultType) + if (!resultType) { return true; + } return resultType == typeConverter.convertType(resultType); }); target.addIllegalOp( storeOp.getTargetType()); // For types that are not `TensorExt::DispatchTensorType` mark as legal. - if (!resultType) + if (!resultType) { return true; + } return resultType == typeConverter.convertType(resultType); }); target.addDynamicallyLegalOp( @@ -724,12 +728,13 @@ void populateMaterializeEncodingPatterns( auto resultType = dyn_cast( loadOp.getSourceType()); // For types that are not `TensorExt::DispatchTensorType` mark as legal. - if (!resultType) + if (!resultType) { return true; + } return resultType == typeConverter.convertType(resultType); }); target.addDynamicallyLegalOp([](func::ReturnOp returnOp) { - return !llvm::any_of(returnOp.getOperandTypes(), + return llvm::none_of(returnOp.getOperandTypes(), isRankedTensorTypeWithEncoding); }); diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeTuningSpecsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeTuningSpecsPass.cpp index a11fa423ddf4..b008bfeaeb50 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MaterializeTuningSpecsPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeTuningSpecsPass.cpp @@ -149,9 +149,10 @@ getDefaultTuningSpec(ModuleOp module, #ifndef NDEBUG if (succeeded(defaultTransformLibrary) && - failed(mlir::verify(*defaultTransformLibrary))) + failed(mlir::verify(*defaultTransformLibrary))) { return (*defaultTransformLibrary).emitError() << "Default tuning spec from " << storageAttr << " failed to verify"; + } #endif return defaultTransformLibrary; diff --git a/compiler/src/iree/compiler/Codegen/Common/MathTransformPass.cpp b/compiler/src/iree/compiler/Codegen/Common/MathTransformPass.cpp index cd8d08df66c1..532730af1b85 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MathTransformPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MathTransformPass.cpp @@ -142,8 +142,9 @@ static bool predicateDeviceLibImpl(StringRef name, bool hasFastExp = isROCMBackend(target); // If fast exp is not available, don't use device-lib implementations. - if (!hasFastExp) + if (!hasFastExp) { return false; + } // Only apply to erf for now. StringRef erf = math::ErfOp::getOperationName(); diff --git a/compiler/src/iree/compiler/Codegen/Common/MemrefCopyToLinalg.cpp b/compiler/src/iree/compiler/Codegen/Common/MemrefCopyToLinalg.cpp index 0839fb76a747..f1c85625e0a5 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MemrefCopyToLinalg.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MemrefCopyToLinalg.cpp @@ -25,8 +25,9 @@ struct MemrefCopyOpToLinalg : public OpRewritePattern { Operation *linalgCopy = createLinalgCopyOp(rewriter, copyOp.getLoc(), copyOp.getSource(), copyOp.getTarget(), copyOp->getAttrs()); - if (!linalgCopy) + if (!linalgCopy) { return failure(); + } rewriter.replaceOp(copyOp, linalgCopy->getResults()); return success(); } diff --git a/compiler/src/iree/compiler/Codegen/Common/NormalizeLoopBounds.cpp b/compiler/src/iree/compiler/Codegen/Common/NormalizeLoopBounds.cpp index 6a237b1d8d3a..d9802f233b33 100644 --- a/compiler/src/iree/compiler/Codegen/Common/NormalizeLoopBounds.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/NormalizeLoopBounds.cpp @@ -137,8 +137,9 @@ LogicalResult normalizeLoopBounds(RewriterBase &rewriter, scf::ForOp forOp) { LogicalResult normalizeLoopBounds(RewriterBase &rewriter, scf::ForallOp forallOp) { OpBuilder::InsertionGuard g(rewriter); - if (forallOp.isNormalized()) + if (forallOp.isNormalized()) { return success(); + } // `scf.forall` requires that all lbs/ubs/steps/ivs are index type so no need // to check here. diff --git a/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp b/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp index e65554190878..ef4fbe5ad5f9 100644 --- a/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp @@ -49,8 +49,9 @@ class OptimizeTensorInsertExtractSlicesPass final static bool canBeHoisted(LoopLikeOpInterface loopLike, SubsetInsertionOpInterface insertion) { // Do not move terminators. - if (insertion->hasTrait()) + if (insertion->hasTrait()) { return false; + } // Walk the nested operations and check that all used values are either // defined outside of the loop or in a nested region, but not at the level of @@ -58,8 +59,10 @@ static bool canBeHoisted(LoopLikeOpInterface loopLike, auto walkFn = [&](Operation *child) { for (OpOperand &operand : child->getOpOperands()) { // Ignore values defined in a nested region. - if (insertion->isAncestor(operand.get().getParentRegion()->getParentOp())) + if (insertion->isAncestor( + operand.get().getParentRegion()->getParentOp())) { continue; + } if (!loopLike.isDefinedOutsideOfLoop(operand.get()) && &operand != &insertion.getSourceOperand()) { return WalkResult::interrupt(); @@ -310,8 +313,9 @@ struct FoldMaskedTransferRAW : OpRewritePattern { [](Value v) { return !isZeroInteger(v); }) || llvm::any_of(writeOp.getIndices(), [](Value v) { return !isZeroInteger(v); })) && - (op.getIndices() != writeOp.getIndices())) + (op.getIndices() != writeOp.getIndices())) { return failure(); + } // Work only with minor identity mappings. if (!op.getPermutationMap().isMinorIdentity() || diff --git a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp index 1aee9088d051..d639f8d79edf 100644 --- a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp @@ -44,8 +44,9 @@ class TransposeUnitDimToShapeCast unsigned numNonUnitSrcDim = llvm::count_if(op.getSourceVectorType().getShape(), [](int64_t dim) { return dim != 1; }); - if (numNonUnitSrcDim > 1) + if (numNonUnitSrcDim > 1) { return failure(); + } rewriter.replaceOpWithNewOp( op, op.getResultVectorType(), op.getVector()); return success(); diff --git a/compiler/src/iree/compiler/Codegen/Common/PadDynamicAlloc.cpp b/compiler/src/iree/compiler/Codegen/Common/PadDynamicAlloc.cpp index b7794ffaea06..46c6b559f175 100644 --- a/compiler/src/iree/compiler/Codegen/Common/PadDynamicAlloc.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/PadDynamicAlloc.cpp @@ -23,15 +23,18 @@ namespace mlir::iree_compiler { /// compute alloc sizes. static Value skipAffineMaxZero(Value dim) { auto affineMax = dim.getDefiningOp(); - if (!affineMax) + if (!affineMax) { return dim; + } for (AffineExpr expr : affineMax.getMap().getResults()) { if (auto cst = dyn_cast(expr)) { - if (cst.getValue() == 0) + if (cst.getValue() == 0) { continue; + } } else if (auto symExpr = dyn_cast(expr)) { - if (symExpr.getPosition() == 0) + if (symExpr.getPosition() == 0) { continue; + } } return dim; } @@ -62,8 +65,9 @@ static LogicalResult padAlloc(MLIRContext *context, AllocLikeOp allocOp, dimSize = *ub; sizes.push_back(dim); } - if (dynamicDimIdx == 0) + if (dynamicDimIdx == 0) { return success(); + } Type elType = allocOp.getType().getElementType(); MemRefType allocType = MemRefType::get(shape, elType, AffineMap(), allocOp.getType().getMemorySpace()); @@ -98,8 +102,9 @@ struct PadDynamicAllocPass final SmallVector allocs; funcOp.walk([&](memref::AllocOp allocOp) { allocs.push_back(allocOp); }); for (memref::AllocOp alloc : allocs) { - if (failed(padAlloc(context, alloc, solver))) + if (failed(padAlloc(context, alloc, solver))) { return signalPassFailure(); + } } // Collect all the alloca operations. @@ -107,8 +112,9 @@ struct PadDynamicAllocPass final funcOp.walk( [&](memref::AllocaOp allocaOp) { allocas.push_back(allocaOp); }); for (memref::AllocaOp alloca : allocas) { - if (failed(padAlloc(context, alloca, solver))) + if (failed(padAlloc(context, alloca, solver))) { return signalPassFailure(); + } } } }; diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index e604891e593a..b8a6fd88a0fd 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -496,6 +496,11 @@ def FlattenMemRefSubspanPass : Pass<"iree-codegen-flatten-memref-subspan", "Modu }]; } +def FlattenSwizzleHintAllocsPass : + InterfacePass<"iree-codegen-flatten-swizzle-hint-allocs", "mlir::FunctionOpInterface"> { + let summary = "Flattens allocations associated with iree_codegen.swizzle_hint ops"; +} + def FoldAffineMinInDistributedLoopsPass : InterfacePass<"iree-codegen-fold-affinemin-in-distributed-loops", "mlir::FunctionOpInterface"> { let summary = "Fold `affine.min` ops in distributed loops"; @@ -881,6 +886,21 @@ def RematerializeParallelOpsPass : let summary = "Pass to rematerialize and merge parallel ops into consumers."; } +def RemoveIndexHintsPass : + InterfacePass<"iree-codegen-remove-index-hints", "mlir::FunctionOpInterface"> { + let summary = "Remove iree_codegen.index_hint operations"; + let description = [{ + This pass removes all iree_codegen.index_hint operations by replacing + them with their input values (pass-through semantics). + + Index hints are used to convey optimization information to downstream + passes and should be cleaned up once that information has been consumed. + }]; + let dependentDialects = [ + "IREE::Codegen::IREECodegenDialect" + ]; +} + def RemoveSingleIterationLoopPass : InterfacePass<"iree-codegen-remove-single-iteration-loop", "mlir::FunctionOpInterface"> { let summary = "Remove distributed loop with single iteration."; diff --git a/compiler/src/iree/compiler/Codegen/Common/PropagateConstantOffsets.cpp b/compiler/src/iree/compiler/Codegen/Common/PropagateConstantOffsets.cpp index a16b661dc5e1..dd4ad8651fae 100644 --- a/compiler/src/iree/compiler/Codegen/Common/PropagateConstantOffsets.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/PropagateConstantOffsets.cpp @@ -105,8 +105,9 @@ struct FoldApplySymbolOrDimSum final : OpRewritePattern { replacements.reserve(map.getNumInputs()); int64_t numDims = map.getNumDims(); auto getCurrExpr = [&](int64_t i) -> AffineExpr { - if (i >= numDims) + if (i >= numDims) { return rewriter.getAffineSymbolExpr(i - numDims); + } return rewriter.getAffineDimExpr(i); }; bool didReplace = false; @@ -157,8 +158,9 @@ struct PropagateConstantAddsThroughLinearize final int64_t runningOffset = 0; Value zero = nullptr; auto getZero = [&]() { - if (zero) + if (zero) { return zero; + } zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); return zero; }; @@ -252,8 +254,9 @@ struct FoldDivisibleConstantMulsIntoLinearize final SmallVector newStaticBasis; Value zero = nullptr; auto getZero = [&]() { - if (zero) + if (zero) { return zero; + } zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); return zero; }; diff --git a/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp b/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp index febda9283240..e06fb170c6af 100644 --- a/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp @@ -7,6 +7,7 @@ #include "iree/compiler/Codegen/Common/Transforms.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" +#include "iree/compiler/Codegen/Dialect/Codegen/Transforms/Transforms.h" #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Codegen/Utils/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -32,11 +33,13 @@ getExpandedShape(SmallVector reIndices, SmallVectorImpl &expandedShape, SmallVectorImpl &totalInnerSizes) { auto destType = dyn_cast(dest.getType()); - if (!destType) + if (!destType) { return failure(); + } // TODO (nirvedhmeshram): Support rank reducing parallel_insert_slice. - if (reIndices.size() != destType.getShape().size()) + if (reIndices.size() != destType.getShape().size()) { return failure(); + } // Iterator to insert outer sizes. auto outerShapeIdx = 0; for (auto [reassociations, destSize] : @@ -57,13 +60,15 @@ getExpandedShape(SmallVector reIndices, for (int64_t reasociation : llvm::drop_begin(reassociations)) { int64_t expandedInnerSize = sliceStaticSizes[reasociation]; // It is not safe to do this pattern if inner dimensions are dynamic. - if (ShapedType::isDynamic(expandedInnerSize)) + if (ShapedType::isDynamic(expandedInnerSize)) { return failure(); + } expandedShape.push_back(expandedInnerSize); totalInnerSize *= expandedInnerSize; } - if (destSize % totalInnerSize != 0) + if (destSize % totalInnerSize != 0) { return failure(); + } totalInnerSizes.push_back(totalInnerSize); // insert the outer size in front of any inner sizes. expandedShape.insert(expandedShape.begin() + outerShapeIdx, @@ -87,20 +92,26 @@ static LogicalResult verifyAndCollectExpandableUsers( continue; } auto extractSliceOp = dyn_cast(user); - if (!extractSliceOp) + if (!extractSliceOp) { return failure(); - if (extractSliceOp.getMixedSizes() != parallelInsertOp.getMixedSizes()) + } + if (extractSliceOp.getMixedSizes() != parallelInsertOp.getMixedSizes()) { return failure(); - if (extractSliceOp.getMixedOffsets() != parallelInsertOp.getMixedOffsets()) + } + if (extractSliceOp.getMixedOffsets() != + parallelInsertOp.getMixedOffsets()) { return failure(); + } for (Operation *user : extractSliceOp->getUsers()) { auto expandShapeOp = dyn_cast(user); - if (!expandShapeOp) + if (!expandShapeOp) { return failure(); + } SmallVector expandReIndices = expandShapeOp.getReassociationIndices(); - if (reIndices != expandReIndices) + if (reIndices != expandReIndices) { return failure(); + } } expandableUsers.push_back(extractSliceOp); } @@ -186,8 +197,9 @@ struct ExpandDestinationForallOp final auto collapseOp = parallelInsertOp.getSource().getDefiningOp(); // No collapse op to hoist out. - if (!collapseOp) + if (!collapseOp) { return failure(); + } // Ignore trivially foldable collapse ops. if (collapseOp.getSrcType().getRank() == @@ -203,8 +215,9 @@ struct ExpandDestinationForallOp final int64_t tiedResultIdx = tiedResult.getResultNumber(); auto forallOp = dyn_cast(tiedResult.getOwner()); - if (!forallOp) + if (!forallOp) { return failure(); + } SmallVector expandedDestShape; SmallVector totalInnerSizes; @@ -226,16 +239,19 @@ struct ExpandDestinationForallOp final auto storeOp = dyn_cast(foralluser); if (storeOp && isFullSlice(storeOp, storeOp.getTargetType(), - storeOp.getTargetDims())) + storeOp.getTargetDims())) { continue; + } auto storeToBufferOp = dyn_cast(foralluser); - if (!storeToBufferOp) + if (!storeToBufferOp) { return failure(); + } MemRefType bufferType = storeToBufferOp.getBuffer().getType(); if (failed(memref::ExpandShapeOp::computeExpandedType( - bufferType, expandedDestShape, reIndices))) + bufferType, expandedDestShape, reIndices))) { return failure(); + } } // This allows us to assume that the extract/inserts in the loop are @@ -412,6 +428,8 @@ void PropagateReshapesByExpansionPass::runOnOperation() { }; linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns, bubbleUpExpansionControlFn); + IREE::Codegen::populateFoldReshapeOpsByExpansionPatterns( + bubbleExpandShapePatterns, bubbleUpExpansionControlFn); // Add patterns to do some additional cleanup (on top of canonicalizations // that can be done later) of reshape ops. tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns); diff --git a/compiler/src/iree/compiler/Codegen/Common/RemoveIndexHints.cpp b/compiler/src/iree/compiler/Codegen/Common/RemoveIndexHints.cpp new file mode 100644 index 000000000000..27d5bf87f91d --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/RemoveIndexHints.cpp @@ -0,0 +1,41 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_REMOVEINDEXHINTSPASS +#include "iree/compiler/Codegen/Common/Passes.h.inc" + +namespace { + +/// Pass to remove all iree_codegen.index_hint operations by replacing them +/// with their input values. +struct RemoveIndexHintsPass final + : impl::RemoveIndexHintsPassBase { + void runOnOperation() override { + FunctionOpInterface funcOp = getOperation(); + IRRewriter rewriter(funcOp.getContext()); + + SmallVector indexHintOps; + funcOp.walk([&](IREE::Codegen::IndexHintOp hintOp) { + indexHintOps.push_back(hintOp); + }); + + for (auto hintOp : indexHintOps) { + hintOp.getResult().replaceAllUsesWith(hintOp.getInput()); + rewriter.eraseOp(hintOp); + } + } +}; + +} // namespace + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/ReshapePatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/ReshapePatterns.cpp index d53084d66cea..a1b82815e72f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ReshapePatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ReshapePatterns.cpp @@ -99,8 +99,9 @@ struct FoldCollapseShapeIntoInterfaceTensorLoad auto reshapeSrcType = cast(reshapeSrc.getType()); auto loadOp = reshapeSrc.getDefiningOp(); - if (!loadOp) + if (!loadOp) { return failure(); + } // Make sure we are loading the full incoming subspan. Otherwise we cannot // simply adjust the subspan's resultant type later. @@ -110,8 +111,9 @@ struct FoldCollapseShapeIntoInterfaceTensorLoad auto subspanOp = loadOp.getSource() .getDefiningOp(); - if (!subspanOp) + if (!subspanOp) { return failure(); + } OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(subspanOp); @@ -200,8 +202,9 @@ struct FoldExpandShapeIntoInterfaceTensorLoad auto subspanOp = loadOp.getSource() .getDefiningOp(); - if (!subspanOp) + if (!subspanOp) { return failure(); + } OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(subspanOp); @@ -305,8 +308,9 @@ struct FoldExpandShapeIntoInterfaceTensorStore auto subspanOp = storeOp.getTarget() .getDefiningOp(); - if (!subspanOp) + if (!subspanOp) { return failure(); + } OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(subspanOp); diff --git a/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp b/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp index bc2f6aef0c0c..b6f5994755e5 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -157,6 +158,22 @@ static void swizzleGatherToLDS(RewriterBase &rewriter, }); } +static LogicalResult +verifyFlatContiguousSwizzleHintOp(IREE::Codegen::SwizzleHintOp hintOp) { + auto memrefType = cast(hintOp.getOperand().getType()); + // Swizzle hints require flat (rank 1) memrefs. + // For rank 1, allow dynamic memrefs or static contiguous row-major memrefs. + if ((memrefType.getRank() != 1 || !memrefType.getLayout().isIdentity()) || + (memrefType.hasStaticShape() && + !memref::isStaticShapeAndContiguousRowMajor(memrefType))) { + hintOp.emitError() + << "swizzle hint operand must be a contiguous flat memref, got " + << hintOp.getOperand().getType(); + return failure(); + } + return success(); +} + /// Resolves all hints. Walks all direct users and splits them into loads and /// stores. If any user is not a swizzle-able load or store, bail out and /// silently drop the optimization hint. @@ -189,7 +206,7 @@ static void resolveHintOp(RewriterBase &rewriter, } if (auto gatherToLDSOp = dyn_cast(user)) { // Ignore swizzleHint on Dst Operand. Gather_to_lds writes elements of a - // subgroup contiguously in order of lane ID + // subgroup contiguously in order of lane ID. if (gatherToLDSOp.getDst() == hintOp) { continue; } @@ -242,6 +259,9 @@ void ResolveSwizzleHintsPass::runOnOperation() { // silently pass through for that hint. IRRewriter rewriter(funcOp->getContext()); for (IREE::Codegen::SwizzleHintOp hintOp : hintOps) { + if (failed(verifyFlatContiguousSwizzleHintOp(hintOp))) { + return signalPassFailure(); + } resolveHintOp(rewriter, hintOp); } diff --git a/compiler/src/iree/compiler/Codegen/Common/StripCompilationInfoPass.cpp b/compiler/src/iree/compiler/Codegen/Common/StripCompilationInfoPass.cpp index b742e844888f..66b088db3684 100644 --- a/compiler/src/iree/compiler/Codegen/Common/StripCompilationInfoPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/StripCompilationInfoPass.cpp @@ -21,8 +21,9 @@ struct StripFuncOpTranslationInfo final using OpInterfaceRewritePattern::OpInterfaceRewritePattern; LogicalResult matchAndRewrite(mlir::FunctionOpInterface funcOp, PatternRewriter &rewriter) const final { - if (!getTranslationInfo(funcOp)) + if (!getTranslationInfo(funcOp)) { return failure(); + } rewriter.modifyOpInPlace(funcOp, [&]() { // If the function has translation info, erase it. @@ -38,8 +39,9 @@ struct StripLinalgOpCompilationInfo final using OpInterfaceRewritePattern::OpInterfaceRewritePattern; LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, PatternRewriter &rewriter) const final { - if (!getCompilationInfo(linalgOp) && !getLoweringConfig(linalgOp)) + if (!getCompilationInfo(linalgOp) && !getLoweringConfig(linalgOp)) { return failure(); + } rewriter.modifyOpInPlace(linalgOp, [&]() { if (getCompilationInfo(linalgOp)) { // Erase the compilation info configuration if it exists. diff --git a/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp index b3c49aebb466..b29265471b81 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp @@ -138,8 +138,9 @@ static void updateTensorDimInfo( auto resultType = cast(result.getType()); int dimOperandIndex = 0; for (auto [index, shape] : llvm::enumerate(resultType.getShape())) { - if (ShapedType::isStatic(shape)) + if (ShapedType::isStatic(shape)) { continue; + } updateTensorDimInfo(result, index, dimOperands[dimOperandIndex++], solver, divisibilityInfo, rangeInfo); } @@ -185,8 +186,9 @@ static void updateTensorDimInfo( LLVM_DEBUG({ for (auto [resultIndex, result] : llvm::enumerate(op->getResults())) { auto tensorType = dyn_cast(result.getType()); - if (!tensorType) + if (!tensorType) { continue; + } for (auto index : llvm::seq(0, tensorType.getRank())) { std::optional range; std::optional divisibility; diff --git a/compiler/src/iree/compiler/Codegen/Common/TensorToVectorVectorizePad.cpp b/compiler/src/iree/compiler/Codegen/Common/TensorToVectorVectorizePad.cpp index 03672231dfd6..1f41ff706aa5 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TensorToVectorVectorizePad.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TensorToVectorVectorizePad.cpp @@ -34,8 +34,9 @@ static Value getAsIndexValue(OpFoldResult attrOrValue, OpBuilder &builder, Location loc) { IntegerAttr attr; if (Value val = dyn_cast(attrOrValue)) { - if (val.getType().isIndex()) + if (val.getType().isIndex()) { return val; + } matchPattern(val, m_Constant(&attr)); } else { attr = cast(cast(attrOrValue)); @@ -84,13 +85,15 @@ struct VectorizePadWithConditions final PatternRewriter &rewriter) const override { // Static result shape is needed to reading padded dimensions in an // unrolled manner. - if (!padOp.getType().hasStaticShape()) + if (!padOp.getType().hasStaticShape()) { return failure(); + } // Only support constant padding value cases. Value paddingValue = padOp.getConstantPaddingValue(); - if (!paddingValue) + if (!paddingValue) { return failure(); + } Attribute paddingAttr; if (!matchPattern(paddingValue, m_Constant(&paddingAttr))) { return failure(); @@ -127,8 +130,9 @@ struct VectorizePadWithConditions final SmallVector paddedDimLBs(tensorRank); SmallVector paddedDimUBs(tensorRank); for (int i = 0; i < tensorRank; ++i) { - if (isConstantZero(lowPads[i]) && isConstantZero(highPads[i])) + if (isConstantZero(lowPads[i]) && isConstantZero(highPads[i])) { continue; + } paddedDimIndices.push_back(i); auto srcDimSize = @@ -147,8 +151,9 @@ struct VectorizePadWithConditions final loc, SplatElementsAttr::get(fullVectorType, {paddingAttr})); auto sliceVectorShape = llvm::to_vector(paddedTensorShape); - for (int dim : paddedDimIndices) + for (int dim : paddedDimIndices) { sliceVectorShape[dim] = 1; + } auto sliceVectorType = VectorType::get(dropLeadingOne(sliceVectorShape), elementType); Value cstSliceVector = rewriter.createOrFold( @@ -157,8 +162,9 @@ struct VectorizePadWithConditions final // Calculate the total count of all padded dimensions. We need to generate // vector read ops with scf.if guards for each of them. int totalCount = 1; - for (int dim : paddedDimIndices) + for (int dim : paddedDimIndices) { totalCount *= paddedTensorShape[dim]; + } auto zeroIndex = rewriter.createOrFold(loc, 0); auto trueAttr = rewriter.getBoolAttr(true); diff --git a/compiler/src/iree/compiler/Codegen/Common/TestExecutablePreprocessing.cpp b/compiler/src/iree/compiler/Codegen/Common/TestExecutablePreprocessing.cpp index 217d20ba1040..2779c7ff4486 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TestExecutablePreprocessing.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TestExecutablePreprocessing.cpp @@ -29,8 +29,9 @@ struct TestExecutablePreprocessingPass final // whatever it needed to the executable instead. getOperation()->walk([&](IREE::HAL::ExecutableVariantOp variantOp) { auto configAttr = variantOp.getTarget().getConfiguration(); - if (!configAttr) + if (!configAttr) { return; + } auto replacementAttr = configAttr.getAs("replace_i64"); if (!replacementAttr) { // Skip variants that don't request modification. diff --git a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp index fe52ed733687..59c230de7494 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp @@ -106,8 +106,9 @@ getTileAndDistributeConfig(ArrayRef computeOps, partitionableLoopsSet.insert(partitionableLoops.begin(), partitionableLoops.end()); for (auto loopId : llvm::seq(0, tileSizes.size())) { - if (partitionableLoopsSet.count(loopId)) + if (partitionableLoopsSet.count(loopId)) { continue; + } tileSizes[loopId] = 0; } @@ -181,10 +182,12 @@ static LogicalResult lowerDispatchWorkgroupCountForDagRootOp( // slowest varying. SmallVector numWorkgroups; for (auto partitionedLoop : llvm::reverse(partitionedLoops)) { - if (partitionedLoop >= tileSizes.size()) + if (partitionedLoop >= tileSizes.size()) { continue; - if (isZeroInteger(tileSizes[partitionedLoop])) + } + if (isZeroInteger(tileSizes[partitionedLoop])) { continue; + } Value numTileAlongDim = getValueOrCreateConstantIndexOp( rewriter, loc, numTiles[partitionedLoop]); if (numWorkgroups.size() == maxWorkgroupParallelDims) { diff --git a/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp b/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp index 126edd72b9c3..1a4a5a40c943 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp @@ -35,23 +35,26 @@ void fuseProducersOfSlices(RewriterBase &rewriter, auto fusableProducer = candidateSlice.getSource().getDefiningOp(); - if (!fusableProducer) + if (!fusableProducer) { continue; + } std::optional controlFnResult = options.fusionControlFn(candidateSlice, cast(candidateSlice.getSource()), /*destinationInitArg=*/false); - if (!controlFnResult) + if (!controlFnResult) { continue; + } // The operands of the fused producer might themselves be slices of // values produced by operations that implement the `TilingInterface`. // Add these operations to the worklist. std::optional fusedResult = scf::tileAndFuseProducerOfSlice(rewriter, candidateSlice, loops); - if (!fusedResult) + if (!fusedResult) { continue; + } for (auto newSlice : fusedResult->generatedSlices) { worklist.push(newSlice); @@ -70,8 +73,9 @@ void collectTiledAndFusedOps(Operation *rootOp, for (OpOperand &operand : current->getOpOperands()) { Operation *producer = operand.get().getDefiningOp(); if (!producer || !isa(producer) || - result.count(producer)) + result.count(producer)) { continue; + } worklist.push_back(producer); result.insert(producer); } @@ -181,10 +185,11 @@ fuseConsumersIntoForall(RewriterBase &rewriter, ArrayRef tiledOps, // list of slices to handle. Otherwise, insert it into the right // position based on dominance. auto *it = llvm::lower_bound(candidates, entry, comp); - if (it != candidates.end() && it->fusableUser == fusableUser) + if (it != candidates.end() && it->fusableUser == fusableUser) { *it = std::move(entry); - else + } else { candidates.insert(it, std::move(entry)); + } } } } @@ -250,15 +255,17 @@ collectTiledAndFusedOps(Operation *op, Operation *current = worklist.pop_back_val(); for (OpOperand &operand : current->getOpOperands()) { auto producer = operand.get().getDefiningOp(); - if (!producer || ops.contains(producer) || exclude.contains(producer)) + if (!producer || ops.contains(producer) || exclude.contains(producer)) { continue; + } worklist.push_back(producer); ops.insert(producer); } for (auto user : current->getUsers()) { auto consumer = dyn_cast(user); - if (!consumer || ops.contains(consumer) || exclude.contains(consumer)) + if (!consumer || ops.contains(consumer) || exclude.contains(consumer)) { continue; + } worklist.push_back(consumer); ops.insert(consumer); } @@ -374,8 +381,9 @@ LogicalResult applyTileAndFuseToEachRoot( // We dont want this for reduction tiling as it can lead to large tensors // being yielded. if (tilingLevel != IREE::GPU::TilingLevel::Reduction && - tilingLevel != IREE::GPU::TilingLevel::PartialReduction) + tilingLevel != IREE::GPU::TilingLevel::PartialReduction) { yieldProducerReplacement = yieldReplacementsFor.contains(owner); + } bool shouldFuse = false; if (auto tilingOwner = dyn_cast(owner)) { shouldFuse = !payloadOps.contains(tilingOwner); @@ -440,7 +448,7 @@ LogicalResult applyTileAndFuseToEachRoot( SmallVector opsToReplace{tilingInterfaceOp}; llvm::append_range(opsToReplace, tiledResults->fusedProducers); for (Operation *toReplace : opsToReplace) { - for (OpResult res : toReplace->getResults()) + for (OpResult res : toReplace->getResults()) { if (auto replacement = tiledResults->replacements.lookup(res)) { Operation *replacementOp = replacement.getDefiningOp(); rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) { @@ -448,6 +456,7 @@ LogicalResult applyTileAndFuseToEachRoot( return dominanceInfo.properlyDominates(replacementOp, user); }); } + } if (toReplace->use_empty()) { rewriter.eraseOp(toReplace); diff --git a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp index f3d156cc37c7..6e7005a67a12 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp @@ -216,8 +216,9 @@ static bool verifyComputeOpsAfterDistribution(FunctionOpInterface funcOp) { /// for the DPS `user`. Returns false if the user is not in DPS. static bool isUsedAsInit(Operation *producer, Operation *user) { auto dpsIface = dyn_cast(user); - if (!dpsIface) + if (!dpsIface) { return false; + } ValueRange results = producer->getResults(); return llvm::any_of(dpsIface.getDpsInits(), [&](Value operand) { return llvm::is_contained(results, operand); @@ -251,8 +252,9 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() { // an init of a DPS op, the user currently cannot be fused. Having a // replacement for it would attempt fusion and fail, so avoid such cases. if (llvm::any_of(op->getUsers(), [&](Operation *user) { - if (isUsedAsInit(op, user)) + if (isUsedAsInit(op, user)) { return false; + } return dominanceInfo.properlyDominates(tilableOp, user) || !tiledAndFusedOps.contains(user); })) { diff --git a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp index cb9718857383..d3aec86e681f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp @@ -36,14 +36,16 @@ static SmallVector fillInterchangeVector(ArrayRef interchangeVector, size_t iterationDomainSize) { SmallVector filledVector; - for (auto v : interchangeVector) + for (auto v : interchangeVector) { filledVector.push_back(v); + } if (filledVector.size() < iterationDomainSize) { auto range = llvm::seq(filledVector.size(), iterationDomainSize); filledVector.append(range.begin(), range.end()); } - if (filledVector.size() > iterationDomainSize) + if (filledVector.size() > iterationDomainSize) { filledVector.resize(iterationDomainSize); + } return filledVector; } @@ -208,8 +210,9 @@ static LogicalResult replaceStoresWithTiledVersion( storeOps.push_back(storeOp); } } - if (storeOps.empty()) + if (storeOps.empty()) { return success(); + } if (storeOps.size() != 1) { return rewriter.notifyMatchFailure(untiledValue.getOwner(), "expected a single store for the op"); @@ -398,9 +401,10 @@ tileDispatchUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, }); // 4. Generate the tiled implementation within the inner most loop. - if (!tilingResult.loops.empty()) + if (!tilingResult.loops.empty()) { rewriter.setInsertionPoint( tilingResult.loops.back().getBody()->getTerminator()); + } FailureOr tiledImplementation = op.getTiledImplementation(rewriter, offsets, sizes); if (failed(tiledImplementation)) { @@ -480,8 +484,9 @@ getAllFusableProducerUses(Operation *untiledOp, for (auto tiledOp : llvm::reverse(tiledOps)) { for (OpOperand &operand : llvm::reverse(tiledOp->getOpOperands())) { auto sliceOp = operand.get().getDefiningOp(); - if (!sliceOp || sliceOp.getSource().getDefiningOp() != untiledOp) + if (!sliceOp || sliceOp.getSource().getDefiningOp() != untiledOp) { continue; + } sliceOps.push_back(sliceOp); } } @@ -572,8 +577,9 @@ struct SwapExtractSliceWithDispatchTensorLoad PatternRewriter &rewriter) const override { auto loadOp = sliceOp.getSource() .getDefiningOp(); - if (!loadOp) + if (!loadOp) { return failure(); + } SmallVector combinedOffsets, combinedSizes, combinedStrides; if (failed(affine::mergeOffsetsSizesAndStrides( @@ -602,8 +608,9 @@ struct SwapExtractSliceWithTensorEmpty LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override { auto emptyTensorOp = sliceOp.getSource().getDefiningOp(); - if (!emptyTensorOp) + if (!emptyTensorOp) { return failure(); + } SmallVector mixedSizes = sliceOp.getMixedSizes(); if (mixedSizes.size() != sliceOp.getType().getRank()) { @@ -611,8 +618,9 @@ struct SwapExtractSliceWithTensorEmpty rankReducedMixedSizes.reserve(sliceOp.getType().getRank()); auto droppedDims = sliceOp.getDroppedDims(); for (auto [index, size] : llvm::enumerate(mixedSizes)) { - if (droppedDims.test(index)) + if (droppedDims.test(index)) { continue; + } rankReducedMixedSizes.push_back(size); } std::swap(mixedSizes, rankReducedMixedSizes); diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp index fa95a8151c3e..b11f9f0e962e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp @@ -69,8 +69,9 @@ class TransformDialectInterpreterPass final } if (failed(transform::applyTransformNamedSequence( payloadRoot, transformEntryPoint, transformModule, - options.enableExpensiveChecks(true)))) + options.enableExpensiveChecks(true)))) { return signalPassFailure(); + } } }; } // namespace diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel index 60f8016aa9e3..1b6a4d0a58ce 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "CommonExtensionsOps.td", ], diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp index 1cb8615d66e6..ac26d8146758 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp @@ -88,8 +88,9 @@ void mlir::iree_compiler::registerTransformDialectCommonExtension( //===---------------------------------------------------------------------===// static void addOperands(Operation *op, SetVector &operandSet) { - if (!op) + if (!op) { return; + } TypeSwitch(op) .Case([&](linalg::LinalgOp linalgOp) { SmallVector inputOperands = linalgOp.getDpsInputs(); @@ -103,12 +104,14 @@ static void addOperands(Operation *op, SetVector &operandSet) { template static bool setFusedOpOperandLimit(OpOperand *fusedOperand) { Operation *producer = fusedOperand->get().getDefiningOp(); - if (!producer) + if (!producer) { return false; + } Operation *consumer = fusedOperand->getOwner(); SetVector fusedOpOperands; - if (producer->getNumResults() != 1) + if (producer->getNumResults() != 1) { return false; + } addOperands(consumer, fusedOpOperands); fusedOpOperands.remove(producer->getResult(0)); addOperands(producer, fusedOpOperands); @@ -148,8 +151,9 @@ void transform_dialect::ApplyUnrollVectorsGpuMmaSyncPatternsOp:: populatePatterns(RewritePatternSet &patterns) { auto unrollOrder = [](Operation *op) -> std::optional> { auto contract = dyn_cast(op); - if (!contract) + if (!contract) { return std::nullopt; + } return mlir::iree_compiler::gpuMmaUnrollOrder(contract); }; vector::populateVectorUnrollPatterns( @@ -171,8 +175,9 @@ void transform_dialect::ApplyUnrollVectorsGpuWmmaSyncPatternsOp:: populatePatterns(RewritePatternSet &patterns) { auto unrollOrder = [](Operation *op) -> std::optional> { auto contract = dyn_cast(op); - if (!contract) + if (!contract) { return std::nullopt; + } return mlir::iree_compiler::gpuMmaUnrollOrder(contract); }; vector::populateVectorUnrollPatterns( @@ -280,8 +285,9 @@ static bool isAscendingRelativeMapping(ArrayRef mapping) { static FailureOr flattenForallOp(RewriterBase &rewriter, scf::ForallOp forallOp) { - if (!forallOp.getMapping().has_value()) + if (!forallOp.getMapping().has_value()) { return forallOp->emitError("mapping must be present"); + } SmallVector mapping = llvm::to_vector(forallOp.getMapping()->getValue()); if (!(llvm::all_of(mapping, llvm::IsaPred) || @@ -403,20 +409,21 @@ static LogicalResult rewriteForallToWorkgroup(RewriterBase &rewriter, Attribute bX = gpu::GPUBlockMappingAttr::get(ctx, gpu::MappingId::DimX); Attribute bY = gpu::GPUBlockMappingAttr::get(ctx, gpu::MappingId::DimY); Attribute bZ = gpu::GPUBlockMappingAttr::get(ctx, gpu::MappingId::DimZ); - if (forallOp.getNumResults() > 0) + if (forallOp.getNumResults() > 0) { return forallOp->emitError( "only bufferized scf.forall lowers to workgroup"); - if (forallOp.getRank() > 3) + } + if (forallOp.getRank() > 3) { return forallOp->emitError( "scf.forall with rank > 3 does not lower to workgroup"); + } - if (!forallOp.getMapping().has_value()) + if (!forallOp.getMapping().has_value()) { return forallOp->emitError("mapping must be present"); + } SmallVector blockMapping = llvm::to_vector(forallOp.getMapping()->getValue()); - if (llvm::any_of(blockMapping, [](Attribute map) { - return !isa(map); - })) { + if (!llvm::all_of(blockMapping, llvm::IsaPred)) { return forallOp->emitError("mapping must be #gpu.block"); } @@ -492,10 +499,12 @@ DiagnosedSilenceableFailure transform_dialect::ForallToWorkgroupOp::applyToOne( scf::ForallOp topLevelForallOp; auto walkResult = target->walk([&](scf::ForallOp forallOp) { - if (forallOp->getParentOfType()) + if (forallOp->getParentOfType()) { return WalkResult::advance(); - if (topLevelForallOp) + } + if (topLevelForallOp) { return WalkResult::interrupt(); + } topLevelForallOp = forallOp; return WalkResult::advance(); }); @@ -506,8 +515,9 @@ DiagnosedSilenceableFailure transform_dialect::ForallToWorkgroupOp::applyToOne( } rewriter.setInsertionPoint(topLevelForallOp); - if (failed(rewriteForallToWorkgroup(rewriter, topLevelForallOp))) + if (failed(rewriteForallToWorkgroup(rewriter, topLevelForallOp))) { return mlir::emitDefiniteFailure(target, "rewriteForallToWorkgroup failed"); + } return DiagnosedSilenceableFailure::success(); } @@ -531,29 +541,34 @@ transform_dialect::GpuDistributeSharedMemoryCopyOp::applyToOne( // Look for ops that move to workgroup memory and mark as copies for // distribution. target.walk([&](linalg::GenericOp copyOp) { - if (copyOp.getNumDpsInputs() != 1 || copyOp.getNumDpsInits() != 1) + if (copyOp.getNumDpsInputs() != 1 || copyOp.getNumDpsInits() != 1) { return; + } auto dest = dyn_cast>(copyOp.getDpsInitOperand(0)->get()); - if (!dest) + if (!dest) { return; + } MemRefType destType = dest.getType(); // Check if the only operation in the possible copy op region is a // terminator. Block &body = copyOp.getRegion().front(); - if (!std::begin(body)->hasTrait()) + if (!std::begin(body)->hasTrait()) { return; + } auto destSpace = dyn_cast_if_present(destType.getMemorySpace()); - if (!destSpace) + if (!destSpace) { return; + } // The destination space must be shared memory. - if (destSpace.getValue() != gpu::GPUDialect::getWorkgroupAddressSpace()) + if (destSpace.getValue() != gpu::GPUDialect::getWorkgroupAddressSpace()) { return; + } // Mark this copy operation as a copy to workgroup memory. setMarker(copyOp, getCopyToWorkgroupMemoryMarker()); @@ -682,8 +697,9 @@ transform_dialect::IREEApplyLoopIndependentCodeMotionOp::applyToOne( // Do not hoist from scf.forall ops. These capture isolated computations // that will be mapped to a certain level in the GPU hierarchy (e.g., // GPU blocks), so hoisting is not desired. - if (!isa(loopLike.getOperation())) + if (!isa(loopLike.getOperation())) { moveLoopInvariantCode(loopLike); + } }); // For now, put single loop promotion as part of licm. Underlying // implementations perform splice operations which shouldn't need @@ -803,16 +819,18 @@ static LogicalResult gpuComprehensiveBufferizeCopyFn(OpBuilder &builder, hasSharedMemoryAddressSpace(cast(to.getType()))) { needsBarrier = true; } - if (needsBarrier) + if (needsBarrier) { gpu::BarrierOp::create(builder, loc); + } // TODO: ideally we should use linalg.copy which was recently reintroduced // as an OpDSL named op. However, IREE-specific patterns to cleanup spurious // post-bufferization copies do not trigger properly. // So we keep using `createLinalgCopyOp` which builds a GenericOp. // linalg::CopyOp::create(builder, loc, from, to); mlir::iree_compiler::createLinalgCopyOp(builder, loc, from, to); - if (needsBarrier) + if (needsBarrier) { gpu::BarrierOp::create(builder, loc); + } return success(); } @@ -889,8 +907,9 @@ DiagnosedSilenceableFailure transform_dialect::IREEBufferizeOp::apply( return mlir::emitDefiniteFailure(target, "greedy pattern application failed"); } - if (listener.failed()) + if (listener.failed()) { return listener.checkAndResetError(); + } } // 2. Run one-shot-bufferize, without the pass baggage. @@ -933,9 +952,10 @@ transform_dialect::IREEEliminateEmptyTensorsOp::applyToOne( ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state) { if (failed( - eliminateEmptyTensors(rewriter, target, getBufferizationOptions()))) + eliminateEmptyTensors(rewriter, target, getBufferizationOptions()))) { return emitDefaultDefiniteFailure(target) << "failed to eliminate tensor.empty ops"; + } return DiagnosedSilenceableFailure::success(); } @@ -983,8 +1003,9 @@ transform_dialect::ShareForallOperandsOp::applyToOne( llvm::to_vector(llvm::seq(0, forallOp.getOutputs().size())); } for (int64_t outputIdx : getShareOperands()) { - if (outputIdx < 0 || outputIdx >= forallOp.getOutputs().size()) + if (outputIdx < 0 || outputIdx >= forallOp.getOutputs().size()) { return mlir::emitDefiniteFailure(forallOp, "operand idx overflow"); + } Value toShare = forallOp.getOutputs()[outputIdx]; if (std::distance(toShare.getUses().begin(), toShare.getUses().end()) != 2) { @@ -997,8 +1018,9 @@ transform_dialect::ShareForallOperandsOp::applyToOne( tensor::ExtractSliceOp extractSliceOp; for (Operation *user : toShare.getUsers()) { extractSliceOp = dyn_cast(user); - if (extractSliceOp) + if (extractSliceOp) { break; + } } if (!extractSliceOp) { /*return mlir::emitSilenceableFailure( @@ -1013,10 +1035,12 @@ transform_dialect::ShareForallOperandsOp::applyToOne( // (i.e., same source/target, offsets, sizes and strides). auto isMatchingParallelInsertSlice = [&](Operation &op) { auto insertSlice = dyn_cast(&op); - if (!insertSlice) + if (!insertSlice) { return false; - if (insertSlice.getDest() != bbArg) + } + if (insertSlice.getDest() != bbArg) { return false; + } return llvm::equal(insertSlice.getMixedOffsets(), extractSliceOp.getMixedOffsets()) && llvm::equal(insertSlice.getMixedSizes(), @@ -1115,8 +1139,9 @@ applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp, FailureOr fuseConsumerResults = scf::tileAndFuseConsumerOfSlices(rewriter, target, loops); - if (failed(fuseConsumerResults)) + if (failed(fuseConsumerResults)) { return failure(); + } // Report back the relevant handles to the transform op. originalConsumerOps.push_back( diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp index 92ae83c6c9cb..945ffdffb3a3 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp @@ -7,6 +7,7 @@ #include "iree/compiler/Codegen/Common/Transforms.h" #include "iree/compiler/Codegen/Common/CombineLayoutTransformation.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallVectorExtras.h" #include "mlir/Analysis/SliceAnalysis.h" @@ -228,8 +229,9 @@ swapExpandShapeWithSlice(RewriterBase &rewriter, auto isZeroOffsetAndFullSize = [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) { - if (!isZeroInteger(offset)) + if (!isZeroInteger(offset)) { return false; + } FailureOr maybeEqual = ValueBoundsConstraintSet::areEqual(sliceSize, size); return llvm::succeeded(maybeEqual) && maybeEqual.value(); @@ -274,8 +276,9 @@ swapExpandShapeWithSlice(RewriterBase &rewriter, // Offset = cumulative product of leading unit extracted dims. for (; i < e; ++i) { int64_t expandedDim = indices[i]; - if (!isOneInteger(sizes[expandedDim])) + if (!isOneInteger(sizes[expandedDim])) { break; + } basis.push_back(outputShape[expandedDim]); delinOffsets.push_back(offsets[expandedDim]); @@ -718,8 +721,9 @@ swapCollapseShapeWithSlice(RewriterBase &rewriter, for (; idx < reassocGroupSize; ++idx) { int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; - if (currentCollapsedsize < expandedShapeSize) + if (currentCollapsedsize < expandedShapeSize) { break; + } // We need to make sure that the slice size can be set to the shape size // and the offset to 0. @@ -817,4 +821,23 @@ void populateSwapExtractWithCollapsePattern(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } +namespace { + +struct RemoveOptimizationBarrier final + : public OpRewritePattern { + using Base::Base; + + LogicalResult matchAndRewrite(IREE::Util::OptimizationBarrierOp barrierOp, + PatternRewriter &rewriter) const override { + rewriter.replaceOp(barrierOp, barrierOp.getOperands()); + return success(); + } +}; + +} // namespace + +void populateRemoveOptimizationBarrierPatterns(RewritePatternSet &patterns) { + patterns.insert(patterns.getContext()); +} + } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.h b/compiler/src/iree/compiler/Codegen/Common/Transforms.h index b6c0067857b5..413cfb17a584 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Transforms.h +++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.h @@ -211,6 +211,20 @@ void populateCombineRelayoutOpPatterns( /// Populate patterns to fuse tilable consumers of forall ops into it. void populateFuseTilableForallConsumersPattern(RewritePatternSet &patterns); +//===----------------------------------------------------------------------===// +// Utilities for iteration space expansion transformations +//===----------------------------------------------------------------------===// + +/// Helper struct to hold the expand/collapse shape ops created for dimension +/// expansion or blocking transformations. +struct ReshapeOps { + tensor::ExpandShapeOp expandShapeOp; + tensor::CollapseShapeOp collapseShapeOp; +}; + +/// Populate patterns to remove optimization barriers. +void populateRemoveOptimizationBarrierPatterns(RewritePatternSet &patterns); + } // namespace mlir::iree_compiler #endif // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMS_H_ diff --git a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp index 7ea7a24be18c..bbf0572abaef 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp @@ -62,8 +62,9 @@ namespace mlir::iree_compiler { static Value convertElementType(OpBuilder &b, Location loc, Type targetType, Value source) { Type sourceType = source.getType(); - if (sourceType == targetType) + if (sourceType == targetType) { return source; + } if (isa(sourceType) && isa(targetType)) { unsigned sourceBitWidth = sourceType.getIntOrFloatBitWidth(); unsigned destBitWidth = targetType.getIntOrFloatBitWidth(); @@ -82,8 +83,9 @@ static std::optional getLegalizedType(Type t) { if (auto shapedType = dyn_cast(t)) { std::optional legalizedElementType = legalizeStorageElementType(shapedType); - if (!legalizedElementType) + if (!legalizedElementType) { return std::nullopt; + } return RankedTensorType::get(shapedType.getShape(), legalizedElementType.value(), shapedType.getEncoding()); @@ -117,8 +119,9 @@ struct TypePropagationTypeConverter : public TypeConverter { TypePropagationTypeConverter() { addConversion([](Type t) { auto convertedType = getLegalizedType(t); - if (!convertedType) + if (!convertedType) { return t; + } return convertedType.value(); }); } diff --git a/compiler/src/iree/compiler/Codegen/Common/UserConfig.cpp b/compiler/src/iree/compiler/Codegen/Common/UserConfig.cpp index 2ce80a85b4f7..5e0dca7ccbb8 100644 --- a/compiler/src/iree/compiler/Codegen/Common/UserConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/UserConfig.cpp @@ -19,8 +19,9 @@ setUserConfig(mlir::FunctionOpInterface entryPointFn, Operation *computeOp, } auto info = compilationInfo.getTranslationInfo(); - if (failed(setTranslationInfo(entryPointFn, info))) + if (failed(setTranslationInfo(entryPointFn, info))) { return failure(); + } setLoweringConfig(computeOp, compilationInfo.getLoweringConfig()); eraseCompilationInfo(computeOp); diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel index 4c1ceff17ddd..af3c70e2ef28 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "add_fmfs.mlir", "affinemin_canonicalization.mlir", @@ -27,19 +28,17 @@ iree_lit_test_suite( "bufferize_dispatch_tensor_load_store.mlir", "canonicalize_early_bufferization_ops.mlir", "canonicalize_interface_load_store.mlir", - "forall_to_for.mlir", "check_for_config.mlir", "combine_layout_transformation.mlir", "convert_accgemm_to_gemm.mlir", - "convert_bf16_to_uint16_buffers.mlir", "convert_bf16_arith_to_f32.mlir", + "convert_bf16_to_uint16_buffers.mlir", "convert_hal_descriptor_type_to_gpu_address_space.mlir", "convert_to_destination_passing_style.mlir", "convert_unsupported_float_arith.mlir", "convert_workgroup_forall_to_pcf.mlir", "convolution_to_igemm.mlir", "convolutions.mlir", - "erase_dead_alloc_and_stores.mlir", "decompose_affine_ops.mlir", "decompose_boundary_pack_unpack_ops.mlir", "decompose_conv2d.mlir", @@ -49,6 +48,7 @@ iree_lit_test_suite( "decompose_softmax.mlir", "eliminate_empty_tensors.mlir", "emulate_narrow_type.mlir", + "erase_dead_alloc_and_stores.mlir", "erase_hal_descriptor_type.mlir", "extract_address_computation.mlir", "fission_transfer_ops_control_flow.mlir", @@ -60,14 +60,15 @@ iree_lit_test_suite( "fold_reshape_into_interface_tensor.mlir", "fold_split_reduction_workgroup_mapping_loops.mlir", "fold_tensor_extract_op.mlir", + "forall_to_for.mlir", "forop_canonicalization.mlir", "generic_vectorization.mlir", "hoist_statically_bound_allocations.mlir", "hoist_unrolled_vector_extract_insert_slice.mlir", "iree_codegen_canonicalize.mlir", "iree_comprehensive_bufferize.mlir", - "iree_expand_strided_metadata_with_subview_expansion.mlir", "iree_expand_strided_metadata.mlir", + "iree_expand_strided_metadata_with_subview_expansion.mlir", "iree_inject_assume_alignment.mlir", "iree_loop_invariant_code_motion.mlir", "link_tuning_specs.mlir", @@ -108,11 +109,12 @@ iree_lit_test_suite( "reductions.mlir", "rematerialize_parallel_ops.mlir", "remove_dead_allocs.mlir", + "remove_index_hints.mlir", "remove_single_iteration_loop.mlir", - "resolve_swizzle_hints.mlir", - "resolve_workgroup_count_hints.mlir", "repeated_matcher_use.mlir", "replace_slow_min_max_ops.mlir", + "resolve_swizzle_hints.mlir", + "resolve_workgroup_count_hints.mlir", "specialize_exports.mlir", "strip_compilation_info.mlir", "test_partitionable_loops_interface.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt index 2c095be422ef..0295c5c09fb0 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt @@ -104,6 +104,7 @@ iree_lit_test_suite( "reductions.mlir" "rematerialize_parallel_ops.mlir" "remove_dead_allocs.mlir" + "remove_index_hints.mlir" "remove_single_iteration_loop.mlir" "repeated_matcher_use.mlir" "replace_slow_min_max_ops.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir index 70ccb68955d8..467d72156786 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir @@ -459,3 +459,205 @@ func.func @no_swap_rank_reducing_slice(%arg0: tensor<3x6xi8>) -> tensor<3xi16> { // CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<3x6xi8> // CHECK-NEXT: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]] // CHECK-NEXT: iree_tensor_ext.bitcast %[[SLICE]] + +// ----- + +// Test propagating collapse_shape producer through inner_tiled op. +// Using proper 2D matmul indexing maps with MFMA_F32_16x16x16_F16 layout. +// Tensor shapes: LHS[outer_m, outer_k, 16, 16], RHS[outer_k, outer_n, 16, 16], ACC[outer_m, outer_n, 16, 16] +#contraction_accesses = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +func.func @propagate_collapse_through_inner_tiled( + %src: tensor<2x3x4x16x16xf16>, %rhs: tensor<4x2x16x16xf16>, %out: tensor<6x2x16x16xf32>) + -> tensor<6x2x16x16xf32> { + // Collapse the first two outer dims of LHS: [2,3] -> [6] + %collapsed = tensor.collapse_shape %src [[0, 1], [2], [3], [4]] + : tensor<2x3x4x16x16xf16> into tensor<6x4x16x16xf16> + %result = iree_codegen.inner_tiled ins(%collapsed, %rhs) outs(%out) { + indexing_maps = #contraction_accesses, + iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type], + kind = #iree_gpu.mma_layout, + permutations = [array, array, array], + semantics = #iree_gpu.mma_semantics + } : tensor<6x4x16x16xf16>, tensor<4x2x16x16xf16> into tensor<6x2x16x16xf32> + return %result : tensor<6x2x16x16xf32> +} + +// CHECK-LABEL: func @propagate_collapse_through_inner_tiled +// CHECK-SAME: %[[SRC:[A-Za-z0-9]+]]: tensor<2x3x4x16x16xf16> +// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<4x2x16x16xf16> +// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: tensor<6x2x16x16xf32> +// CHECK: %[[EXPANDED_OUT:.+]] = tensor.expand_shape %[[OUT]] {{\[}}[0, 1], [2], [3], [4]{{\]}} +// CHECK-SAME: : tensor<6x2x16x16xf32> into tensor<2x3x2x16x16xf32> +// CHECK: %[[INNER_TILED:.+]] = iree_codegen.inner_tiled +// CHECK-SAME: ins(%[[SRC]], %[[RHS]]) +// CHECK-SAME: outs(%[[EXPANDED_OUT]]) +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d2, d3)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>] +// CHECK-SAME: iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type] +// CHECK-SAME: : tensor<2x3x4x16x16xf16>, tensor<4x2x16x16xf16> into tensor<2x3x2x16x16xf32> +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[INNER_TILED]] {{\[}}[0, 1], [2], [3], [4]{{\]}} +// CHECK-SAME: : tensor<2x3x2x16x16xf32> into tensor<6x2x16x16xf32> +// CHECK: return %[[COLLAPSED]] + +// ----- + +// Test propagating expand_shape consumer through inner_tiled op. +#contraction_accesses2 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +func.func @propagate_expand_through_inner_tiled( + %lhs: tensor<6x4x16x16xf16>, %rhs: tensor<4x2x16x16xf16>, %out: tensor<6x2x16x16xf32>) + -> tensor<2x3x2x16x16xf32> { + %result = iree_codegen.inner_tiled ins(%lhs, %rhs) outs(%out) { + indexing_maps = #contraction_accesses2, + iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type], + kind = #iree_gpu.mma_layout, + permutations = [array, array, array], + semantics = #iree_gpu.mma_semantics + } : tensor<6x4x16x16xf16>, tensor<4x2x16x16xf16> into tensor<6x2x16x16xf32> + %expanded = tensor.expand_shape %result [[0, 1], [2], [3], [4]] + output_shape [2, 3, 2, 16, 16] : tensor<6x2x16x16xf32> into tensor<2x3x2x16x16xf32> + return %expanded : tensor<2x3x2x16x16xf32> +} + +// CHECK-LABEL: func @propagate_expand_through_inner_tiled +// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: tensor<6x4x16x16xf16> +// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<4x2x16x16xf16> +// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: tensor<6x2x16x16xf32> +// CHECK-DAG: %[[EXPANDED_OUT:.+]] = tensor.expand_shape %[[OUT]] {{\[}}[0, 1], [2], [3], [4]{{\]}} +// CHECK-SAME: : tensor<6x2x16x16xf32> into tensor<2x3x2x16x16xf32> +// CHECK-DAG: %[[EXPANDED_LHS:.+]] = tensor.expand_shape %[[LHS]] {{\[}}[0, 1], [2], [3], [4]{{\]}} +// CHECK-SAME: : tensor<6x4x16x16xf16> into tensor<2x3x4x16x16xf16> +// CHECK: %[[INNER_TILED:.+]] = iree_codegen.inner_tiled +// CHECK-SAME: ins(%[[EXPANDED_LHS]], %[[RHS]]) +// CHECK-SAME: outs(%[[EXPANDED_OUT]]) +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d3, d2)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>] +// CHECK-SAME: iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type] +// CHECK-SAME: : tensor<2x3x4x16x16xf16>, tensor<4x2x16x16xf16> into tensor<2x3x2x16x16xf32> +// CHECK: return %[[INNER_TILED]] + +// ----- + +// Test that reshape touching inner dimensions is NOT propagated. +#contraction_accesses3 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +func.func @no_propagate_inner_dim_reshape( + %src: tensor<6x4x16x2x8xf16>, %rhs: tensor<4x2x16x16xf16>, %out: tensor<6x2x16x16xf32>) + -> tensor<6x2x16x16xf32> { + // Collapsing inner dims [3,4] which are part of inner tile - should NOT propagate. + %collapsed = tensor.collapse_shape %src [[0], [1], [2], [3, 4]] + : tensor<6x4x16x2x8xf16> into tensor<6x4x16x16xf16> + %result = iree_codegen.inner_tiled ins(%collapsed, %rhs) outs(%out) { + indexing_maps = #contraction_accesses3, + iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type], + kind = #iree_gpu.mma_layout, + permutations = [array, array, array], + semantics = #iree_gpu.mma_semantics + } : tensor<6x4x16x16xf16>, tensor<4x2x16x16xf16> into tensor<6x2x16x16xf32> + return %result : tensor<6x2x16x16xf32> +} + +// CHECK-LABEL: func @no_propagate_inner_dim_reshape +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape +// CHECK: iree_codegen.inner_tiled ins(%[[COLLAPSED]], + +// ----- + +// Test propagating collapse_shape producer through inner_tiled op with dynamic outer shapes. +#contraction_accesses_dyn1 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +func.func @propagate_collapse_through_inner_tiled_dynamic( + %src: tensor, %rhs: tensor<4x2x16x16xf16>, %out: tensor) + -> tensor { + // Collapse the first two outer dims of LHS: [?, 3] -> [?*3] + %collapsed = tensor.collapse_shape %src [[0, 1], [2], [3], [4]] + : tensor into tensor + %result = iree_codegen.inner_tiled ins(%collapsed, %rhs) outs(%out) { + indexing_maps = #contraction_accesses_dyn1, + iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type], + kind = #iree_gpu.mma_layout, + permutations = [array, array, array], + semantics = #iree_gpu.mma_semantics + } : tensor, tensor<4x2x16x16xf16> into tensor + return %result : tensor +} + +// CHECK-LABEL: func @propagate_collapse_through_inner_tiled_dynamic +// CHECK-SAME: %[[SRC:[A-Za-z0-9]+]]: tensor +// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<4x2x16x16xf16> +// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: tensor +// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[SRC]], %c0 +// CHECK: %[[EXPANDED_OUT:.+]] = tensor.expand_shape %[[OUT]] {{\[}}[0, 1], [2], [3], [4]{{\]}} +// CHECK-SAME: output_shape [%[[DIM]], 3, 2, 16, 16] +// CHECK-SAME: : tensor into tensor +// CHECK: %[[INNER_TILED:.+]] = iree_codegen.inner_tiled +// CHECK-SAME: ins(%[[SRC]], %[[RHS]]) +// CHECK-SAME: outs(%[[EXPANDED_OUT]]) +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d2, d3)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>] +// CHECK-SAME: iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type] +// CHECK-SAME: : tensor, tensor<4x2x16x16xf16> into tensor +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[INNER_TILED]] {{\[}}[0, 1], [2], [3], [4]{{\]}} +// CHECK-SAME: : tensor into tensor +// CHECK: return %[[COLLAPSED]] + +// ----- + +// Test propagating expand_shape consumer through inner_tiled op with dynamic outer shapes. +#contraction_accesses_dyn2 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +func.func @propagate_expand_through_inner_tiled_dynamic( + %lhs: tensor, %rhs: tensor<4x2x16x16xf16>, %out: tensor, + %dyn_dim: index) + -> tensor { + %result = iree_codegen.inner_tiled ins(%lhs, %rhs) outs(%out) { + indexing_maps = #contraction_accesses_dyn2, + iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type], + kind = #iree_gpu.mma_layout, + permutations = [array, array, array], + semantics = #iree_gpu.mma_semantics + } : tensor, tensor<4x2x16x16xf16> into tensor + %expanded = tensor.expand_shape %result [[0, 1], [2], [3], [4]] + output_shape [%dyn_dim, 3, 2, 16, 16] : tensor into tensor + return %expanded : tensor +} + +// CHECK-LABEL: func @propagate_expand_through_inner_tiled_dynamic +// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: tensor +// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<4x2x16x16xf16> +// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: tensor +// CHECK-SAME: %[[DYN_DIM:[A-Za-z0-9]+]]: index +// CHECK-DAG: %[[EXPANDED_OUT:.+]] = tensor.expand_shape %[[OUT]] {{\[}}[0, 1], [2], [3], [4]{{\]}} +// CHECK-SAME: output_shape [%[[DYN_DIM]], 3, 2, 16, 16] +// CHECK-SAME: : tensor into tensor +// CHECK-DAG: %[[EXPANDED_LHS:.+]] = tensor.expand_shape %[[LHS]] {{\[}}[0, 1], [2], [3], [4]{{\]}} +// CHECK-SAME: output_shape [%[[DYN_DIM]], 3, 4, 16, 16] +// CHECK-SAME: : tensor into tensor +// CHECK: %[[INNER_TILED:.+]] = iree_codegen.inner_tiled +// CHECK-SAME: ins(%[[EXPANDED_LHS]], %[[RHS]]) +// CHECK-SAME: outs(%[[EXPANDED_OUT]]) +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d3, d2)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>] +// CHECK-SAME: iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type] +// CHECK-SAME: : tensor, tensor<4x2x16x16xf16> into tensor +// CHECK: return %[[INNER_TILED]] diff --git a/compiler/src/iree/compiler/Codegen/Common/test/remove_index_hints.mlir b/compiler/src/iree/compiler/Codegen/Common/test/remove_index_hints.mlir new file mode 100644 index 000000000000..198a2193df57 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/test/remove_index_hints.mlir @@ -0,0 +1,34 @@ +// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-codegen-remove-index-hints))' %s | FileCheck %s + +// Test: index_hint with lane_constant is removed. +// CHECK-LABEL: func.func @remove_lane_constant_hint +// CHECK-NOT: iree_codegen.index_hint +// CHECK: return %arg0 +func.func @remove_lane_constant_hint(%arg0: index) -> index { + %hint = iree_codegen.index_hint %arg0(#iree_gpu.lane_constant<16>) : index + return %hint : index +} + +// ----- + +// Test: index_hint with lane_increment is removed. +// CHECK-LABEL: func.func @remove_lane_increment_hint +// CHECK-NOT: iree_codegen.index_hint +// CHECK: return %arg0 +func.func @remove_lane_increment_hint(%arg0: index) -> index { + %hint = iree_codegen.index_hint %arg0(#iree_gpu.lane_increment<16>) : index + return %hint : index +} + +// ----- + +// Test: Multiple hints in sequence are all removed. +// CHECK-LABEL: func.func @remove_multiple_hints +// CHECK-NOT: iree_codegen.index_hint +// CHECK: arith.addi %arg0, %arg1 +func.func @remove_multiple_hints(%arg0: index, %arg1: index) -> index { + %hint0 = iree_codegen.index_hint %arg0(#iree_gpu.lane_constant<16>) : index + %hint1 = iree_codegen.index_hint %arg1(#iree_gpu.lane_increment<16>) : index + %sum = arith.addi %hint0, %hint1 : index + return %sum : index +} diff --git a/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir b/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir index 4a77f8d8d2f3..6b6b9030a3ba 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir @@ -322,3 +322,24 @@ func.func @swizzle_raw_buffer_to_lds_ignore_dst_op(%global : memref<32768xi8, #a // CHECK: %[[LDSOFFSET:.+]] = arith.constant 0 : index // CHECK: %[[LDS:.+]] = memref.alloc() : memref<32768xi8, #gpu.address_space> // CHECK: amdgpu.gather_to_lds %[[SRC]][%[[SWOFF]]], %[[LDS]][%[[LDSOFFSET]]] + +// ----- + +// Verify that swizzle_hint fails on non-flat (rank > 1) memrefs. +func.func @swizzle_hint_non_flat_memref_error(%src: memref<32x64xf32>) -> vector<4xf32> { + // expected-error @+1 {{swizzle hint operand must be a contiguous flat memref, got 'memref<32x64xf32>'}} + %0 = iree_codegen.swizzle_hint %src[#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32> + %offset = arith.constant 0 : index + %1 = vector.load %0[%offset, %offset] : memref<32x64xf32>, vector<4xf32> + return %1: vector<4xf32> +} + +// Verify that swizzle_hint fails on non-contiguous memrefs. +func.func @swizzle_hint_non_contiguous_memref_error() -> vector<4xf32> { + %src = memref.alloc() : memref<32x64xf32, strided<[2, 1], offset: 0>> + // expected-error @+1 {{swizzle hint operand must be a contiguous flat memref, got 'memref<32x64xf32, strided<[2, 1]>>'}} + %0 = iree_codegen.swizzle_hint %src[#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32, strided<[2, 1], offset: 0>> + %offset = arith.constant 0 : index + %1 = vector.load %0[%offset, %offset] : memref<32x64xf32, strided<[2, 1], offset: 0>>, vector<4xf32> + return %1: vector<4xf32> +} diff --git a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/BUILD.bazel index 676a4fd8c4b9..52c452d0804c 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/BUILD.bazel @@ -21,6 +21,7 @@ exports_files([ iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "IREECPUAttrs.td", "IREECPUDialect.td", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/test/BUILD.bazel index c7ebf7648fd1..9f62cb2c2724 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "invalid.mlir", "roundtrip.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel index 1612effd8a3d..84708c271ef0 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel @@ -25,6 +25,7 @@ exports_files([ iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "IREECodegenAttrs.td", "IREECodegenDialect.td", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp index 0a281311227f..c60a7a2b0a37 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp @@ -137,11 +137,13 @@ void LoweringConfigTilingLevelAttr::print(mlir::AsmPrinter &printer) const { [&](auto pair) { auto [tileSize, isScalable] = pair; // Wrap scalable sizes in square brackets. - if (isScalable) + if (isScalable) { printer << '['; + } printer << tileSize; - if (isScalable) + if (isScalable) { printer << ']'; + } }); } printer << ']'; @@ -163,8 +165,9 @@ Attribute LoweringConfigTilingLevelAttr::parse(mlir::AsmParser &parser, auto parseListOfSizes = [&](SmallVector *scalableFlags = nullptr, bool prefixChecked = false) -> FailureOr> { - if (!prefixChecked && parser.parseLSquare()) + if (!prefixChecked && parser.parseLSquare()) { return failure(); + } if (parser.parseOptionalRSquare().succeeded()) { // Empty list. return SmallVector(); @@ -177,15 +180,18 @@ Attribute LoweringConfigTilingLevelAttr::parse(mlir::AsmParser &parser, expectScalableSizes && parser.parseOptionalLSquare().succeeded(); int64_t size = 0; if (parser.parseInteger(size) || - (isScalable && parser.parseRSquare())) + (isScalable && parser.parseRSquare())) { return failure(); + } sizes.push_back(size); - if (scalableFlags) + if (scalableFlags) { scalableFlags->push_back(isScalable); + } return success(); }); - if (failed(listParse) || parser.parseRSquare()) + if (failed(listParse) || parser.parseRSquare()) { return failure(); + } return sizes; }; SmallVector scalableFlags; @@ -193,8 +199,9 @@ Attribute LoweringConfigTilingLevelAttr::parse(mlir::AsmParser &parser, // Case 1: Simple list of tile sizes, e.g.: // [0, [32], 16] auto tileSizes = parseListOfSizes(&scalableFlags, /*prefixChecked=*/true); - if (failed(tileSizes)) + if (failed(tileSizes)) { return {}; + } return parser.getChecked( loc, parser.getContext(), *tileSizes, ArrayRef{}, scalableFlags); @@ -202,15 +209,18 @@ Attribute LoweringConfigTilingLevelAttr::parse(mlir::AsmParser &parser, // Case 2: sizes and interchange, e.g.: // {sizes = [0, [32], 16], interchange = [0, 1, 2]} if (parser.parseLBrace() || parser.parseKeyword("sizes") || - parser.parseEqual()) + parser.parseEqual()) { return {}; + } auto tileSizes = parseListOfSizes(&scalableFlags); if (failed(tileSizes) || parser.parseComma() || - parser.parseKeyword("interchange") || parser.parseEqual()) + parser.parseKeyword("interchange") || parser.parseEqual()) { return {}; + } auto tileInterchange = parseListOfSizes(); - if (failed(tileInterchange) || parser.parseRBrace()) + if (failed(tileInterchange) || parser.parseRBrace()) { return {}; + } return parser.getChecked( loc, parser.getContext(), *tileSizes, *tileInterchange, scalableFlags); } @@ -218,8 +228,9 @@ Attribute LoweringConfigTilingLevelAttr::parse(mlir::AsmParser &parser, LogicalResult LoweringConfigTilingLevelAttr::verify( function_ref emitError, ArrayRef tileSizes, ArrayRef tileInterchange, ArrayRef scalableFlags) { - if (!scalableFlags.empty() && scalableFlags.size() != tileSizes.size()) + if (!scalableFlags.empty() && scalableFlags.size() != tileSizes.size()) { return emitError() << "scalable flags length does not match tile sizes"; + } return success(); } @@ -254,29 +265,33 @@ LoweringConfigAttr::get(MLIRContext *context, TileSizesListTypeRef tileSizes, TileSizesListType LoweringConfigAttr::getTileSizeVals() const { TileSizesListType tileSizes; - for (auto &level : getTilingLevels()) + for (auto &level : getTilingLevels()) { tileSizes.push_back(SmallVector(level.getSizes())); + } return tileSizes; } SmallVector LoweringConfigAttr::getTileSizeVals(unsigned level) const { auto levels = getTilingLevels(); - if (level >= levels.size()) + if (level >= levels.size()) { return {}; + } return SmallVector(levels[level].getSizes()); } ScalableTileFlagsListType LoweringConfigAttr::getScalableTileFlagVals() { ScalableTileFlagsListType scalableFlags; - for (auto &level : getTilingLevels()) + for (auto &level : getTilingLevels()) { scalableFlags.push_back(SmallVector(level.getScalableFlags())); + } return scalableFlags; } SmallVector LoweringConfigAttr::getScalableTileFlagVals(unsigned level) { auto levels = getTilingLevels(); - if (level >= levels.size()) + if (level >= levels.size()) { return {}; + } SmallVector scalableFlags(levels[level].getScalableFlags()); // Extend the scalable flags with `false` to match the length of the sizes. scalableFlags.resize(levels[level].getSizes().size()); @@ -286,8 +301,9 @@ SmallVector LoweringConfigAttr::getScalableTileFlagVals(unsigned level) { SmallVector LoweringConfigAttr::getTileInterchangeVals(unsigned level) const { auto levels = getTilingLevels(); - if (level >= levels.size()) + if (level >= levels.size()) { return {}; + } return SmallVector(levels[level].getInterchange()); } @@ -338,8 +354,9 @@ bool LoweringConfigAttr::hasWorkgroupTilingLevel() const { LogicalResult LoweringConfigAttr::verify(function_ref emitError, LoweringConfigTilingLevelsAttr levels) { - if (!levels) + if (!levels) { return emitError() << "missing lowering config levels"; + } return success(); } @@ -516,23 +533,27 @@ static OpFoldResult getMinimumConstantOffsetValue(OpBuilder &b, Location loc, OpFoldResult offset, int64_t rotationInvariant) { auto value = dyn_cast_if_present(offset); - if (!value) + if (!value) { return offset; + } auto add = value.getDefiningOp(); - if (!add) + if (!add) { return offset; + } llvm::APInt constant; - if (!matchPattern(add.getRhs(), m_ConstantInt(&constant))) + if (!matchPattern(add.getRhs(), m_ConstantInt(&constant))) { return offset; + } int64_t constantOffset = constant.getSExtValue(); int64_t baseMod = constantOffset % rotationInvariant; // Skip constructing the new apply if it's not needed (c < rotationInvariant). - if (baseMod == constantOffset) + if (baseMod == constantOffset) { return offset; + } Value modOffset = arith::ConstantIndexOp::create(b, loc, baseMod); // If the original add is nsw/nuw, then the new add must also be given we're @@ -798,14 +819,16 @@ void eraseTranslationInfo(FunctionOpInterface funcOp) { SmallVector getTileSizes(Operation *op, unsigned level) { IREE::Codegen::LoweringConfigAttrInterface configAttr = getLoweringConfig(op); - if (!configAttr) + if (!configAttr) { return {}; + } return configAttr.getStaticTilingLevelSizes(level, op); } SmallVector getTileSizes(OpBuilder &b, Operation *op, unsigned level) { IREE::Codegen::LoweringConfigAttrInterface configAttr = getLoweringConfig(op); - if (!configAttr) + if (!configAttr) { return {}; + } return llvm::map_to_vector(configAttr.getTilingLevelSizes(b, level, op), [&](OpFoldResult s) -> Value { return getValueOrCreateConstantIndexOp( @@ -856,8 +879,9 @@ bool hasRootOpInfo(Operation *op) { IREE::Codegen::UKernelProviderInterface getUKernelProviderFromTarget(DictionaryAttr dict) { - if (!dict) + if (!dict) { return {}; + } return dict.getAs( kUKernelProviderName); } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h index 1f2325226948..acc87a0cb1d0 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h @@ -103,8 +103,9 @@ template FailureOr getLoweringConfigCarryingOp(ArrayRef computeOps) { for (Operation *op : computeOps) { - if (getLoweringConfig(op)) + if (getLoweringConfig(op)) { return op; + } } return failure(); } @@ -117,8 +118,9 @@ getLoweringConfigCarryingOp(ArrayRef computeOps) { template FailureOr getFirstLoweringConfig(ArrayRef computeOps) { FailureOr op = getLoweringConfigCarryingOp(computeOps); - if (failed(op)) + if (failed(op)) { return failure(); + } return getLoweringConfig(*op); } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.cpp index 9c22481e8663..48b5b02f8dbb 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.cpp @@ -211,8 +211,9 @@ IREECodegenDialect::verifyOperationAttribute(Operation *op, } } - if (symbol != kTuningSpecEntrypointAttrName) + if (symbol != kTuningSpecEntrypointAttrName) { return success(); + } const std::string requiredByEntrypointMessage = " (required by '" + std::string(kTuningSpecEntrypointAttrName) + "')"; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td index bf63c8850e45..5fa6512c6433 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td @@ -619,6 +619,11 @@ def IREECodegen_AnySwizzleAttr : Attr; + def IREECodegen_UKernelProviderInterface : AttrInterface<"UKernelProviderInterface"> { let cppNamespace = "::mlir::iree_compiler::IREE::Codegen"; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.cpp index cf1e47557a39..98e6571b4cfc 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.cpp @@ -42,8 +42,9 @@ LogicalResult ExtractStridedMetadataOp::inferReturnTypes( ExtractStridedMetadataOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { auto sourceType = dyn_cast(adaptor.getSource().getType()); - if (!sourceType) + if (!sourceType) { return failure(); + } unsigned sourceRank = sourceType.getRank(); IndexType indexType = IndexType::get(context); @@ -55,8 +56,9 @@ LogicalResult ExtractStridedMetadataOp::inferReturnTypes( // Offset. inferredReturnTypes.push_back(indexType); // Sizes and strides. - for (unsigned i = 0; i < sourceRank * 2; ++i) + for (unsigned i = 0; i < sourceRank * 2; ++i) { inferredReturnTypes.push_back(indexType); + } return success(); } @@ -282,8 +284,9 @@ LogicalResult InnerTiledOp::verify() { SmallVector indexingMaps = getIndexingMapsArray(); // Verify that an indexing map was specified for each operand. - if (indexingMaps.size() != expectedNumIns + expectedNumOuts) + if (indexingMaps.size() != expectedNumIns + expectedNumOuts) { return emitOpError("expected an indexing map for each operand"); + } // Verify that each index map has 'numIterators' inputs, no symbols, and // that the number of map outputs equals the rank of its associated @@ -292,9 +295,10 @@ LogicalResult InnerTiledOp::verify() { for (const auto &it : llvm::enumerate(indexingMaps)) { auto index = it.index(); auto map = it.value(); - if (map.getNumSymbols() != 0) + if (map.getNumSymbols() != 0) { return emitOpError("expected indexing map ") << index << " to have no symbols"; + } auto shapedType = opTypes[index]; unsigned rank = shapedType.getRank(); // Verify that the map has the right number of inputs, outputs, and indices. @@ -370,9 +374,11 @@ LogicalResult InnerTiledOp::verify() { } static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) { - for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) - if (targetExpr == map.getResult(i)) + for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { + if (targetExpr == map.getResult(i)) { return i; + } + } return -1; } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.td index 8c2bdab9cff7..a59e341f9657 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.td @@ -173,9 +173,9 @@ def IREECodegen_SwizzleHintOp : Op:$operand, + let arguments = (ins AnyRankedTensorOrMemRef:$operand, IREECodegen_AnySwizzleAttr:$swizzle); - let results = (outs RankedTensorOrMemRefOf<[AnyType], [1]>:$result); + let results = (outs AnyRankedTensorOrMemRef:$result); let assemblyFormat = [{ $operand `[` $swizzle attr-dict `]` `:` type($result) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.cpp index e9c4c751a0cd..d437d0771740 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.cpp @@ -329,8 +329,9 @@ struct UKernelOpsBufferizationInterface SmallVector nonTensorResultValues; for (OpResult result : op->getResults()) { Type resultType = result.getType(); - if (isa(resultType)) + if (isa(resultType)) { continue; + } nonTensorResultTypes.push_back(resultType); nonTensorResultValues.push_back(result); } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/BUILD.bazel index 588d9dea767c..53856548c216 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "invalid.mlir", "lowering_config_attr.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/BUILD.bazel new file mode 100644 index 000000000000..69c2f3cf089d --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/BUILD.bazel @@ -0,0 +1,35 @@ +# Copyright 2026 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_compiler_cc_library( + name = "IREECodegenTransforms", + srcs = [ + "ReshapeFusion.cpp", + ], + hdrs = [ + "Transforms.h", + ], + deps = [ + "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DialectUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:LinalgTransforms", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/CMakeLists.txt new file mode 100644 index 000000000000..05bdbb4f53f4 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/CMakeLists.txt @@ -0,0 +1,33 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/BUILD.bazel # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_cc_library( + NAME + IREECodegenTransforms + HDRS + "Transforms.h" + SRCS + "ReshapeFusion.cpp" + DEPS + LLVMSupport + MLIRIR + MLIRLinalgDialect + MLIRLinalgTransforms + MLIRSupport + MLIRTensorDialect + MLIRTransformUtils + MLIRTransforms + iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect + PUBLIC +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/ReshapeFusion.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/ReshapeFusion.cpp new file mode 100644 index 000000000000..f216e054b92a --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/ReshapeFusion.cpp @@ -0,0 +1,315 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" +#include "iree/compiler/Codegen/Dialect/Codegen/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir::iree_compiler::IREE::Codegen { + +//===----------------------------------------------------------------------===// +// Utility Functions +//===----------------------------------------------------------------------===// + +namespace { + +/// Check if an InnerTiledOp can be expanded by propagating a reshape through +/// it. The main real condition is that the inner dimensions of the op are not +/// expanded. Otherwise, we artificially restrict to single result inner_tiled +/// ops for now. +static LogicalResult +canExpandInnerTiledOp(InnerTiledOp op, OpOperand *fusedOperand, + ArrayRef reassociation) { + // Only single result inner_tiled ops are tested or used anywhere, so restrict + // to single result for now. + if (op->getNumResults() != 1) { + return failure(); + } + + // Only outer dims can be expanded because inner dims depend on the `kind` + // attribute's implementation. + int64_t outerRank = + op.getIndexingMapsArray()[fusedOperand->getOperandNumber()] + .getNumResults(); + if (llvm::any_of(reassociation.drop_front(outerRank), + [](ArrayRef group) { return group.size() != 1; })) { + return failure(); + } + return success(); +} + +/// Expand an InnerTiledOp by propagating a reshape through it. +/// `fusedOperand` is the operand connected to the reshape. +/// `reassociation` describes how the collapsed dims map to expanded dims. +/// `expandedShape` is the full expanded shape (outer + inner dims). +/// `expandedValue` is the expanded value to replace the fused operand. +/// `outputReassociations` will be cleared and filled with the reassociation +/// indices for each output, to be used for collapsing the result back to its +/// original shape. +/// The outer dimensions of the InnerTiledOp are expected to not be expanded, +/// which is enforced by the canExpandInnerTiledOp precondition. +static InnerTiledOp expandInnerTiledOp( + InnerTiledOp op, OpOperand *fusedOperand, + ArrayRef reassociation, + ArrayRef expandedShape, Value expandedValue, + SmallVectorImpl> &outputReassociations, + PatternRewriter &rewriter) { + assert(reassociation.size() == + cast(fusedOperand->get().getType()).getRank() && + "expected reassociation rank to match fused operand rank"); + + // Build mapping: iterDim -> list of (expandedIterDim, size). + SmallVector indexingMaps = op.getIndexingMapsArray(); + AffineMap fusedMap = indexingMaps[fusedOperand->getOperandNumber()]; + int64_t numIterDims = fusedMap.getNumDims(); + SmallVector>> iterDimExpansion( + numIterDims); + int64_t expandedDimCounter = 0; + for (auto [resultIdx, expr] : llvm::enumerate(fusedMap.getResults())) { + int64_t iterDim = cast(expr).getPosition(); + for (int64_t expandedOperandIdx : reassociation[resultIdx]) { + iterDimExpansion[iterDim].push_back( + {expandedDimCounter++, expandedShape[expandedOperandIdx]}); + } + } + // Iteration dims outside the fused map's results are independent from the + // expansion, but update their dim position to account for earlier expanded + // dims. Get iteration domain to query sizes of dims not in the fused operand. + SmallVector iterationDomain = op.getIterationDomain(rewriter); + for (int64_t i = 0; i < numIterDims; ++i) { + if (iterDimExpansion[i].empty()) { + iterDimExpansion[i].push_back( + {expandedDimCounter++, iterationDomain[i].size}); + } + } + + SmallVector newIndexingMaps; + SmallVector newOperands; + outputReassociations.clear(); + Location loc = op.getLoc(); + for (OpOperand &operand : op->getOpOperands()) { + AffineMap origMap = indexingMaps[operand.getOperandNumber()]; + auto operandType = cast(operand.get().getType()); + int64_t operandOuterRank = origMap.getNumResults(); + int64_t innerRank = operandType.getRank() - operandOuterRank; + SmallVector newMapResults; + SmallVector operandReassoc; + SmallVector expandedOperandSizes; + int64_t dimCounter = 0; + for (AffineExpr expr : origMap.getResults()) { + int64_t iterDim = cast(expr).getPosition(); + ReassociationIndices group; + for (auto [expandedDim, size] : iterDimExpansion[iterDim]) { + newMapResults.push_back(getAffineDimExpr(expandedDim, op.getContext())); + group.push_back(dimCounter++); + expandedOperandSizes.push_back(size); + } + operandReassoc.push_back(group); + } + // Inner dims are never expanded. + for (int64_t i = 0; i < innerRank; ++i) { + operandReassoc.push_back({dimCounter++}); + expandedOperandSizes.push_back(tensor::getMixedSize( + rewriter, loc, operand.get(), operandOuterRank + i)); + } + newIndexingMaps.push_back( + AffineMap::get(expandedDimCounter, 0, newMapResults, op.getContext())); + + // Store output reassociations for later use. + if (operand.getOperandNumber() >= op.getNumInputs()) { + outputReassociations.push_back(operandReassoc); + } + + if (&operand == fusedOperand) { + newOperands.push_back(expandedValue); + continue; + } + + if (llvm::all_of(operandReassoc, [](ArrayRef group) { + return group.size() == 1; + })) { + newOperands.push_back(operand.get()); + continue; + } + + SmallVector staticShape; + std::tie(staticShape, std::ignore) = + decomposeMixedValues(expandedOperandSizes); + auto expandedType = + RankedTensorType::get(staticShape, operandType.getElementType()); + newOperands.push_back(tensor::ExpandShapeOp::create( + rewriter, loc, expandedType, operand.get(), operandReassoc, + expandedOperandSizes)); + } + + // Expand iterator types. + SmallVector newIterTypes; + for (auto [idx, iterType] : llvm::enumerate(op.getIteratorTypesArray())) { + newIterTypes.append(iterDimExpansion[idx].size(), iterType); + } + + int64_t numInputs = op.getNumInputs(); + SmallVector newInputs(newOperands.begin(), + newOperands.begin() + numInputs); + SmallVector newOutputs(newOperands.begin() + numInputs, + newOperands.end()); + + // Permutations are unchanged, since they are for inner dims, but we need to + // convert from ArrayAttr to SmallVector>. + std::optional>> newPermutations; + if (auto permAttr = op.getPermutations()) { + newPermutations = llvm::map_to_vector( + permAttr->getAsRange(), [](DenseI64ArrayAttr perm) { + return SmallVector(perm.asArrayRef()); + }); + } + + return InnerTiledOp::create(rewriter, loc, newInputs, newOutputs, + newIndexingMaps, newIterTypes, op.getKind(), + op.getSemantics(), newPermutations); +} + +//===----------------------------------------------------------------------===// +// Patterns +//===----------------------------------------------------------------------===// + +/// Pattern to propagate a tensor::CollapseShapeOp through a consumer +/// InnerTiledOp. The collapsed dimensions must not include any inner dimensions +/// of the InnerTiledOp. +/// +/// Example: +/// %collapsed = tensor.collapse_shape %src [[0, 1], ...] +/// %result = inner_tiled ins(%collapsed, ...) outs(%out) +/// => +/// %expanded_out = tensor.expand_shape %out [[0, 1], ...] +/// %result = inner_tiled ins(%src, ...) outs(%expanded_out) +/// %collapsed_result = tensor.collapse_shape %result [[0, 1], ...] +struct FoldProducerCollapseShapeWithInnerTiled + : public OpRewritePattern { + FoldProducerCollapseShapeWithInnerTiled(MLIRContext *context, + linalg::ControlFusionFn controlFn, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFn(std::move(controlFn)) {} + + LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseOp, + PatternRewriter &rewriter) const override { + if (!collapseOp->hasOneUse()) { + return failure(); + } + OpOperand &use = *collapseOp->use_begin(); + auto innerTiledOp = dyn_cast(use.getOwner()); + if (!innerTiledOp || !controlFn(&use)) { + return failure(); + } + if (failed(canExpandInnerTiledOp(innerTiledOp, &use, + collapseOp.getReassociationIndices()))) { + return failure(); + } + + SmallVector expandedShape = tensor::getMixedSizes( + rewriter, collapseOp.getLoc(), collapseOp.getSrc()); + SmallVector> outputReassociations; + InnerTiledOp expandedOp = expandInnerTiledOp( + innerTiledOp, &use, collapseOp.getReassociationIndices(), expandedShape, + collapseOp.getSrc(), outputReassociations, rewriter); + + SmallVector results; + for (auto [idx, result] : llvm::enumerate(expandedOp.getResults())) { + auto resultType = + cast(innerTiledOp.getResultTypes()[idx]); + results.push_back(tensor::CollapseShapeOp::create( + rewriter, innerTiledOp.getLoc(), resultType, result, + outputReassociations[idx])); + } + rewriter.replaceOp(innerTiledOp, results); + return success(); + } + +private: + linalg::ControlFusionFn controlFn; +}; + +/// Pattern to propagate a tensor::ExpandShapeOp consumer back through an +/// InnerTiledOp producer. The expanded dimensions must not include any inner +/// dimensions of the InnerTiledOp. +/// +/// Example: +/// %result = inner_tiled ins(%lhs, ...) outs(%out) +/// %expanded = tensor.expand_shape %result [[0, 1], ...] +/// => +/// %expanded_lhs = tensor.expand_shape %lhs [[0, 1], ...] +/// %expanded_out = tensor.expand_shape %out [[0, 1], ...] +/// %result = inner_tiled ins(%expanded_lhs, ...) outs(%expanded_out) +struct FoldConsumerExpandShapeWithInnerTiled + : public OpRewritePattern { + FoldConsumerExpandShapeWithInnerTiled(MLIRContext *context, + linalg::ControlFusionFn controlFn, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFn(std::move(controlFn)) {} + + LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp, + PatternRewriter &rewriter) const override { + auto producerResult = dyn_cast(expandOp.getSrc()); + if (!producerResult) { + return failure(); + } + auto innerTiledOp = dyn_cast(producerResult.getOwner()); + if (!innerTiledOp || !controlFn(&expandOp.getSrcMutable())) { + return failure(); + } + + int64_t resultIdx = producerResult.getResultNumber(); + OpOperand *outputOperand = innerTiledOp.getDpsInitOperand(resultIdx); + if (failed(canExpandInnerTiledOp(innerTiledOp, outputOperand, + expandOp.getReassociationIndices()))) { + return failure(); + } + + // The DPS init will be expanded in the same way as the result, so insert + // the expand_shape on the init first in order to reuse the + // expandInnerTiledOp transformation utility. + SmallVector expandedShape = expandOp.getMixedOutputShape(); + SmallVector staticShape; + std::tie(staticShape, std::ignore) = decomposeMixedValues(expandedShape); + auto sourceType = cast(outputOperand->get().getType()); + auto expandedType = + RankedTensorType::get(staticShape, sourceType.getElementType()); + auto expandedInit = tensor::ExpandShapeOp::create( + rewriter, expandOp.getLoc(), expandedType, outputOperand->get(), + expandOp.getReassociationIndices(), expandedShape); + + SmallVector> outputReassociations; + InnerTiledOp expandedOp = expandInnerTiledOp( + innerTiledOp, outputOperand, expandOp.getReassociationIndices(), + expandedShape, expandedInit, outputReassociations, rewriter); + rewriter.replaceOp(expandOp, expandedOp.getResult(resultIdx)); + return success(); + } + +private: + linalg::ControlFusionFn controlFn; +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Populate Functions +//===----------------------------------------------------------------------===// + +void populateFoldReshapeOpsByExpansionPatterns( + RewritePatternSet &patterns, + const linalg::ControlFusionFn &controlFoldingReshapes) { + patterns.add(patterns.getContext(), + controlFoldingReshapes); +} + +} // namespace mlir::iree_compiler::IREE::Codegen diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/Transforms.h new file mode 100644 index 000000000000..ef83658add96 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/Transforms.h @@ -0,0 +1,28 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_TRANSFORMS_TRANSFORMS_H_ +#define IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_TRANSFORMS_TRANSFORMS_H_ + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir::iree_compiler::IREE::Codegen { + +//===----------------------------------------------------------------------===// +// Populate functions. +//===----------------------------------------------------------------------===// + +/// Populate patterns to propagate reshapes by expansion. This folds +/// tensor.expand_shape and tensor.collapse_shape ops with their producer +/// and consumer operations respectively. +void populateFoldReshapeOpsByExpansionPatterns( + RewritePatternSet &patterns, + const linalg::ControlFusionFn &controlFoldingReshapes); + +} // namespace mlir::iree_compiler::IREE::Codegen + +#endif // IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_TRANSFORMS_TRANSFORMS_H_ diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp index 0ac641b33520..bb17a14f64d8 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp @@ -69,13 +69,15 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, static llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const ScalableTileFlags &scalableTileFlags) { - if (scalableTileFlags.empty()) + if (scalableTileFlags.empty()) { return os; + } os << "scalableTiles = ["; for (unsigned i = 0; i < scalableTileFlags.size(); ++i) { os << (scalableTileFlags[i] ? "true" : "false"); - if (i + 1 < scalableTileFlags.size()) + if (i + 1 < scalableTileFlags.size()) { os << ", "; + } } return os; } @@ -279,8 +281,9 @@ deserializeEncodingInfo(DictionaryAttr attr) { } if (attr.contains("scalableTiles")) { auto value = attr.getNamed("scalableTiles"); - if (!value || !isa(value->getValue())) + if (!value || !isa(value->getValue())) { return std::nullopt; + } ScalableTileFlags res = llvm::map_to_vector( cast(value->getValue()), [](Attribute a) { return cast(a).getValue(); }); diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/BUILD.bazel index e82cc074eb70..aa1a4f229316 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/BUILD.bazel @@ -23,6 +23,7 @@ exports_files([ iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "IREEGPUAttrs.td", "IREEGPUDialect.td", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.cpp index 8a2ed597c344..cd55d18c0efb 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.cpp @@ -177,8 +177,9 @@ SmallVector deriveThreadTileSizes(Operation *op) { .Case( [&](IREE::LinalgExt::MapScatterOp scatterOp) -> SmallVector { ShapedType inputType = scatterOp.getInputType(); - if (!inputType.hasStaticShape()) + if (!inputType.hasStaticShape()) { return {}; + } ArrayRef loopBounds = inputType.getShape(); int64_t elemBits = inputType.getElementTypeBitWidth(); int64_t vectorSize = kPreferredCopyNumBits / elemBits; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.cpp index a04af998e191..b8c27fd9cf06 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.cpp @@ -151,4 +151,11 @@ std::optional> getPaddingList(LoweringConfigAttr config, return getIntegerVector(array); } +constexpr StringLiteral kDimensionExpansionName = "expand_dims"; + +DimensionExpansionAttr getDimensionExpansion(LoweringConfigAttr config) { + return config.getAttributes().getAs( + kDimensionExpansionName); +} + } // namespace mlir::iree_compiler::IREE::GPU diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h index ad175556116d..e5b8b4730715 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h @@ -62,6 +62,9 @@ IREE::GPU::LoweringConfigAttr setPromotedOperandsList( std::optional> getPaddingList(LoweringConfigAttr config, bool paddingConv = false); +/// Helper to retrieve dimension expansion config from lowering config. +DimensionExpansionAttr getDimensionExpansion(LoweringConfigAttr config); + } // namespace mlir::iree_compiler::IREE::GPU #endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_IR_GPULOWERINGCONFIGUTILS_H_ diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index 9965b093eff2..888df6b06d0f 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -7,11 +7,14 @@ #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h" #include "iree/compiler/Dialect/LinalgExt/Utils/MatchUtils.h" +#include "iree/compiler/Utils/EncodingUtils.h" #include "iree/compiler/Utils/Indexing.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLForwardCompat.h" @@ -699,8 +702,9 @@ Attribute MMAAttr::getDistributionMappingKind() const { OpFoldResult MMAAttr::getDistributionWorkerCount(OpBuilder &, Location, Operation *) const { - if (!getDistributionMappingKind()) + if (!getDistributionMappingKind()) { return OpFoldResult(); + } return getAsIndexOpFoldResult(getContext(), getSubgroupSize()); } @@ -780,6 +784,59 @@ MMAAttr::buildUnderlyingOperations(OpBuilder &builder, Location loc, return failure(); } +/// Creates index_hint ops wrapping delinearized lane ID values. +/// The `delinearizedLaneId` values come from delinearizing the lane ID using +/// `basis`, with the innermost/fastest-varying dimension last. +/// +/// Non-final indices get lane_constant hints (uniform across lane groups). +/// The final index gets lane_increment hint (increments within lane group). +/// The group size is derived from the innermost basis element. +/// Indices with a unit basis are ignored, and given a lane_constant hint. +static SmallVector +createTransposeLoadIndexHint(OpBuilder &builder, Location loc, + ValueRange delinearizedLaneId, + ArrayRef basis) { + // Need at least 2 dimensions for transpose load pattern. + if (delinearizedLaneId.size() < 2) { + return SmallVector(delinearizedLaneId.begin(), + delinearizedLaneId.end()); + } + + // Find the index of the innermost non-unit (> 1) basis element. + // This determines which result gets the lane-increment hint. + // Size-1 dimensions produce constant 0 outputs regardless of lane ID, + // so they don't contribute to the meaningful group structure. + int64_t groupSize = 1; + size_t incrementResultIdx = delinearizedLaneId.size() - 1; + // The delinearized indices could have N or N + 1 results, and the basis + // elements are aligned with the last N results, so iterate backwards + // together. + for (size_t i = 1; i <= basis.size(); ++i) { + groupSize = basis[basis.size() - i]; + incrementResultIdx = delinearizedLaneId.size() - i; + if (groupSize > 1) { + break; + } + } + + auto laneConstantAttr = + IREE::GPU::LaneConstantAttr::get(builder.getContext(), groupSize); + auto laneIncrementAttr = IREE::GPU::LaneIncrementAttr::get( + builder.getContext(), groupSize, /*step=*/1); + + SmallVector results; + for (auto [i, value] : llvm::enumerate(delinearizedLaneId)) { + // The result corresponding to innermost non-unit basis gets lane-increment; + // all other results get lane-constant hints. + Attribute hint = (i == incrementResultIdx) ? Attribute(laneIncrementAttr) + : Attribute(laneConstantAttr); + auto hintOp = IREE::Codegen::IndexHintOp::create(builder, loc, value, hint); + results.push_back(hintOp.getResult()); + } + + return results; +} + static LogicalResult populateCanonicalOffsetsSizesAndStrides( OpBuilder &builder, Location loc, Value laneId, ArrayRef permutation, MMASingleSubgroupLayout subgroupLayout, @@ -817,6 +874,12 @@ static LogicalResult populateCanonicalOffsetsSizesAndStrides( auto splitLaneId = affine::AffineDelinearizeIndexOp::create( builder, loc, laneId, vtidBasis, /*hasOuterBound=*/false); + // Wrap delinearize results with index_hint ops for transpose load. + // The delinearize results are already in the correct order + // (innermost/fastest-varying dimension is last). + SmallVector hintedSplitLaneId = createTransposeLoadIndexHint( + builder, loc, splitLaneId.getResults(), vtidBasis); + // Each thread grabs `element` contiguous data, so the vtid needs to be // multiplied by `element` to get the next bunch of data. // vtid: virtual thread id @@ -828,7 +891,7 @@ static LogicalResult populateCanonicalOffsetsSizesAndStrides( // worsen the generated code quality. for (auto [splitResultIdx, element] : llvm::zip_equal(dimToVtid, subgroupLayout.element)) { - Value vtid = splitLaneId.getResult(splitResultIdx); + Value vtid = hintedSplitLaneId[splitResultIdx]; int64_t vtidLen = vtidBasis[splitResultIdx - 1]; if (element != 1) { vtid = affine::AffineLinearizeIndexOp::create( @@ -1831,8 +1894,9 @@ DataTiledScaledMMAAttr::verifyIndexingMaps(ArrayRef maps) const { std::optional TargetAttr::getCUDAComputeCapability() const { StringRef arch = getArch(); - if (!arch.starts_with("sm_")) + if (!arch.starts_with("sm_")) { return false; + } APInt version; if (arch.substr(3).getAsInteger(10, version)) { return false; @@ -1843,14 +1907,16 @@ std::optional TargetAttr::getCUDAComputeCapability() const { bool TargetAttr::supportsTF32InputMMAOps() const { // TODO: scan the list of MMA ops to decude after plumbing through support // for NVIDIA TensorCore MMA ops. - if (auto cc = getCUDAComputeCapability()) + if (auto cc = getCUDAComputeCapability()) { return cc >= 80; + } return false; } bool TargetAttr::supportsSyncMMAOps() const { - if (auto cc = getCUDAComputeCapability()) + if (auto cc = getCUDAComputeCapability()) { return cc >= 80; + } return false; } @@ -1983,8 +2049,9 @@ getLoopBounds(ArrayRef loopRanges, for (auto [loopRange, givenTileSize] : llvm::zip_equal(loopRanges, givenTileSizes)) { // No loop if the tile size is 0. - if (isZeroInteger(givenTileSize)) + if (isZeroInteger(givenTileSize)) { continue; + } lbs.push_back(loopRange.offset); ubs.push_back(loopRange.size); steps.push_back(givenTileSize); @@ -2275,6 +2342,15 @@ Value PromoteWithCacheSwizzleAttr::promoteOperand( return cacheSwizzlePromotionImpl(builder, operand, getCopyConfig()); } +//===----------------------------------------------------------------------===// +// SwizzleOperandAttr +//===----------------------------------------------------------------------===// + +Value SwizzleOperandAttr::promoteOperand(mlir::OpBuilder &builder, + mlir::OpOperand &operand) const { + return swizzlePromotionImpl(builder, operand, getCopyConfig(), getSwizzle()); +} + //===----------------------------------------------------------------------===// // LaneIdAttr //===----------------------------------------------------------------------===// @@ -2309,6 +2385,89 @@ GPUPipelineOptionsAttr GPUPipelineOptionsAttr::get( } //===----------------------------------------------------------------------===// +// DimensionExpansionAttr +//===----------------------------------------------------------------------===// + +DimensionExpansionAttr +DimensionExpansionAttr::get(MLIRContext *context, + ArrayRef reassociations, + ArrayRef outputShape) { + Builder b(context); + SmallVector reassociationAttrs; + for (const ReassociationIndices &indices : reassociations) { + SmallVector indexAttrs; + for (int64_t idx : indices) { + indexAttrs.push_back(b.getI64IntegerAttr(idx)); + } + reassociationAttrs.push_back(b.getArrayAttr(indexAttrs)); + } + ArrayAttr reassociationAttr = b.getArrayAttr(reassociationAttrs); + DenseI64ArrayAttr outputShapeAttr = b.getDenseI64ArrayAttr(outputShape); + return get(context, reassociationAttr, outputShapeAttr); +} + +LogicalResult +DimensionExpansionAttr::verify(function_ref emitError, + ArrayAttr reassociations, + DenseI64ArrayAttr outputShape) { + if (reassociations.empty()) { + return emitError() << "reassociations cannot be empty"; + } + + int64_t nextExpected = 0; + + for (auto [groupIdx, attr] : llvm::enumerate(reassociations)) { + auto indexArray = dyn_cast(attr); + if (!indexArray) { + return emitError() << "reassociation at index " << groupIdx + << " must be an array"; + } + + if (indexArray.empty()) { + return emitError() << "reassociation group " << groupIdx + << " cannot be empty"; + } + + int numDynamicDims = 0; + for (auto [innerIdx, idxAttr] : llvm::enumerate(indexArray)) { + auto intAttr = dyn_cast(idxAttr); + if (!intAttr) { + return emitError() << "reassociation index at [" << groupIdx << "][" + << innerIdx << "] must be an integer"; + } + + int64_t idx = intAttr.getInt(); + if (idx != nextExpected) { + return emitError() << "reassociation indices must form contiguous " + << "sequence; expected dimension " << nextExpected + << " at [" << groupIdx << "][" << innerIdx + << "], got " << idx; + } + + if (outputShape[idx] == ShapedType::kDynamic) { + numDynamicDims++; + } + + nextExpected++; + } + + if (numDynamicDims > 1) { + return emitError() + << "reassociation group " << groupIdx + << " has multiple dynamic dimensions; at most 1 allowed"; + } + } + + ArrayRef outputShapeArray = outputShape.asArrayRef(); + if (nextExpected != static_cast(outputShapeArray.size())) { + return emitError() << "reassociations cover " << nextExpected + << " dimensions, but output_shape has rank " + << outputShapeArray.size(); + } + + return success(); +} + // Index Hint Attributes //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h index 8a05e5b7ea64..920081181beb 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h @@ -295,6 +295,10 @@ StringRef getTilingLevelName(GPU::TilingLevel level); Value cacheSwizzlePromotionImpl(OpBuilder &builder, OpOperand &operand, Attribute attr); +Value swizzlePromotionImpl(OpBuilder &builder, OpOperand &operand, + Attribute attr, + Codegen::SwizzleAttrInterface swizzle); + } // namespace mlir::iree_compiler::IREE::GPU // clang-format off diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td index f14a5ef98c1c..9766a468ee59 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td @@ -131,6 +131,51 @@ def IREEGPU_PromoteWithCacheSwizzle : ); } +def IREEGPU_SwizzleOperand : + AttrDef + ]> { + let mnemonic = "swizzle_operand"; + let summary = [{ + Indicate promotion of an operand with setting an xor swizzle value. + }]; + let description = [{ + During matmul operand promotion, we generate copies associated to a + particular matmul operand with specific lowering configuration optimized + for coalesced loads. This attribute carries information on how accesses to + memrefs or tensors associated to a particular copy should be swizzled. + + This information is used to create a swizzle hint op on the alloc + associated with the copy. Ultimately, this aims to modify memory accesses + to minimize bank conflicts. For example, + + ```mlir + %0 = tensor_ext.dispatch.tensor.load : tensor + %1 = linalg.matmul ins(%0, ...) + ``` + + Becomes with `#iree_gpu.swizzle_operand<#iree_gpu.use_global_load_dma>` + + ```mlir + %0 = tensor_ext.dispatch.tensor.load : tensor + %1 = tensor.empty() + %2 = swizzle_hint_op %1 xor_shuffle(256, 32) + %3 = linalg.copy lowering_config = #iree_gpu.use_global_load_dma ins(%0) outs(%1) + %4 = linalg.matmul ins(%3, ...) + ``` + + With intelligent selection of `row_width` and `access_width`, this should + minimize bank conflicts. + }]; + let assemblyFormat = "`<` struct(params) `>`"; + let parameters = (ins + "Attribute":$copy_config, + IREECodegen_SwizzleAttrParameter:$swizzle + ); +} + //===----------------------------------------------------------------------===// // GPU Workgroup Processor (WGP) Level Feature/Limit Attributes //===----------------------------------------------------------------------===// @@ -960,6 +1005,62 @@ def IREEGPU_GPUPipelineOptionsAttr : AttrDef { + let mnemonic = "expand_dims"; + let cppNamespace = "::mlir::iree_compiler::IREE::GPU"; + + let summary = [{Attribute for describing static dimension expansion.}]; + let description = [{ + This attribute describes how dimensions in an iteration space should be + expanded. Each original dimension can either remain unchanged or be + split into multiple dimensions. The semantics are similar to the familiar + `tensor.expand_shape` operation. + + The reassociations parameter specifies the mapping from original dimensions + to expanded dimensions. For example, [[0], [1], [2, 3]] means: + - Original dimension 0 maps to output dimension 0 + - Original dimension 1 maps to output dimension 1 + - Original dimension 2 is split into output dimensions 2 and 3 + + The output_shape parameter specifies the sizes of the expanded dimensions. + If the size is ShapedType::kDynamic, the size is determined from the product + of the rest of the static tile sizes in the respective reassociation group. + There can be at most one dynamic size per reassociation group. + + Example: #iree_gpu.expand_dims<[[0], [1], [2, 3]], output_shape = [?, ?, ?, 8]> + }]; + + let parameters = (ins + "ArrayAttr":$reassociations, + "DenseI64ArrayAttr":$output_shape + ); + + let builders = [ + AttrBuilder<(ins + "ArrayRef":$reassociations, + "ArrayRef":$outputShape)> + ]; + + let assemblyFormat = "`<` $reassociations `,` `output_shape` `=` custom($output_shape) `>`"; + + let extraClassDeclaration = [{ + SmallVector getReassociationIndices() { + return llvm::to_vector<4>(llvm::map_range( + getReassociations().getAsRange(), + [](ArrayAttr arrayAttr) -> ReassociationIndices { + return llvm::to_vector<2>(llvm::map_range( + arrayAttr.getAsRange(), + [](IntegerAttr idx) { return idx.getInt(); })); + })); + } + }]; + + let genVerifyDecl = 1; +} + // Lane Index Hint Attributes //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/PromotionImpls.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/PromotionImpls.cpp index 080c2e14485c..6f683bf87f2d 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/PromotionImpls.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/PromotionImpls.cpp @@ -8,9 +8,11 @@ #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" +#include "llvm/Support/DebugLog.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -18,13 +20,14 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/Interfaces/TilingInterface.h" +#define DEBUG_TYPE "iree-codegen-promotion-utils" + namespace mlir::iree_compiler::IREE::GPU { /// Helper to insert copy with the specified attr. Value promoteValue(OpBuilder &builder, Location loc, Value v, Attribute attr) { auto tensorType = cast(v.getType()); SmallVector mixedSizes = tensor::getMixedSizes(builder, loc, v); - Value empty = tensor::EmptyOp::create(builder, loc, mixedSizes, tensorType.getElementType()); auto copy = linalg::CopyOp::create(builder, loc, v, empty); @@ -32,28 +35,35 @@ Value promoteValue(OpBuilder &builder, Location loc, Value v, Attribute attr) { return copy.getResult(0); } -/// Inserts a `linalg.copy` directly before the given operation on the -/// specified operand, for example with operand index = 1: -/// -/// %2 = linalg.matmul ins(%0, %1) -/// -/// becomes -/// -/// %empty = tensor.empty() -/// %copy = linalg.copy %1 to %empty { -/// lowering_config = #iree_gpu.{derived_thread_config|use_global_dma}} -/// linalg.matmul ins(%0, %copy) -/// -/// If the producer is already a tilable op, the producer is just annotated with -/// the underlying attribute. -/// Additionally we can also promote results so in above example we will -/// generate for index = 2 : -/// %out_buffer = bufferization.alloc_tensor -/// %copy1 = linalg.copy %2 to %out_buffer -/// %copy2 = linalg.copy %copy1 to %empty { -/// lowering_config = #iree_gpu.derived_thread_config} -Value defaultPromotionImpl(OpBuilder &builder, OpOperand &operand, - Attribute attr) { +// Helper to insert a swizzle hint op and flatten the associated alloc. +Value swizzlePromoteValue(OpBuilder &builder, Location loc, Value v, + Attribute attr, + Codegen::SwizzleAttrInterface swizzle) { + auto tensorType = cast(v.getType()); + int64_t numElements = tensorType.getNumElements(); + SmallVector sizes = tensor::getMixedSizes(builder, loc, v); + bool hasStaticShape = tensorType.hasStaticShape(); + if (hasStaticShape) { + sizes = {builder.getIndexAttr(numElements)}; + } + Value alloc = + tensor::EmptyOp::create(builder, loc, sizes, tensorType.getElementType()); + + // Only generate a swizzle hint op if the shape is static. + if (hasStaticShape) { + Value swizzled = + IREE::Codegen::SwizzleHintOp::create(builder, loc, alloc, swizzle); + alloc = tensor::ExpandShapeOp::create( + builder, loc, tensorType, swizzled, + {llvm::to_vector(llvm::seq(tensorType.getRank()))}); + } + auto copy = linalg::CopyOp::create(builder, loc, v, alloc); + setLoweringConfig(copy, attr); + return copy.getResult(0); +} + +std::optional promotionImpl(OpBuilder &builder, OpOperand &operand, + Attribute attr) { if (auto producer = operand.get().getDefiningOp()) { // Skip promotion of fills. if (isa(producer)) { @@ -78,11 +88,73 @@ Value defaultPromotionImpl(OpBuilder &builder, OpOperand &operand, if (!tensorType) { return operand.get(); } + return std::nullopt; +} +/// Inserts a `linalg.copy` directly before the given operation on the +/// specified operand, for example with operand index = 1: +/// +/// ```mlir +/// %2 = linalg.matmul ins(%0, %1) +/// ``` +/// +/// becomes +/// +/// ```mlir +/// %empty = tensor.empty() +/// %copy = linalg.copy %1 to %empty { +/// lowering_config = #iree_gpu.{derived_thread_config|use_global_dma}} +/// linalg.matmul ins(%0, %copy) +/// ``` +/// +/// If the producer is already a tilable op, the producer is just annotated with +/// the underlying attribute. +/// Additionally we can also promote results so in above example we will +/// generate for index = 2 : +/// +/// ```mlir +/// %out_buffer = bufferization.alloc_tensor +/// %copy1 = linalg.copy %2 to %out_buffer +/// %copy2 = linalg.copy %copy1 to %empty { +/// lowering_config = #iree_gpu.derived_thread_config} +/// ``` +Value defaultPromotionImpl(OpBuilder &builder, OpOperand &operand, + Attribute attr) { + std::optional promotedValue = promotionImpl(builder, operand, attr); + if (promotedValue.has_value()) { + return promotedValue.value(); + } return promoteValue(builder, operand.getOwner()->getLoc(), operand.get(), attr); } +/// Inserts a `linalg.copy` directly before the given operation on the +/// specified operand, similar to the defaultPromotionImpl. +/// The difference is this also assigns a `iree_codegen.swizzle_hint` op +/// to the generated `tensor.empty` op. +/// For example: +/// ```mlir +/// %2 = linalg.matmul ins(%0, %1) +/// ``` +/// becomes +/// ```mlir +/// %empty = tensor.empty() +/// %swizzle = iree_codegen.swizzle_hint %empty[...] +/// %copy = linalg.copy %1 to %swizzle { +/// lowering_config = #iree_gpu.{derived_thread_config|use_global_dma}} +/// linalg.matmul ins(%0, %copy) +/// ``` +Value swizzlePromotionImpl(OpBuilder &builder, OpOperand &operand, + Attribute attr, + Codegen::SwizzleAttrInterface swizzle) { + std::optional promotedValue = promotionImpl(builder, operand, attr); + if (promotedValue.has_value()) { + return promotedValue.value(); + } + return swizzlePromoteValue(builder, operand.getOwner()->getLoc(), + operand.get(), attr, swizzle); +} + /// Inserts a `linalg.copy` directly before the given operation on the /// specified operand, and also inserts a buffer_resource_cast on the producing /// dispatch input if possible. diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/BUILD.bazel index 88b751a73805..6698292164a1 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "bufferize_coalesced_gather_dma.mlir", "canonicalize.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index 7a65414761bb..cd573602e77e 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -338,7 +338,7 @@ getContractionHeuristicSeeds(GPUMatmulShapeType problem, bool isGemm, /// due to padding requirements or because the operation has an existing /// accumulator that needs to be loaded from global memory (matmul_accumulate). static std::optional getMmaScheduleFromProblemAndTarget( - IREE::GPU::TargetAttr target, GPUMatmulShapeType problem, + IREE::GPU::TargetAttr target, GPUMatmulShapeType problem, Location loc, bool transposedLhs, bool transposedRhs, bool isGemm, bool mustBeAligned = true, bool doCPromotion = false, bool scaled = false, int64_t splitReductionTripCnt = 0) { @@ -348,10 +348,12 @@ static std::optional getMmaScheduleFromProblemAndTarget( for (IREE::GPU::ScaledMMAAttr smma : target.getWgp().getScaledMma()) { // Intrinsics that do not specify a distribution kind cannot be // distributed. - if (!smma.getDistributionMappingKind()) + if (!smma.getDistributionMappingKind()) { continue; - if (smma.getSubgroupSize() != targetSubgroupSize) + } + if (smma.getSubgroupSize() != targetSubgroupSize) { continue; + } auto [m, n, k, kB] = smma.getScaledMNKShape(); SmallVector elementTypes; @@ -365,10 +367,12 @@ static std::optional getMmaScheduleFromProblemAndTarget( for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { // Intrinsics that do not specify a distribution kind cannot be // distributed. - if (!mma.getDistributionMappingKind()) + if (!mma.getDistributionMappingKind()) { continue; - if (mma.getSubgroupSize() != targetSubgroupSize) + } + if (mma.getSubgroupSize() != targetSubgroupSize) { continue; + } auto [mSize, nSize, kSize] = mma.getMNKShape(); auto [aType, bType, cType] = mma.getABCElementTypes(); @@ -379,8 +383,9 @@ static std::optional getMmaScheduleFromProblemAndTarget( return std::nullopt; } - assert(problem.aType == problem.bType && - "expected the same aType and bType."); + if (problem.aType != problem.bType) { + return std::nullopt; + } GemmCutoff gemmCutoffs = computeGemmCutoffsForAI(target, problem.aType, scaled); @@ -432,7 +437,7 @@ static std::optional getMmaScheduleFromProblemAndTarget( // First try to find a schedule with an exactly matching intrinsic. std::optional schedule = deduceMMASchedule( problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize, - wgpCount, transposedLhs, transposedRhs, /*canUpcastAcc=*/false, + wgpCount, loc, transposedLhs, transposedRhs, /*canUpcastAcc=*/false, /*mustBeAligned=*/mustBeAligned, doCPromotion, splitReductionTripCnt); return schedule; } @@ -773,24 +778,24 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize( assert((operands.size() == 3 || scaled) && "expected 3 operands"); assert((operands.size() == 5 || !scaled) && "expected 5 operands"); - Value lhs = operands[0]; - Value rhs = operands[1]; - - Value init = operands[2]; + Type lhsElemType = getElementTypeOrSelf(operands[0]); + Type rhsElemType = getElementTypeOrSelf(operands[1]); + Type initElemType = getElementTypeOrSelf(operands[2]); + Type lhsScaleType; + Type rhsScaleType; if (scaled) { - init = operands[4]; assert(llvm::all_of(operands, [](Value a) { return isa(a.getType()); }) && "All operands must be a shaped type"); - assert(*getRank(lhs) > *getRank(operands[2]) && - *getRank(rhs) > *getRank(operands[3]) && + assert(*getRank(operands[0]) > *getRank(operands[2]) && + *getRank(operands[1]) > *getRank(operands[3]) && "Expected operand #0 (lhs) and operand #1 (rhs) to have a greater " "rank than their corresponding scales, operand #2 (lhs_scale) and " "operand #3 (rhs_scale)"); + lhsScaleType = getElementTypeOrSelf(operands[2]); + rhsScaleType = getElementTypeOrSelf(operands[3]); + initElemType = getElementTypeOrSelf(operands[4]); } - Type lhsElemType = getElementTypeOrSelf(lhs); - Type rhsElemType = getElementTypeOrSelf(rhs); - Type initElemType = getElementTypeOrSelf(init); // Intentionally padded GEMM proved to be beneficial for performance for // the following layouts: 1) [M, K] x [K, N] 2) [M, K] x [N, K] // Therefore we disallow padding only when LHS is transposed. @@ -800,7 +805,9 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize( getDimBoundsNoPad(batchDims), lhsElemType, rhsElemType, - initElemType}; + initElemType, + lhsScaleType, + rhsScaleType}; // Accumulator needs shared memory if: // - Padding requires C promotion, OR @@ -809,8 +816,9 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize( (couldNeedPadding && CPromoteIfPadding) || hasExistingAccumulator; bool mustBeAligned = true; + Location loc = operands[0].getLoc(); std::optional schedule = getMmaScheduleFromProblemAndTarget( - target, problem, transposedLhs, transposedRhs, isGemm, + target, problem, loc, transposedLhs, transposedRhs, isGemm, /*mustBeAligned=*/true, doCPromotion, scaled, splitReductionTripCnt); if (!schedule && canSupportUnaligned) { @@ -820,8 +828,8 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize( // accumulator. bool doCPromotionUnaligned = CPromoteIfPadding || hasExistingAccumulator; schedule = getMmaScheduleFromProblemAndTarget( - target, problem, transposedLhs, transposedRhs, isGemm, mustBeAligned, - doCPromotionUnaligned, scaled, splitReductionTripCnt); + target, problem, loc, transposedLhs, transposedRhs, isGemm, + mustBeAligned, doCPromotionUnaligned, scaled, splitReductionTripCnt); } if (!schedule) { @@ -889,7 +897,7 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize( // Attach the MMA schedule as an attribute to the entry point export function // for later access in the pipeline. - MLIRContext *context = lhs.getContext(); + MLIRContext *context = target.getContext(); Builder b(context); SmallVector attrs = { {"workgroup", b.getI64ArrayAttr(workgroupTileSizes)}, @@ -1550,8 +1558,9 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target, int64_t lossFactor = 32; for (; lossFactor >= 1; lossFactor >>= 1) { - if (distributeToThreads(numThreads, lossFactor) == 1) + if (distributeToThreads(numThreads, lossFactor) == 1) { break; + } } } @@ -1729,8 +1738,9 @@ setDirectConvolutionLoweringConfig(IREE::GPU::TargetAttr target, return failure(); } - if (target.getWgp().getMma().empty()) + if (target.getWgp().getMma().empty()) { return failure(); + } const int64_t targetSubgroupSize = target.getPreferredSubgroupSize(); const int64_t splitReductionTripCnt = getSplitReductionTripCount(entryPoint); @@ -1872,17 +1882,17 @@ setDirectConvolutionLoweringConfig(IREE::GPU::TargetAttr target, bool transposedRhs = rhsKPos > nPos; bool mustBeAligned = true; std::optional schedule = getMmaScheduleFromProblemAndTarget( - target, problem, transposedLhs, transposedRhs, /*isGemm=*/false, - mustBeAligned, /*doCPromotion=*/false, /*scaled=*/false, - splitReductionTripCnt); + target, problem, linalgOp.getLoc(), transposedLhs, transposedRhs, + /*isGemm=*/false, mustBeAligned, /*doCPromotion=*/false, + /*scaled=*/false, splitReductionTripCnt); if (!schedule && canSupportUnaligned) { LDBG() << "Attempting to deduce unaligned TileAndFuse MMA schedule"; mustBeAligned = false; schedule = getMmaScheduleFromProblemAndTarget( - target, problem, transposedLhs, transposedRhs, /*isGemm=*/false, - mustBeAligned, /*doCPromotion=*/false, /*scaled=*/false, - splitReductionTripCnt); + target, problem, linalgOp.getLoc(), transposedLhs, transposedRhs, + /*isGemm=*/false, mustBeAligned, /*doCPromotion=*/false, + /*scaled=*/false, splitReductionTripCnt); } if (!schedule) { LDBG() << "Failed to deduce TileAndFuse MMA schedule"; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp index d8f738e88c73..4093391acd77 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp @@ -110,8 +110,9 @@ TargetAttr createTargetAttr(const TargetDetails &details, StringRef arch, SmallVector mmaAttrs; mmaAttrs.reserve(wgp->mmaCount); - for (int i = 0; i < wgp->mmaCount; ++i) + for (int i = 0; i < wgp->mmaCount; ++i) { mmaAttrs.push_back(MMAAttr::get(context, wgp->mmaOps[i])); + } SmallVector scaledMmaAttrs; scaledMmaAttrs.reserve(wgp->scaledMmaCount); @@ -814,10 +815,12 @@ std::optional getARMGPUTargetDetails(StringRef target) { } StringRef normalizeARMGPUTarget(StringRef target) { - if (target == "valhall") + if (target == "valhall") { return "valhall1"; - if (target.starts_with("valhall")) + } + if (target.starts_with("valhall")) { return target; + } return llvm::StringSwitch(target.lower()) .Cases({"mali-g715", "mali-g615"}, "valhall4") @@ -954,15 +957,19 @@ std::optional getNVIDIAGPUTargetDetails(StringRef target) { } StringRef normalizeNVIDIAGPUTarget(StringRef target) { - if (target.starts_with("sm_")) + if (target.starts_with("sm_")) { return target; + } - if (target.starts_with("rtx40")) + if (target.starts_with("rtx40")) { return "sm_89"; - if (target.starts_with("rtx30")) + } + if (target.starts_with("rtx30")) { return "sm_86"; - if (target.starts_with("rtx20")) + } + if (target.starts_with("rtx20")) { return "sm_75"; + } return llvm::StringSwitch(target.lower()) .Case("a100", "sm_80") @@ -1002,22 +1009,26 @@ const WgpDetails *getAdrenoWgpDetails() { } bool verifyQualcommGPUTarget(StringRef target) { - if (target == "adreno") + if (target == "adreno") { return true; + } StringRef t = target; - if (!t.consume_front("adreno-")) + if (!t.consume_front("adreno-")) { return false; + } // The can exist an optional L at the end. - if (t.ends_with("l")) + if (t.ends_with("l")) { t = t.drop_back(); + } // Check whether we have a product number unsigned number = 0; // StringRef::consumeInteger() returns true to signify errors. - if (t.size() != 3 || t.consumeInteger(10, number)) + if (t.size() != 3 || t.consumeInteger(10, number)) { return false; + } return true; } @@ -1036,8 +1047,9 @@ std::optional getQualcommGPUTargetDetails(StringRef target) { // Adreno-750: https://vulkan.gpuinfo.org/displayreport.php?id=27414 // Adreno-740: https://vulkan.gpuinfo.org/displayreport.php?id=19218 // Adreno-730: https://vulkan.gpuinfo.org/displayreport.php?id=19382 - if (verifyQualcommGPUTarget(target)) + if (verifyQualcommGPUTarget(target)) { return TargetDetails{adrenoWgp, nullptr}; + } return std::nullopt; } @@ -1103,9 +1115,11 @@ TargetAttr getMetalTargetDetails(MLIRContext *context) { TargetAttr getCUDATargetDetails(StringRef target, StringRef features, MLIRContext *context) { - if (std::optional details = getNVIDIAGPUTargetDetails(target)) + if (std::optional details = + getNVIDIAGPUTargetDetails(target)) { return createTargetAttr(*details, normalizeNVIDIAGPUTarget(target), features, context); + } return nullptr; } @@ -1147,8 +1161,9 @@ StringRef normalizeHIPTarget(StringRef target) { StringRef normalizeVulkanAMDGPUTarget(StringRef target) { // We cannot accept rdnaN as a target for LLVM AMDGPU backend; so the // following is only meant for Vulkan but not HIP. - if (target.starts_with("rdna")) + if (target.starts_with("rdna")) { return target; + } return normalizeAMDGPUTarget(target); } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ReductionConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ReductionConfigUtils.cpp index 6e432f24518d..242accb8bdcd 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ReductionConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ReductionConfigUtils.cpp @@ -4,6 +4,8 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" #include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" @@ -126,13 +128,15 @@ static LogicalResult checkSingleCombiner(linalg::LinalgOp op) { SmallVector combinerOps; if (matchReduction(op.getRegionOutputArgs(), index, combinerOps) && combinerOps.size() == 1) { - if (foundSingleReductionOutput) + if (foundSingleReductionOutput) { return failure(); + } foundSingleReductionOutput = true; continue; } - if (!op.getMatchingIndexingMap(&initOpOperand).isIdentity()) + if (!op.getMatchingIndexingMap(&initOpOperand).isIdentity()) { return failure(); + } } if (!foundSingleReductionOutput) { return failure(); @@ -294,10 +298,44 @@ getVectorDistributeReductionConfig( int subgroup = partialReductionSize / subgroupStride; int64_t subgroupBasis = (subgroup == 0) ? 1 : subgroup; - partialReductionTileSizes[lastReductionDim] = partialReductionSize; - threadTileSizes[lastReductionDim] = threadLoads; - threadCounts[lastReductionDim] = threadBasis; - subGroupCounts[lastReductionDim] = subgroupBasis; + SmallVector reassociations; + SmallVector outputShape; + + // We require the reduction dimension to be evenly divisible by threadLoads + // because the current expansion strategy doesn't support padding. + if (ShapedType::isStaticShape(bounds) && threadLoads > 1 && + bounds[lastReductionDim] % threadLoads == 0) { + workgroupTileSizes.push_back(0); + partialReductionTileSizes.push_back(0); + threadTileSizes.push_back(0); + threadCounts.push_back(1); + subGroupCounts.push_back(1); + mapping.push_back(mapping.size()); + + int64_t outer = lastReductionDim; + int64_t inner = lastReductionDim + 1; + + for (int64_t i = 0; i < op.getNumLoops(); ++i) { + if (i == lastReductionDim) { + int64_t idx = outputShape.size(); + reassociations.push_back({idx, idx + 1}); + outputShape.append({ShapedType::kDynamic, threadLoads}); + } else { + reassociations.push_back({static_cast(outputShape.size())}); + outputShape.push_back(ShapedType::kDynamic); + } + } + + partialReductionTileSizes[outer] = partialReductionSize / threadLoads; + threadTileSizes[inner] = threadLoads; + threadCounts[outer] = threadBasis; + subGroupCounts[outer] = subgroupBasis; + } else { + partialReductionTileSizes[lastReductionDim] = partialReductionSize; + threadTileSizes[lastReductionDim] = threadLoads; + threadCounts[lastReductionDim] = threadBasis; + subGroupCounts[lastReductionDim] = subgroupBasis; + } ArrayAttr subgroupBasisAttr = b.getArrayAttr( {b.getI64ArrayAttr(subGroupCounts), b.getI64ArrayAttr(mapping)}); @@ -305,13 +343,20 @@ getVectorDistributeReductionConfig( ArrayAttr threadBasisAttr = b.getArrayAttr( {b.getI64ArrayAttr(threadCounts), b.getI64ArrayAttr(mapping)}); - NamedAttribute configAttrs[] = { - NamedAttribute("workgroup", b.getI64ArrayAttr(workgroupTileSizes)), - NamedAttribute("partial_reduction", + SmallVector configAttrs = { + b.getNamedAttr("workgroup", b.getI64ArrayAttr(workgroupTileSizes)), + b.getNamedAttr("partial_reduction", b.getI64ArrayAttr(partialReductionTileSizes)), - NamedAttribute("thread", b.getI64ArrayAttr(threadTileSizes)), - NamedAttribute("lane_basis", threadBasisAttr), - NamedAttribute("subgroup_basis", subgroupBasisAttr)}; + b.getNamedAttr("thread", b.getI64ArrayAttr(threadTileSizes)), + b.getNamedAttr("lane_basis", threadBasisAttr), + b.getNamedAttr("subgroup_basis", subgroupBasisAttr), + }; + + if (!reassociations.empty()) { + auto dimExpandAttr = + DimensionExpansionAttr::get(context, reassociations, outputShape); + configAttrs.emplace_back(b.getNamedAttr("expand_dims", dimExpandAttr)); + } auto configDict = b.getDictionaryAttr(configAttrs); auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict); @@ -623,8 +668,9 @@ LogicalResult setReductionConfig(IREE::GPU::TargetAttr target, } } - if (subgroupSize == 0) + if (subgroupSize == 0) { return failure(); + } FailureOr bitWidth = getBitWidth(op); if (failed(bitWidth)) { diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/BUILD.bazel index 7fcef1e752f8..898344518ba9 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "IREEGPUExtensionsOps.td", ], diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel index 405750a481cf..5cc241ac631f 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "convert_to_multi_mma.mlir", "distribute_inner_tiled.mlir", @@ -28,8 +29,8 @@ iree_lit_test_suite( "transform_fuse_extract_slice_with_forall.mlir", "transform_fuse_forall.mlir", "transform_lower_barrier_region.mlir", - "vectorize_iree_gpu_ops.mlir", "unroll_multi_mma.mlir", + "vectorize_iree_gpu_ops.mlir", ], include = ["*.mlir"], ), diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/distribute_inner_tiled.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/distribute_inner_tiled.mlir index 8d1224dc9bc6..a5ba4ea8910d 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/distribute_inner_tiled.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/distribute_inner_tiled.mlir @@ -35,17 +35,19 @@ module attributes { transform.with_named_sequence } { // CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]: tensor<2x2x16x16xf32> // CHECK: scf.forall (%[[LANE_ID:.+]]) in (64) shared_outs(%[[ITER_ARG:.+]] = %[[ACC]]) -> (tensor<2x2x16x16xf32>) // CHECK: %[[ID:.+]]:3 = affine.delinearize_index %[[LANE_ID]] into (4, 16) -// CHECK: %[[ID1:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (4, 4) -// CHECK: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[ID]]#2, %[[ID1]]] +// CHECK: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<16>) : index +// CHECK: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<16>) : index +// CHECK: %[[ID1:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (4, 4) +// CHECK: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[COL]], %[[ID1]]] // CHECK-SAME: [2, 2, 1, 4] [1, 1, 1, 1] : tensor<2x2x16x16xf16> to tensor<2x2x1x4xf16> -// CHECK: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[ID1]], %[[ID]]#2] +// CHECK: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[ID1]], %[[COL]]] // CHECK-SAME: [2, 2, 4, 1] [1, 1, 1, 1] : tensor<2x2x16x16xf16> to tensor<2x2x4x1xf16> -// CHECK: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ITER_ARG]][0, 0, %[[ID1]], %[[ID]]#2] +// CHECK: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ITER_ARG]][0, 0, %[[ID1]], %[[COL]]] // CHECK-SAME: [2, 2, 4, 1] [1, 1, 1, 1] : tensor<2x2x16x16xf32> to tensor<2x2x4x1xf32> // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: : tensor<2x2x1x4xf16>, tensor<2x2x4x1xf16> into tensor<2x2x4x1xf32> // CHECK: scf.forall.in_parallel -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ITER_ARG]][0, 0, %[[ID1]], %[[ID]]#2] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ITER_ARG]][0, 0, %[[ID1]], %[[COL]]] // CHECK-SAME: [2, 2, 4, 1] [1, 1, 1, 1] : tensor<2x2x4x1xf32> into tensor<2x2x16x16xf32> // CHECK: mapping = [#iree_gpu.lane_id<0>] @@ -87,17 +89,19 @@ module attributes { transform.with_named_sequence } { // CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]: tensor<2x2x16x16xi32> // CHECK: scf.forall (%[[LANE_ID:.+]]) in (64) shared_outs(%[[ITER_ARG:.+]] = %[[ACC]]) -> (tensor<2x2x16x16xi32>) // CHECK: %[[ID:.+]]:3 = affine.delinearize_index %[[LANE_ID]] into (4, 16) -// CHECK: %[[ID1:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (4, 8) -// CHECK: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[ID]]#2, %[[ID1]]] +// CHECK: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<16>) : index +// CHECK: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<16>) : index +// CHECK: %[[ID1:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (4, 8) +// CHECK: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[COL]], %[[ID1]]] // CHECK-SAME: [2, 2, 1, 8] [1, 1, 1, 1] : tensor<2x2x16x32xi8> to tensor<2x2x1x8xi8> -// CHECK: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[ID]]#2, %[[ID1]]] +// CHECK: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[COL]], %[[ID1]]] // CHECK-SAME: [2, 2, 1, 8] [1, 1, 1, 1] : tensor<2x2x16x32xi8> to tensor<2x2x1x8xi8> -// CHECK: %[[ID1_2:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (4, 4) -// CHECK: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ITER_ARG]][0, 0, %[[ID1_2]], %[[ID]]#2] +// CHECK: %[[ID1_2:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (4, 4) +// CHECK: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ITER_ARG]][0, 0, %[[ID1_2]], %[[COL]]] // CHECK-SAME: [2, 2, 4, 1] [1, 1, 1, 1] : tensor<2x2x16x16xi32> to tensor<2x2x4x1xi32> // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: : tensor<2x2x1x8xi8>, tensor<2x2x1x8xi8> into tensor<2x2x4x1xi32> // CHECK: scf.forall.in_parallel -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ITER_ARG]][0, 0, %[[ID1_2]], %[[ID]]#2] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ITER_ARG]][0, 0, %[[ID1_2]], %[[COL]]] // CHECK-SAME: [2, 2, 4, 1] [1, 1, 1, 1] : tensor<2x2x4x1xi32> into tensor<2x2x16x16xi32> // CHECK: mapping = [#iree_gpu.lane_id<0>] diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp index d6da7774b767..8d4d88cfccf5 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp @@ -38,8 +38,9 @@ getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands, if (isa(opOperand.get().getType())) { FailureOr resultBuffer = getBuffer(rewriter, opOperand.get(), options, state); - if (failed(resultBuffer)) + if (failed(resultBuffer)) { return failure(); + } result.push_back(*resultBuffer); } else { result.push_back(opOperand.get()); @@ -121,8 +122,9 @@ struct BarrierRegionOpBufferizationInterface memrefType = bufferization::getBufferType( barrierOp.getOperand(argNum), options, state, invocationStack); } - if (failed(memrefType)) + if (failed(memrefType)) { return failure(); + } return cast(*memrefType); } @@ -207,8 +209,9 @@ struct ValueBarrierOpBufferizationInterface auto srcMemrefType = bufferization::getBufferType( barrierOp.getInputs()[cast(value).getResultNumber()], options, state, invocationStack); - if (failed(srcMemrefType)) + if (failed(srcMemrefType)) { return failure(); + } return cast(*srcMemrefType); } @@ -280,8 +283,9 @@ struct YieldOpBufferizationInterface if (isa(value.getType())) { FailureOr maybeBuffer = getBuffer(rewriter, value, options, state); - if (failed(maybeBuffer)) + if (failed(maybeBuffer)) { return failure(); + } newResults.push_back(*maybeBuffer); } else { newResults.push_back(value); @@ -443,8 +447,9 @@ struct BufferResourceCastOpBufferizationInterface assert(value.getDefiningOp() == castOp && "invalid value"); auto srcMemrefType = bufferization::getBufferType( castOp.getInput(), options, state, invocationStack); - if (failed(srcMemrefType)) + if (failed(srcMemrefType)) { return failure(); + } auto baseMemrefType = cast(srcMemrefType.value()); if (!hasStorageBufferMemSpace(baseMemrefType)) { diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/DistributeInnerTiledToLanes.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/DistributeInnerTiledToLanes.cpp index 18f45ed12e30..42f96aca3611 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/DistributeInnerTiledToLanes.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/DistributeInnerTiledToLanes.cpp @@ -63,8 +63,9 @@ LogicalResult fuseProducersGreedily(RewriterBase &rewriter, // Materialize the slice of the producer in place. std::optional fusedProducer = scf::tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, loops); - if (!fusedProducer) + if (!fusedProducer) { continue; + } // We have no way to know whether a multi-use value can be yielded from the // parallel loop so never yield a replacement. @@ -73,8 +74,9 @@ LogicalResult fuseProducersGreedily(RewriterBase &rewriter, for (auto tiledOp : fusedProducer->tiledOps) { for (OpOperand &operand : tiledOp->getOpOperands()) { auto sliceOp = operand.get().getDefiningOp(); - if (!sliceOp) + if (!sliceOp) { continue; + } candidates.push_back(sliceOp); } } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp index d67ced9ed05c..8f95eed4c316 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp @@ -101,24 +101,44 @@ static FailureOr createSharedAllocDestination(RewriterBase &rewriter, return failure(); } - auto empty = forallOp.getDpsInits()[0].getDefiningOp(); + // Skip swizzle hint ops. + Operation *destination = forallOp.getDpsInits()[0].getDefiningOp(); + if (auto swizzleOp = dyn_cast(destination)) { + destination = swizzleOp->getOperand(0).getDefiningOp(); + } + // Fail if the destination is not a `tensor.empty` op and cannot be trivially // converted to a `bufferization.alloc_tensor`. + auto empty = dyn_cast(destination); if (!empty) { return failure(); } // Create a `bufferization.alloc_tensor` op with memory space // `#gpu.address_space`. + Location loc = empty->getLoc(); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(empty); Attribute sharedMemoryAddrSpace = gpu::AddressSpaceAttr::get( rewriter.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace()); auto allocTensor = bufferization::AllocTensorOp::create( - rewriter, empty->getLoc(), cast(empty.getResult().getType()), + rewriter, loc, cast(empty.getResult().getType()), empty.getDynamicSizes(), /*copy=*/Value(), /*size_hint=*/Value(), /*memory_space=*/sharedMemoryAddrSpace); + + // If the original `tensor.empty` has a swizzle hint, apply it to the new + // allocation. Note that if there is a swizzle hint, it will be the only user + // of the `tensor.empty` op. + if (auto swizzleHintOp = + dyn_cast(*empty->getUsers().begin())) { + assert(empty->hasOneUse() && + "a tensor.empty op with a swizzle hint applied, should have the " + "swizzle hint as its only user"); + auto newSwizzle = IREE::Codegen::SwizzleHintOp::create( + rewriter, loc, allocTensor.getResult(), swizzleHintOp.getSwizzle()); + return newSwizzle.getResult(); + } return allocTensor.getResult(); } @@ -465,9 +485,8 @@ fuseNestedLaneAndWarpForalls(RewriterBase &rewriter, scf::ForallOp warpForallOp, scf::ForallOp laneForallOp) { // Verify mappings. if (!warpForallOp.getMapping() || - !llvm::all_of(*warpForallOp.getMapping(), [](Attribute mappingAttr) { - return isa(mappingAttr); - })) { + !llvm::all_of(*warpForallOp.getMapping(), + llvm::IsaPred)) { return rewriter.notifyMatchFailure(warpForallOp, "not a warp forall op"); } if (!laneForallOp.getMapping() || laneForallOp.getMapping()->size() != 1 || @@ -1281,8 +1300,9 @@ convertScaledContractionToInnerTiledMma( lhsInnerPerm, rhsInnerPerm, sc1InnerPerm, sc2InnerPerm, accInnerPerm}; SmallVector identityPerm = {0, 1}; if (lhsInnerPerm == identityPerm && rhsInnerPerm == identityPerm && - accInnerPerm == identityPerm) + accInnerPerm == identityPerm) { perms = std::nullopt; + } IREE::Codegen::LoweringConfigAttrInterface maybeLoweringConfig = getLoweringConfig(linalgOp); @@ -1424,8 +1444,9 @@ FailureOr convertContractionToInnerTiledMma( SmallVector identityPerm = {0, 1}; if (lhsInnerPerm == identityPerm && rhsInnerPerm == identityPerm && - accInnerPerm == identityPerm) + accInnerPerm == identityPerm) { perms = std::nullopt; + } IREE::Codegen::LoweringConfigAttrInterface maybeLoweringConfig = getLoweringConfig(linalgOp); @@ -1875,12 +1896,15 @@ void populateIREEGPUVectorUnrollPatterns(RewritePatternSet &patterns) { //===---------------------------------------------------------------------===// static bool isLaneMappableForall(scf::ForallOp forallOp) { - if (forallOp.getNumResults() > 0) + if (forallOp.getNumResults() > 0) { return false; - if (forallOp.getRank() != 1) + } + if (forallOp.getRank() != 1) { return false; - if (!forallOp.getMapping().has_value()) + } + if (!forallOp.getMapping().has_value()) { return false; + } Attribute mapping = *forallOp.getMapping()->getValue().begin(); if (mapping != IREE::GPU::LaneIdAttr::get(forallOp.getContext(), 0)) { return false; @@ -2065,4 +2089,107 @@ void populateIREEGPULowerValueBarrierPatterns(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } +//===----------------------------------------------------------------------===// +// SwizzleHintOp Fold Patterns +//===----------------------------------------------------------------------===// + +// The following patterns are adapted from the populateFoldTensorEmptyPatterns +// in upstream llvm-project. The main change is to add support for folding with +// swizzle_hint ops from IREE. Once swizzle_hint ops are more widely used and +// proven stable, we could consider upstreaming this extension. + +namespace { +struct FoldSwizzleHintOpWithExtractSliceOp final + : OpRewritePattern { + using Base::Base; + LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, + PatternRewriter &rewriter) const override { + // Check for swizzle_hint op source. + auto swizzleHintOp = + sliceOp.getSource().getDefiningOp(); + if (!swizzleHintOp) { + return failure(); + } + + // Check for tensor.empty source. + auto emptyOp = swizzleHintOp.getOperand().getDefiningOp(); + if (!emptyOp) { + return failure(); + } + + // Check for single use. + if (!emptyOp->hasOneUse()) { + return failure(); + } + + // Create new tensor.empty op. tensor.extract_slice may be rank-reducing; + // its dynamic sizes must be preserved as well as its result type. + Location loc = sliceOp.getLoc(); + auto sliceType = cast(sliceOp.getType()); + auto tensorType = + RankedTensorType::get(sliceType.getShape(), sliceType.getElementType(), + sliceType.getEncoding()); + auto newEmptyOp = + tensor::EmptyOp::create(rewriter, loc, tensorType, sliceOp.getSizes()); + rewriter.replaceOpWithNewOp( + sliceOp, newEmptyOp, swizzleHintOp.getSwizzle()); + return success(); + } +}; + +template +struct FoldSwizzleHintOpWithReshapeOp final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ReshapeOp reshapeOp, + PatternRewriter &rewriter) const override { + auto swizzleHintOp = + reshapeOp.getSrc() + .template getDefiningOp(); + if (!swizzleHintOp) { + return failure(); + } + auto emptyOp = + swizzleHintOp.getOperand().template getDefiningOp(); + if (!emptyOp) { + return failure(); + } + + // Check for single use. + if (!emptyOp->hasOneUse()) { + return failure(); + } + + // Reify result shape. + Location loc = reshapeOp.getLoc(); + ReifiedRankedShapedTypeDims resultShapes; + if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) || + !llvm::hasSingleElement(resultShapes)) { + return failure(); + } + + // Create new tensor.empty op. + Value emptyTensor = + tensor::EmptyOp::create(rewriter, loc, resultShapes[0], + reshapeOp.getResultType().getElementType(), + reshapeOp.getResultType().getEncoding()); + Value newSwizzleHintOp = IREE::Codegen::SwizzleHintOp::create( + rewriter, loc, emptyTensor, swizzleHintOp.getSwizzle()); + if (newSwizzleHintOp.getType() != reshapeOp.getResultType()) { + rewriter.replaceOpWithNewOp( + reshapeOp, reshapeOp.getResultType(), newSwizzleHintOp); + } else { + rewriter.replaceOp(reshapeOp, newSwizzleHintOp); + } + return success(); + } +}; + +} // namespace + +void populateFoldSwizzleHintOpPatterns(RewritePatternSet &patterns) { + patterns.add, + FoldSwizzleHintOpWithReshapeOp, + FoldSwizzleHintOpWithExtractSliceOp>(patterns.getContext()); +} + } // namespace mlir::iree_compiler::IREE::GPU diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h index 70d9c3522b73..dcdd11f4232a 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h @@ -195,6 +195,9 @@ void populateIREEGPUVectorUnrollPatterns( void populateIREEGPUVectorUnrollPatterns(RewritePatternSet &patterns); void populateIREEGPUVectorizationPatterns(RewritePatternSet &patterns); +// Populate patterns to fold tensor.empty ops through swizzle hint ops. +void populateFoldSwizzleHintOpPatterns(RewritePatternSet &patterns); + } // namespace mlir::iree_compiler::IREE::GPU #endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_TRANSFORMS_H_ diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel index 3868bd7e564b..46574c8f7180 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "combine_barrier_regions.mlir", "distribute_inner_tiled_to_lanes.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_inner_tiled_to_lanes.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_inner_tiled_to_lanes.mlir index 211ba3232414..5b1abd255d0a 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_inner_tiled_to_lanes.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_inner_tiled_to_lanes.mlir @@ -97,15 +97,17 @@ module { // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<8x2x32x8xf16> // CHECK: scf.forall (%[[LANEID:.+]]) in (64) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<2x2x4x8x32xf32>) // CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (2, 32) -// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (2, 4) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[ID]]#2, %[[IDY]]] [2, 8, 1, 4] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[ID]]#2, %[[IDY]]] [8, 2, 1, 4] -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, 0, %[[IDY]], %[[ID]]#2] [2, 2, 4, 4, 1] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<32>) : index +// CHECK-DAG: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<32>) : index +// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (2, 4) +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[COL]], %[[IDY]]] [2, 8, 1, 4] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[COL]], %[[IDY]]] [8, 2, 1, 4] +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, 0, %[[IDY]], %[[COL]]] [2, 2, 4, 4, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<2x8x1x4xf16>, tensor<8x2x1x4xf16> into tensor<2x2x4x4x1xf32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, 0, %[[IDY]], %[[ID]]#2] [2, 2, 4, 4, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, 0, %[[IDY]], %[[COL]]] [2, 2, 4, 4, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -137,15 +139,17 @@ module { // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<8x2x32x8xf16> // CHECK: scf.forall (%[[LANEID:.+]]) in (64) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<2x2x32x4x8xf32>) // CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (2, 32) -// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (2, 4) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[ID]]#2, %[[IDY]]] [2, 8, 1, 4] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[ID]]#2, %[[IDY]]] [8, 2, 1, 4] -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, %[[ID]]#2, 0, %[[IDY]]] [2, 2, 1, 4, 4] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<32>) : index +// CHECK-DAG: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<32>) : index +// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (2, 4) +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[COL]], %[[IDY]]] [2, 8, 1, 4] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[COL]], %[[IDY]]] [8, 2, 1, 4] +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, %[[COL]], 0, %[[IDY]]] [2, 2, 1, 4, 4] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<2x8x1x4xf16>, tensor<8x2x1x4xf16> into tensor<2x2x1x4x4xf32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, %[[ID]]#2, 0, %[[IDY]]] [2, 2, 1, 4, 4] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, %[[COL]], 0, %[[IDY]]] [2, 2, 1, 4, 4] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -177,15 +181,17 @@ module { // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<8x2x32x8xi8> // CHECK: scf.forall (%[[LANEID:.+]]) in (64) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<2x2x4x8x32xi32>) // CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (2, 32) -// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (2, 4) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[ID]]#2, %[[IDY]]] [2, 8, 1, 4] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[ID]]#2, %[[IDY]]] [8, 2, 1, 4] -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, 0, %[[IDY]], %[[ID]]#2] [2, 2, 4, 4, 1] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<32>) : index +// CHECK-DAG: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<32>) : index +// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (2, 4) +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[COL]], %[[IDY]]] [2, 8, 1, 4] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[COL]], %[[IDY]]] [8, 2, 1, 4] +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, 0, %[[IDY]], %[[COL]]] [2, 2, 4, 4, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<2x8x1x4xi8>, tensor<8x2x1x4xi8> into tensor<2x2x4x4x1xi32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, 0, %[[IDY]], %[[ID]]#2] [2, 2, 4, 4, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, 0, %[[IDY]], %[[COL]]] [2, 2, 4, 4, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -217,16 +223,19 @@ module { // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<8x2x16x16xf16> // CHECK: scf.forall (%[[LANEID:.+]]) in (32) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<2x2x8x2x16xf32>) // CHECK-DAG: %[[ID_1:.+]]:2 = affine.delinearize_index %[[LANEID]] into (16) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[ID_1]]#1, 0] [2, 8, 1, 16] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[ID_1]]#1, 0] [8, 2, 1, 16] +// CHECK-DAG: %[[ROW_1:.+]] = iree_codegen.index_hint %[[ID_1]]#1(#iree_gpu.lane_increment<16>) : index +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[ROW_1]], 0] [2, 8, 1, 16] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[ROW_1]], 0] [8, 2, 1, 16] // CHECK-DAG: %[[ID_2:.+]]:3 = affine.delinearize_index %[[LANEID]] into (2, 16) +// CHECK-DAG: %[[ROW_2:.+]] = iree_codegen.index_hint %[[ID_2]]#1(#iree_gpu.lane_constant<16>) : index +// CHECK-DAG: %[[COL_2:.+]] = iree_codegen.index_hint %[[ID_2]]#2(#iree_gpu.lane_increment<16>) : index // Note: ID_2#1 and I_2#2 should not be delinearize outputs once we move to linearized indexing -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, 0, %[[ID_2]]#1, %[[ID_2]]#2] [2, 2, 8, 1, 1] +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, 0, %[[ROW_2]], %[[COL_2]]] [2, 2, 8, 1, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<2x8x1x16xf16>, tensor<8x2x1x16xf16> into tensor<2x2x8x1x1xf32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, 0, %[[ID_2]]#1, %[[ID_2]]#2] [2, 2, 8, 1, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, 0, %[[ROW_2]], %[[COL_2]]] [2, 2, 8, 1, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -251,14 +260,16 @@ func.func @distribute_MFMA_F32_16x16x4_F32(%lhs: tensor<16x4xf32>, %rhs: tensor< // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<4x16xf32> // CHECK: scf.forall (%[[LANEID:.+]]) in (64) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<16x16xf32>) // CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (4, 16) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[ID]]#2, %[[ID]]#1] [1, 1] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][%[[ID]]#1, %[[ID]]#2] [1, 1] -// CHECK-DAG: %[[IDZ:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (4, 4) -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][%[[IDZ]], %[[ID]]#2] [4, 1] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<16>) : index +// CHECK-DAG: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<16>) : index +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[COL]], %[[ROW]]] [1, 1] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][%[[ROW]], %[[COL]]] [1, 1] +// CHECK-DAG: %[[IDZ:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (4, 4) +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][%[[IDZ]], %[[COL]]] [4, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<1x1xf32>, tensor<1x1xf32> into tensor<4x1xf32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][%[[IDZ]], %[[ID]]#2] [4, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][%[[IDZ]], %[[COL]]] [4, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -283,15 +294,17 @@ func.func @distribute_F32_16x16x32_F8E4M3FNUZ(%lhs: tensor<16x32xf8E4M3FNUZ>, %r // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<32x16xf8E4M3FNUZ> // CHECK: scf.forall (%[[LANEID:.+]]) in (64) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<16x16xf32>) // CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (4, 16) -// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (4, 8) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[ID]]#2, %[[IDY]]] [1, 8] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][%[[IDY]], %[[ID]]#2] [8, 1] -// CHECK-DAG: %[[IDZ:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (4, 4) -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][%[[IDZ]], %[[ID]]#2] [4, 1] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<16>) : index +// CHECK-DAG: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<16>) : index +// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (4, 8) +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[COL]], %[[IDY]]] [1, 8] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][%[[IDY]], %[[COL]]] [8, 1] +// CHECK-DAG: %[[IDZ:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (4, 4) +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][%[[IDZ]], %[[COL]]] [4, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<1x8xf8E4M3FNUZ>, tensor<8x1xf8E4M3FNUZ> into tensor<4x1xf32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][%[[IDZ]], %[[ID]]#2] [4, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][%[[IDZ]], %[[COL]]] [4, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -316,15 +329,17 @@ func.func @distribute_I32_32x32x16_I8(%lhs: tensor<32x16xi8>, %rhs: tensor<16x32 // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<16x32xi8> // CHECK: scf.forall (%[[LANEID:.+]]) in (64) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<4x8x32xi32>) // CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (2, 32) -// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (2, 8) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[ID]]#2, %[[IDY]]] [1, 8] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][%[[IDY]], %[[ID]]#2] [8, 1] -// CHECK-DAG: %[[IDZ:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (2, 4) -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, %[[IDZ]], %[[ID]]#2] [4, 4, 1] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<32>) : index +// CHECK-DAG: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<32>) : index +// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (2, 8) +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[COL]], %[[IDY]]] [1, 8] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][%[[IDY]], %[[COL]]] [8, 1] +// CHECK-DAG: %[[IDZ:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (2, 4) +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, %[[IDZ]], %[[COL]]] [4, 4, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<1x8xi8>, tensor<8x1xi8> into tensor<4x4x1xi32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, %[[IDZ]], %[[ID]]#2] [4, 4, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, %[[IDZ]], %[[COL]]] [4, 4, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -349,13 +364,15 @@ func.func @distribute_WMMAR3_F16_16x16x16_F16(%lhs: tensor<16x16xf16>, %rhs: ten // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<16x16xf16> // CHECK: scf.forall (%[[LANEID:.+]]) in (32) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<16x8x2xf16>) // CHECK-DAG: %[[ID:.+]]:2 = affine.delinearize_index %[[LANEID]] into (16) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[ID]]#1, 0] [1, 16] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, %[[ID]]#1] [16, 1] -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, %[[ID]]#1] [16, 1, 1] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_increment<16>) : index +// CHECK-DAG: %[[COL:.+]] = iree_codegen.index_hint %c0(#iree_gpu.lane_constant<16>) : index +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[ROW]], 0] [1, 16] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, %[[ROW]]] [16, 1] +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, %[[COL]], %[[ROW]]] [16, 1, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<1x16xf16>, tensor<16x1xf16> into tensor<16x1x1xf16> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, %[[ID]]#1] [16, 1, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, %[[COL]], %[[ROW]]] [16, 1, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -387,15 +404,18 @@ module { // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<8x2x16x16xi8> // CHECK: scf.forall (%[[LANEID:.+]]) in (32) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<2x2x8x2x16xi32>) // CHECK-DAG: %[[ID:.+]]:2 = affine.delinearize_index %[[LANEID]] into (16) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[ID]]#1, 0] [2, 8, 1, 16] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[ID]]#1, 0] [8, 2, 1, 16] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_increment<16>) : index +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[ROW]], 0] [2, 8, 1, 16] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[ROW]], 0] [8, 2, 1, 16] // CHECK-DAG: %[[ID_ACC:.+]]:3 = affine.delinearize_index %[[LANEID]] into (2, 16) -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, 0, %[[ID_ACC]]#1, %[[ID_ACC]]#2] [2, 2, 8, 1, 1] +// CHECK-DAG: %[[ROW_ACC:.+]] = iree_codegen.index_hint %[[ID_ACC]]#1(#iree_gpu.lane_constant<16>) : index +// CHECK-DAG: %[[COL_ACC:.+]] = iree_codegen.index_hint %[[ID_ACC]]#2(#iree_gpu.lane_increment<16>) : index +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, 0, %[[ROW_ACC]], %[[COL_ACC]]] [2, 2, 8, 1, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<2x8x1x16xi8>, tensor<8x2x1x16xi8> into tensor<2x2x8x1x1xi32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, 0, %[[ID_ACC]]#1, %[[ID_ACC]]#2] [2, 2, 8, 1, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, 0, %[[ROW_ACC]], %[[COL_ACC]]] [2, 2, 8, 1, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -420,14 +440,16 @@ func.func @distribute_WMMAR4_F16_16x16x16_F16(%lhs: tensor<16x16xf16>, %rhs: ten // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<16x16xf16> // CHECK: scf.forall (%[[LANEID:.+]]) in (32) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<16x16xf16>) // CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (2, 16) -// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (2, 8) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[ID]]#2, %[[IDY]]] [1, 8] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][%[[IDY]], %[[ID]]#2] [8, 1] -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][%[[IDY]], %[[ID]]#2] [8, 1] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<16>) : index +// CHECK-DAG: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<16>) : index +// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (2, 8) +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[COL]], %[[IDY]]] [1, 8] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][%[[IDY]], %[[COL]]] [8, 1] +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][%[[IDY]], %[[COL]]] [8, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<1x8xf16>, tensor<8x1xf16> into tensor<8x1xf16> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][%[[IDY]], %[[ID]]#2] [8, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][%[[IDY]], %[[COL]]] [8, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -459,15 +481,17 @@ module { // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<8x2x16x16xi8> // CHECK: scf.forall (%[[LANEID:.+]]) in (32) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<2x2x16x16xi32>) // CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (2, 16) -// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (2, 8) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[ID]]#2, %[[IDY]]] [2, 8, 1, 8] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[ID]]#2, %[[IDY]]] [8, 2, 1, 8] -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, %[[IDY]], %[[ID]]#2] [2, 2, 8, 1] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<16>) : index +// CHECK-DAG: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<16>) : index +// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (2, 8) +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[COL]], %[[IDY]]] [2, 8, 1, 8] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[COL]], %[[IDY]]] [8, 2, 1, 8] +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, %[[IDY]], %[[COL]]] [2, 2, 8, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<2x8x1x8xi8>, tensor<8x2x1x8xi8> into tensor<2x2x8x1xi32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, %[[IDY]], %[[ID]]#2] [2, 2, 8, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, %[[IDY]], %[[COL]]] [2, 2, 8, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -492,15 +516,17 @@ func.func @distribute_WMMA_F32_16x16x4_F32(%lhs: tensor<16x4xf32>, %rhs: tensor< // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<4x16xf32> // CHECK: scf.forall (%[[LANEID:.+]]) in (32) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<16x16xf32>) // CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (2, 16) -// CHECK-DAG: %[[IDX:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (2, 2) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[ID]]#2, %[[IDX]]] [1, 2] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][%[[IDX]], %[[ID]]#2] [2, 1] -// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (2, 8) -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][%[[IDY]], %[[ID]]#2] [8, 1] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<16>) : index +// CHECK-DAG: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<16>) : index +// CHECK-DAG: %[[IDX:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (2, 2) +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[COL]], %[[IDX]]] [1, 2] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][%[[IDX]], %[[COL]]] [2, 1] +// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (2, 8) +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][%[[IDY]], %[[COL]]] [8, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<1x2xf32>, tensor<2x1xf32> into tensor<8x1xf32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][%[[IDY]], %[[ID]]#2] [8, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][%[[IDY]], %[[COL]]] [8, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -525,15 +551,17 @@ func.func @distribute_WMMA_F32_16x16x128_F8E4M3FN(%lhs: tensor<16x128xf8E4M3FN>, // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<128x16xf8E4M3FN> // CHECK: scf.forall (%[[LANEID:.+]]) in (32) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<16x16xf32>) // CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (2, 16) -// CHECK-DAG: %[[IDX:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (2, 64) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[ID]]#2, %[[IDX]]] [1, 64] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][%[[IDX]], %[[ID]]#2] [64, 1] -// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (2, 8) -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][%[[IDY]], %[[ID]]#2] [8, 1] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<16>) : index +// CHECK-DAG: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<16>) : index +// CHECK-DAG: %[[IDX:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (2, 64) +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[COL]], %[[IDX]]] [1, 64] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][%[[IDX]], %[[COL]]] [64, 1] +// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (2, 8) +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][%[[IDY]], %[[COL]]] [8, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<1x64xf8E4M3FN>, tensor<64x1xf8E4M3FN> into tensor<8x1xf32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][%[[IDY]], %[[ID]]#2] [8, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][%[[IDY]], %[[COL]]] [8, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -1087,16 +1115,18 @@ func.func @scaled_matmul_f32_16x16x128_b32_fp4_fp8(%lhs: tensor<3x5x1x16x4x32xf4 // CHECK-SAME: %[[RHS_SCALE:[A-Za-z0-9]+]]: tensor<5x7x4x16xf8E8M0FNU> // CHECK: scf.forall (%[[LANEID:.+]]) in (64) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<3x7x16x16xf32>) // CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (4, 16) -// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (4, 4) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, 0, %[[ID]]#2, %[[ID]]#1, 0] [3, 5, 1, 1, 1, 32] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, 0, %[[ID]]#1, 0, %[[ID]]#2] [5, 1, 7, 1, 32, 1] -// CHECK-DAG: %[[LHS_SCALE_SLICE:.+]] = tensor.extract_slice %[[LHS_SCALE]][0, 0, %[[ID]]#2, %[[ID]]#1] [3, 5, 1, 1] -// CHECK-DAG: %[[RHS_SCALE_SLICE:.+]] = tensor.extract_slice %[[RHS_SCALE]][0, 0, %[[ID]]#1, %[[ID]]#2] [5, 7, 1, 1] -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, %[[IDY]], %[[ID]]#2] [3, 7, 4, 1] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<16>) : index +// CHECK-DAG: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<16>) : index +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, 0, %[[COL]], %[[ROW]], 0] [3, 5, 1, 1, 1, 32] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, 0, %[[ROW]], 0, %[[COL]]] [5, 1, 7, 1, 32, 1] +// CHECK-DAG: %[[LHS_SCALE_SLICE:.+]] = tensor.extract_slice %[[LHS_SCALE]][0, 0, %[[COL]], %[[ROW]]] [3, 5, 1, 1] +// CHECK-DAG: %[[RHS_SCALE_SLICE:.+]] = tensor.extract_slice %[[RHS_SCALE]][0, 0, %[[ROW]], %[[COL]]] [5, 7, 1, 1] +// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (4, 4) +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, %[[IDY]], %[[COL]]] [3, 7, 4, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]], %[[LHS_SCALE_SLICE]], %[[RHS_SCALE_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]]] // CHECK-SAME: : tensor<3x5x1x1x1x32xf4E2M1FN>, tensor<5x1x7x1x32x1xf8E4M3FN>, tensor<3x5x1x1xf8E8M0FNU>, tensor<5x7x1x1xf8E8M0FNU> into tensor<3x7x4x1xf32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, %[[IDY]], %[[ID]]#2] [3, 7, 4, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, %[[IDY]], %[[COL]]] [3, 7, 4, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -1137,16 +1167,16 @@ func.func @scaled_matmul_trb_f32_16x16x128_b32_fp4_fp8(%lhs: tensor<3x5x4x16x4x3 // CHECK-SAME: %[[LHS_SCALE:[A-Za-z0-9]+]]: tensor<3x5x16x4xf8E8M0FNU> // CHECK-SAME: %[[RHS_SCALE:[A-Za-z0-9]+]]: tensor<5x7x16x4xf8E8M0FNU> // CHECK: scf.forall (%[[LANEID:.+]]) in (64) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<3x7x16x16xf32>) -// CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (4, 16) -// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (4, 4) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, 0, %[[ID]]#2, %[[ID]]#1, 0] [3, 5, 4, 1, 1, 32] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, 0, %[[ID]]#2, %[[ID]]#1, 0] [5, 4, 7, 1, 1, 32] -// CHECK-DAG: %[[LHS_SCALE_SLICE:.+]] = tensor.extract_slice %[[LHS_SCALE]][0, 0, %[[ID]]#2, %[[ID]]#1] [3, 5, 1, 1] -// CHECK-DAG: %[[RHS_SCALE_SLICE:.+]] = tensor.extract_slice %[[RHS_SCALE]][0, 0, %[[ID]]#2, %[[ID]]#1] [5, 7, 1, 1] -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, %[[IDY]], %[[ID]]#2] [3, 7, 4, 1] +// CHECK-DAG: iree_codegen.index_hint {{.*}}(#iree_gpu.lane_constant<16>) : index +// CHECK-DAG: iree_codegen.index_hint {{.*}}(#iree_gpu.lane_increment<16>) : index +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]]{{.*}} [3, 5, 4, 1, 1, 32] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]]{{.*}} [5, 4, 7, 1, 1, 32] +// CHECK-DAG: %[[LHS_SCALE_SLICE:.+]] = tensor.extract_slice %[[LHS_SCALE]]{{.*}} [3, 5, 1, 1] +// CHECK-DAG: %[[RHS_SCALE_SLICE:.+]] = tensor.extract_slice %[[RHS_SCALE]]{{.*}} [5, 7, 1, 1] +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]]{{.*}} [3, 7, 4, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]], %[[LHS_SCALE_SLICE]], %[[RHS_SCALE_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: : tensor<3x5x4x1x1x32xf4E2M1FN>, tensor<5x4x7x1x1x32xf8E4M3FN>, tensor<3x5x1x1xf8E8M0FNU>, tensor<5x7x1x1xf8E8M0FNU> into tensor<3x7x4x1xf32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, %[[IDY]], %[[ID]]#2] [3, 7, 4, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -1183,16 +1213,16 @@ func.func @scaled_matmul_trb_f32_32x32x64_b32_fp4_fp8(%lhs: tensor<3x5x1x32x2x32 // CHECK-SAME: %[[LHS_SCALE:[A-Za-z0-9]+]]: tensor<3x5x32x2xf8E8M0FNU> // CHECK-SAME: %[[RHS_SCALE:[A-Za-z0-9]+]]: tensor<5x7x32x2xf8E8M0FNU> // CHECK: scf.forall (%[[LANEID:.+]]) in (64) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<3x7x4x8x32xf32>) -// CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (2, 32) -// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (2, 4) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, 0, %[[ID]]#2, %[[ID]]#1, 0] [3, 5, 1, 1, 1, 32] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, 0, %[[ID]]#2, %[[ID]]#1, 0] [5, 1, 7, 1, 1, 32] -// CHECK-DAG: %[[LHS_SCALE_SLICE:.+]] = tensor.extract_slice %[[LHS_SCALE]][0, 0, %[[ID]]#2, %[[ID]]#1] [3, 5, 1, 1] -// CHECK-DAG: %[[RHS_SCALE_SLICE:.+]] = tensor.extract_slice %[[RHS_SCALE]][0, 0, %[[ID]]#2, %[[ID]]#1] [5, 7, 1, 1] -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, 0, %[[IDY]], %[[ID]]#2] [3, 7, 4, 4, 1] +// CHECK-DAG: iree_codegen.index_hint {{.*}}(#iree_gpu.lane_constant<32>) : index +// CHECK-DAG: iree_codegen.index_hint {{.*}}(#iree_gpu.lane_increment<32>) : index +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]]{{.*}} [3, 5, 1, 1, 1, 32] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]]{{.*}} [5, 1, 7, 1, 1, 32] +// CHECK-DAG: %[[LHS_SCALE_SLICE:.+]] = tensor.extract_slice %[[LHS_SCALE]]{{.*}} [3, 5, 1, 1] +// CHECK-DAG: %[[RHS_SCALE_SLICE:.+]] = tensor.extract_slice %[[RHS_SCALE]]{{.*}} [5, 7, 1, 1] +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]]{{.*}} [3, 7, 4, 4, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]], %[[LHS_SCALE_SLICE]], %[[RHS_SCALE_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: : tensor<3x5x1x1x1x32xf4E2M1FN>, tensor<5x1x7x1x1x32xf8E4M3FN>, tensor<3x5x1x1xf8E8M0FNU>, tensor<5x7x1x1xf8E8M0FNU> into tensor<3x7x4x4x1xf32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, 0, %[[IDY]], %[[ID]]#2] [3, 7, 4, 4, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/ExternalInterfaces/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/PCF/ExternalInterfaces/test/BUILD.bazel index 63fd7bf4c497..cdb68196b384 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/ExternalInterfaces/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/ExternalInterfaces/test/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "bufferize.mlir", ], diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/BUILD.bazel index 1c16ae00608c..ca840f957570 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "PCFBase.td", "PCFInterfaces.td", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/PCFTypes.cpp b/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/PCFTypes.cpp index 16a9ff3b5e30..1a2e62edb373 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/PCFTypes.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/PCFTypes.cpp @@ -98,10 +98,11 @@ void ShapedRefType::print(AsmPrinter &printer) const { ArrayRef shape = getShape(); for (int64_t dim : shape) { - if (ShapedType::isDynamic(dim)) + if (ShapedType::isDynamic(dim)) { printer << '?'; - else + } else { printer << dim; + } printer << 'x'; } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/test/BUILD.bazel index 7fa829131cde..4a76b1f09230 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "control_flow_ops.mlir", "folders.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/ConvertSRefToMemRef.cpp b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/ConvertSRefToMemRef.cpp index 74de12f6eaf6..8d0e5d04c5af 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/ConvertSRefToMemRef.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/ConvertSRefToMemRef.cpp @@ -1115,8 +1115,10 @@ struct ConvertWhileOp final : OpConversionPattern { auto newOp = scf::WhileOp::create(rewriter, op.getLoc(), resultTypes, inits); for (auto i : {0u, 1u}) { - if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter))) + if (failed( + rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter))) { return failure(); + } auto &dstRegion = newOp.getRegion(i); rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); } @@ -1159,17 +1161,20 @@ void ConvertSRefToMemRefPass::runOnOperation() { // only implements context specific conversions. auto isLegallyTypedOp = [&](Operation *op) -> bool { for (Type type : op->getResultTypes()) { - if (isIllegalType(type)) + if (isIllegalType(type)) { return false; + } } for (Type type : op->getOperandTypes()) { - if (isIllegalType(type)) + if (isIllegalType(type)) { return false; + } } for (Region ®ion : op->getRegions()) { for (Type type : region.getArgumentTypes()) { - if (isIllegalType(type)) + if (isIllegalType(type)) { return false; + } } } if (auto funcInterface = dyn_cast(op)) { diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/BUILD.bazel index 8e3484164d2f..49034482436a 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "convert_forall_to_loops.mlir", "convert_sref_to_memref.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/BUILD.bazel index c165a4adc525..212b45745f7a 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/BUILD.bazel @@ -23,11 +23,12 @@ exports_files([ iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "VectorExtAttrs.td", "VectorExtBase.td", - "VectorExtOps.td", "VectorExtInterfaces.td", + "VectorExtOps.td", ], include = ["*.td"], ), diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp index 25d9d5c3ab61..344026ded253 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp @@ -303,9 +303,7 @@ NestedLayoutAttr::getRecombinedLayout(ArrayRef layouts, ArrayRef maps, AffineMap resultMap) { constexpr int64_t kInvalid = -1; - if (llvm::any_of(layouts, [](VectorLayoutInterface layout) { - return !isa(layout); - })) { + if (!llvm::all_of(layouts, llvm::IsaPred)) { return NestedLayoutAttr(); } MLIRContext *context = resultMap.getContext(); @@ -435,11 +433,13 @@ NestedLayoutAttr::computeThreadIds(Value threadId, int64_t subgroupSize, SmallVector subgroupDimToResult, threadDimToResult; if (failed(basisFromSizesStrides(getSubgroupTile(), getSubgroupStrides(), - subgroupBasis, subgroupDimToResult))) + subgroupBasis, subgroupDimToResult))) { return {}; + } if (failed(basisFromSizesStrides(getThreadTile(), getThreadStrides(), - threadBasis, threadDimToResult))) + threadBasis, threadDimToResult))) { return {}; + } // Add the subgroup_size to the end of the subgroup delinearization basis. subgroupBasis.push_back(subgroupSize); diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.cpp index f8f422d0c334..b9a8c076015a 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.cpp @@ -68,8 +68,9 @@ static ParseResult parseIndexVecs(OpAsmParser &parser, SmallVectorImpl &indexVecs, SmallVectorImpl &indexVecTypes, ArrayAttr &indexed) { - if (parser.parseLSquare()) + if (parser.parseLSquare()) { return failure(); + } SMLoc loc; SmallVector indexedArr; @@ -127,11 +128,13 @@ static void printIndexVecs(OpAsmPrinter &p, Operation *op, static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op, SmallVector elidedAttrs = {}) { elidedAttrs.push_back(TransferGatherOp::getOperandSegmentSizeAttr()); - if (op.getPermutationMap().isMinorIdentity()) + if (op.getPermutationMap().isMinorIdentity()) { elidedAttrs.push_back(op.getPermutationMapAttrName()); + } // Elide in_bounds attribute if all dims are out-of-bounds. - if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; })) + if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; })) { elidedAttrs.push_back(op.getInBoundsAttrName()); + } p.printOptionalAttrDict(op->getAttrs(), elidedAttrs); } @@ -140,8 +143,9 @@ void TransferGatherOp::print(OpAsmPrinter &p) { printIndexVecs(p, *this, getIndexVecs(), getIndexVecs().getTypes(), getIndexedAttr()); p << ", " << getPadding(); - if (getMask()) + if (getMask()) { p << ", " << getMask(); + } printTransferAttrs(p, *this, {"indexed"}); p << " : " << getShapedType() << ", " << getType(); } @@ -151,9 +155,10 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds) { - if (!isa(shapedType)) + if (!isa(shapedType)) { return op->emitOpError( "requires source to be a memref or ranked tensor type"); + } Type elementType = shapedType.getElementType(); DataLayout dataLayout = DataLayout::closest(op); @@ -272,30 +277,36 @@ LogicalResult TransferGatherOp::verify() { : VectorType(); auto sourceElementType = shapedType.getElementType(); - if (static_cast(getIndices().size()) != shapedType.getRank()) + if (static_cast(getIndices().size()) != shapedType.getRank()) { return emitOpError("requires ") << shapedType.getRank() << " indices"; + } if (failed(verifyTransferOp(cast(getOperation()), shapedType, vectorType, maskType, - inferredMaskType, permutationMap, getInBounds()))) + inferredMaskType, permutationMap, + getInBounds()))) { return failure(); + } if (auto sourceVectorElementType = dyn_cast(sourceElementType)) { // Source has vector element type. // Check that 'sourceVectorElementType' and 'paddingType' types match. - if (sourceVectorElementType != paddingType) + if (sourceVectorElementType != paddingType) { return emitOpError( "requires source element type and padding type to match."); + } } else { // Check that 'paddingType' is valid to store in a vector type. - if (!VectorType::isValidElementType(paddingType)) + if (!VectorType::isValidElementType(paddingType)) { return emitOpError("requires valid padding vector elemental type"); + } // Check that padding type and vector element types match. - if (paddingType != sourceElementType) + if (paddingType != sourceElementType) { return emitOpError( "requires formal padding and source of the same elemental type"); + } } if (failed(verifyPermutationMap(permutationMap, @@ -353,28 +364,34 @@ ParseResult TransferGatherOp::parse(OpAsmParser &parser, OpAsmParser::UnresolvedOperand maskInfo; // Parsing with support for paddingValue. if (parser.parseOperand(sourceInfo) || - parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square)) + parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square)) { return failure(); + } SmallVector indexVecTypes; ArrayAttr indexed; - if (parseIndexVecs(parser, indexVecInfo, indexVecTypes, indexed)) + if (parseIndexVecs(parser, indexVecInfo, indexVecTypes, indexed)) { return failure(); + } result.addAttribute("indexed", indexed); - if (parser.parseComma() || parser.parseOperand(paddingInfo)) + if (parser.parseComma() || parser.parseOperand(paddingInfo)) { return failure(); + } ParseResult hasMask = parser.parseOptionalComma(); if (hasMask.succeeded()) { - if (parser.parseOperand(maskInfo)) + if (parser.parseOperand(maskInfo)) { return failure(); + } } // Parse attributes and types. if (parser.parseOptionalAttrDict(result.attributes) || - parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) + parser.getCurrentLocation(&typesLoc) || + parser.parseColonTypeList(types)) { return failure(); + } // Check if number of types given are correct. int64_t nRequiredTypes = 2; @@ -387,10 +404,12 @@ ParseResult TransferGatherOp::parse(OpAsmParser &parser, // sourceTy, resultTy auto shapedType = dyn_cast(types[0]); VectorType vectorType = dyn_cast(types[1]); - if (!shapedType || !isa(shapedType)) + if (!shapedType || !isa(shapedType)) { return parser.emitError(typesLoc, "requires memref or ranked tensor type"); - if (!vectorType) + } + if (!vectorType) { return parser.emitError(typesLoc, "requires vector type"); + } auto permMapAttrName = TransferGatherOp::getPermutationMapAttrName(result.name); Attribute permMapAttr = result.attributes.get(permMapAttrName); @@ -414,12 +433,14 @@ ParseResult TransferGatherOp::parse(OpAsmParser &parser, parser.resolveOperands(indexVecInfo, indexVecTypes, typesLoc, result.operands) || parser.resolveOperand(paddingInfo, shapedType.getElementType(), - result.operands)) + result.operands)) { return failure(); + } if (hasMask.succeeded()) { - if (dyn_cast(shapedType.getElementType())) + if (dyn_cast(shapedType.getElementType())) { return parser.emitError( maskInfo.location, "does not support masks with vector element type"); + } if (vectorType.getRank() != permMap.getNumResults()) { return parser.emitError(typesLoc, "expected the same rank for the vector and the " @@ -428,8 +449,9 @@ ParseResult TransferGatherOp::parse(OpAsmParser &parser, // Instead of adding the mask type as an op type, compute it based on the // vector type and the permutation map (to keep the type signature small). auto maskType = vector::inferTransferOpMaskType(vectorType, permMap); - if (parser.resolveOperand(maskInfo, maskType, result.operands)) + if (parser.resolveOperand(maskInfo, maskType, result.operands)) { return failure(); + } } result.addAttribute(TransferGatherOp::getOperandSegmentSizeAttr(), builder.getDenseI32ArrayAttr( diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/test/BUILD.bazel index f4b3ce71c919..18b6c1bbbda4 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/test/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "canonicalize.mlir", "invalid.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BufferizationInterfaces.cpp b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BufferizationInterfaces.cpp index d54ba2e05d06..c0df96ac24e6 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BufferizationInterfaces.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BufferizationInterfaces.cpp @@ -55,8 +55,9 @@ struct TransferGatherOpInterface "only tensor types expected"); FailureOr buffer = getBuffer(rewriter, gatherOp.getBase(), options, state); - if (failed(buffer)) + if (failed(buffer)) { return failure(); + } replaceOpWithNewBufferizedOp( rewriter, gatherOp, gatherOp.getVectorType(), *buffer, gatherOp.getIndices(), gatherOp.getIndexVecs(), gatherOp.getIndexed(), diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/BUILD.bazel index d6f29e9f09ca..55e1e50d8431 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "vector_ext_fold_unit_extent_dims.mlir", "vectorize_vector_ext_ops.mlir", diff --git a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/test/BUILD.bazel index b1524933aa58..9f62cb2c2724 100644 --- a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/test/BUILD.bazel @@ -15,9 +15,10 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ - "roundtrip.mlir", "invalid.mlir", + "roundtrip.mlir", ], include = ["*.mlir"], ), diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel index 1501c7838e1d..e9812c35456f 100644 --- a/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel @@ -22,11 +22,12 @@ exports_files([ iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "PartitionableLoopsInterface.td", "ProcessorOpInterfaces.td", - "UKernelOpInterface.td", "TensorMaskingOpInterface.td", + "UKernelOpInterface.td", ], include = ["*.td"], ), diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp index ae2d87c67fa6..b6198d6edbae 100644 --- a/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp +++ b/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp @@ -149,15 +149,17 @@ struct DispatchTensorStoreOpInterface auto maybeBuffer = getBuffer(rewriter, storeOp->getOpOperand(0).get(), options, state); - if (failed(maybeBuffer)) + if (failed(maybeBuffer)) { return failure(); + } Value srcMemref = *maybeBuffer; // If everything bufferized inplace, no copy is needed. We wrote to the // target buffer already. The copy folds away in that case. if (failed(options.createMemCpy(rewriter, storeOp->getLoc(), srcMemref, - target))) + target))) { return failure(); + } rewriter.eraseOp(storeOp); return success(); @@ -176,8 +178,9 @@ struct LoadFromBufferOpInterface getSourceSubspanMemref( cast>(loadFromBufferOp.getBuffer())); // Conservatively return false if the subspan is not found. - if (!subspanOp) + if (!subspanOp) { return false; + } std::optional descriptorFlags = subspanOp->getDescriptorFlags(); return !descriptorFlags.has_value() || @@ -219,15 +222,17 @@ struct StoreToBufferOpInterface auto storeOp = cast(op); FailureOr maybeBuffer = getBuffer(rewriter, storeOp.getTensor(), options, state); - if (failed(maybeBuffer)) + if (failed(maybeBuffer)) { return failure(); + } Value srcMemref = *maybeBuffer; // If everything bufferized inplace, no copy is needed. We wrote to the // target buffer already. The copy folds away in that case. if (failed(options.createMemCpy(rewriter, storeOp.getLoc(), srcMemref, - storeOp.getBuffer()))) + storeOp.getBuffer()))) { return failure(); + } rewriter.eraseOp(storeOp); return success(); @@ -285,13 +290,15 @@ static LogicalResult bufferizeLinalgExtOp(RewriterBase &rewriter, rewriter.setInsertionPoint(op); // Nothing to do. This op is already bufferized. - if (dspOp.hasPureBufferSemantics()) + if (dspOp.hasPureBufferSemantics()) { return success(); + } // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need // basis. - if (!dspOp.hasPureTensorSemantics()) + if (!dspOp.hasPureTensorSemantics()) { return op->emitError() << "op does not have tensor semantics"; + } // New input operands for the cloned op. SmallVector newOperands, newOutputBuffers; @@ -305,8 +312,9 @@ static LogicalResult bufferizeLinalgExtOp(RewriterBase &rewriter, } if (!dspOp.isDpsInit(&opOperand)) { auto maybeBuffer = getBuffer(rewriter, opOperand.get(), options, state); - if (failed(maybeBuffer)) + if (failed(maybeBuffer)) { return failure(); + } // Input operands are never written to. newOperands.push_back(*maybeBuffer); continue; @@ -319,8 +327,9 @@ static LogicalResult bufferizeLinalgExtOp(RewriterBase &rewriter, FailureOr resultBuffer = getBuffer( rewriter, aliasingOpOperands.getAliases().front().opOperand->get(), options, state); - if (failed(resultBuffer)) + if (failed(resultBuffer)) { return failure(); + } newOperands.push_back(*resultBuffer); newOutputBuffers.push_back(*resultBuffer); } @@ -385,8 +394,9 @@ getSourceAndDestFromPackUnPackOp(RewriterBase &rewriter, OpTy op, static_assert(llvm::is_one_of::value); Value source; auto maybeBuffer = getBuffer(rewriter, op.getSource(), options, state); - if (failed(maybeBuffer)) + if (failed(maybeBuffer)) { return failure(); + } source = *maybeBuffer; Value dest; @@ -397,8 +407,9 @@ getSourceAndDestFromPackUnPackOp(RewriterBase &rewriter, OpTy op, FailureOr resultBuffer = getBuffer( rewriter, aliasingOpOperands.getAliases().front().opOperand->get(), options, state); - if (failed(resultBuffer)) + if (failed(resultBuffer)) { return failure(); + } dest = *resultBuffer; return std::make_pair(source, dest); } @@ -412,8 +423,9 @@ static LogicalResult bufferizePackOp(RewriterBase &rewriter, linalg::PackOp op, auto maybeSrcAndDest = getSourceAndDestFromPackUnPackOp(rewriter, op, options, state); - if (failed(maybeSrcAndDest)) + if (failed(maybeSrcAndDest)) { return failure(); + } auto [source, dest] = *maybeSrcAndDest; // Set insertion point now that potential alloc/dealloc are introduced. @@ -438,8 +450,9 @@ static LogicalResult bufferizeUnPackOp(RewriterBase &rewriter, auto maybeSrcAndDest = getSourceAndDestFromPackUnPackOp(rewriter, op, options, state); - if (failed(maybeSrcAndDest)) + if (failed(maybeSrcAndDest)) { return failure(); + } auto [source, dest] = *maybeSrcAndDest; // Set insertion point now that potential alloc/dealloc are introduced. @@ -482,8 +495,9 @@ struct PackUnPackOpInterface auto dspOp = cast(op); // The i-th "out" tensor may alias with the i-th OpResult. - if (dspOp.isDpsInit(&opOperand)) + if (dspOp.isDpsInit(&opOperand)) { return {dspOp.getTiedOpResult(&opOperand)}; + } return {}; } @@ -493,10 +507,11 @@ struct PackUnPackOpInterface auto dspOp = cast(op); // The i-th "out" tensor may alias with the i-th OpResult. - if (dspOp.isDpsInit(&opOperand)) + if (dspOp.isDpsInit(&opOperand)) { return {AliasingValue(dspOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent, /*isDefinite=*/false)}; + } return {}; } @@ -531,8 +546,9 @@ struct DispatchTensorLoadOpSubsetInterface // DispatchTensorStoreOp result that bufferizes inplace. auto loadOp = cast(op); auto storeOp = dyn_cast(op); - if (!storeOp) + if (!storeOp) { return false; + } return equivalenceFn(loadOp.getSource(), storeOp.getTarget()); } @@ -556,8 +572,9 @@ struct DispatchTensorStoreOpSubsetInterface // DispatchTensorLoadOp result that bufferizes inplace. auto storeOp = cast(op); auto loadOp = dyn_cast(op); - if (!loadOp) + if (!loadOp) { return false; + } return equivalenceFn(loadOp.getSource(), storeOp.getTarget()); } diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp index a9f9c229e76a..cbb37d3c4d2e 100644 --- a/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp +++ b/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp @@ -300,6 +300,8 @@ void registerPartitionableLoopsInterfaceModels(DialectRegistry ®istry) { *ctx); IREE::LinalgExt::MapScatterOp::attachInterface< AllParallelAsPartitionableLoops>(*ctx); + IREE::LinalgExt::MapGatherOp::attachInterface< + AllParallelAsPartitionableLoops>(*ctx); }); registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { tensor::PadOp::attachInterface< diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp index d8f6caf063a5..766aa6aac6da 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp @@ -116,8 +116,9 @@ struct ConvertHALEntryPointFuncOp LogicalResult matchAndRewrite(func::FuncOp stdFuncOp, func::FuncOpAdaptor operands, ConversionPatternRewriter &rewriter) const override { - if (!stdFuncOp.isPublic()) + if (!stdFuncOp.isPublic()) { return failure(); + } FunctionType fnType = stdFuncOp.getFunctionType(); if (fnType.getNumInputs() != 0 || fnType.getNumResults() != 0) { stdFuncOp->emitWarning() @@ -773,8 +774,9 @@ struct RewriteCallOpABI : public OpRewritePattern { PatternRewriter &rewriter) const override { auto symbol = dyn_cast(callOp.getCallableForCallee()); auto flatSymbol = dyn_cast_if_present(symbol); - if (!flatSymbol) + if (!flatSymbol) { return failure(); + } // Ensure the target function is extern. // To support conversion inserting calls in local patterns that can't add @@ -821,8 +823,9 @@ struct RewriteExternCallOpToDynamicImportCallOp // Ignore indirect calls (they're probably already converted imports). auto symbol = dyn_cast(callOp.getCallableForCallee()); auto flatSymbol = dyn_cast_if_present(symbol); - if (!flatSymbol) + if (!flatSymbol) { return failure(); + } // Ensure the target function is extern. // To support conversion inserting calls in local patterns that can't add @@ -1139,8 +1142,9 @@ void ConvertToLLVMPass::runOnOperation() { RewritePatternSet patterns(&getContext()); patterns.insert(abi, typeConverter); - if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) + if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { return signalPassFailure(); + } } // Post conversion patterns. diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/DispatchABI.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/DispatchABI.cpp index d82b1b6ec9ea..51158f9939d4 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/DispatchABI.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/DispatchABI.cpp @@ -339,8 +339,9 @@ HALDispatchABI::getProcessorType(MLIRContext *context, llvm::sys::ScopedLock lock(sMutex); auto structType = LLVM::LLVMStructType::getIdentified(context, "iree_hal_processor_v0_t"); - if (structType.isInitialized()) + if (structType.isInitialized()) { return structType; + } auto uint64Type = IntegerType::get(context, 64); SmallVector fieldTypes; @@ -365,8 +366,9 @@ HALDispatchABI::getEnvironmentType(MLIRContext *context, llvm::sys::ScopedLock lock(sMutex); auto structType = LLVM::LLVMStructType::getIdentified( context, "iree_hal_executable_environment_v0_t"); - if (structType.isInitialized()) + if (structType.isInitialized()) { return structType; + } auto opaquePtrType = LLVM::LLVMPointerType::get(context); SmallVector fieldTypes; @@ -399,8 +401,9 @@ HALDispatchABI::getDispatchStateType(MLIRContext *context, llvm::sys::ScopedLock lock(sMutex); auto structType = LLVM::LLVMStructType::getIdentified( context, "iree_hal_executable_dispatch_state_v0_t"); - if (structType.isInitialized()) + if (structType.isInitialized()) { return structType; + } auto uint8Type = IntegerType::get(context, 8); auto uint16Type = IntegerType::get(context, 16); @@ -453,8 +456,9 @@ HALDispatchABI::getWorkgroupStateType(MLIRContext *context, llvm::sys::ScopedLock lock(sMutex); auto structType = LLVM::LLVMStructType::getIdentified( context, "iree_hal_executable_workgroup_state_v0_t"); - if (structType.isInitialized()) + if (structType.isInitialized()) { return structType; + } auto uint16Type = IntegerType::get(context, 16); auto uint32Type = IntegerType::get(context, 32); @@ -583,8 +587,9 @@ static StringRef getDimName(int32_t dim) { // the ops if MLIR or LLVM is likely to reject them. static bool isLocationValidForDI(Location loc) { // Unknown locations are passed as null and DI doesn't like that. - if (isa(loc)) + if (isa(loc)) { return false; + } // MLIR currently can't handle name-only locations. We do this check to ensure // there's at least one real location MLIR can pass along. if (auto callLoc = dyn_cast(loc)) { @@ -604,11 +609,13 @@ static bool isLocationValidForDI(Location loc) { static Value buildArgDI(Operation *forOp, int argNum, Value value, Twine name, LLVM::DITypeAttr type, OpBuilder &builder) { - if (!clVerboseDebugInfo) + if (!clVerboseDebugInfo) { return value; + } auto loc = forOp->getLoc(); - if (!isLocationValidForDI(loc)) + if (!isLocationValidForDI(loc)) { return value; + } auto scopeAttr = getLocalScopeAttr(forOp); LLVM::DbgValueOp::create(builder, loc, value, LLVM::DILocalVariableAttr::get( @@ -621,11 +628,13 @@ static Value buildArgDI(Operation *forOp, int argNum, Value value, Twine name, static Value buildValueDI(Operation *forOp, Value value, Twine name, LLVM::DITypeAttr type, OpBuilder &builder) { - if (!clVerboseDebugInfo) + if (!clVerboseDebugInfo) { return value; + } auto loc = forOp->getLoc(); - if (!isLocationValidForDI(loc)) + if (!isLocationValidForDI(loc)) { return value; + } auto scopeAttr = getLocalScopeAttr(forOp); LLVM::DbgValueOp::create(builder, loc, value, LLVM::DILocalVariableAttr::get( @@ -789,7 +798,7 @@ MemRefDescriptor HALDispatchABI::loadBinding(Operation *forOp, int64_t ordinal, // requested range is valid. auto [strides, offset] = memRefType.getStridesAndOffset(); if (memRefType.hasStaticShape() && - !llvm::any_of(strides, ShapedType::isDynamic) && + llvm::none_of(strides, ShapedType::isDynamic) && ShapedType::isStatic(offset)) { return MemRefDescriptor::fromStaticShape(builder, loc, *typeConverter, memRefType, basePtrValue); @@ -1379,8 +1388,9 @@ Value HALDispatchABI::getIndexValue(Location loc, int64_t value, Value HALDispatchABI::castValueToType(Location loc, Value value, Type resultType, OpBuilder &builder) { // NOTE: we should handle more cases here (and proper sign extension). - if (value.getType() == resultType) + if (value.getType() == resultType) { return value; + } return builder.createOrFold(loc, resultType, value); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp index 7dbb37bda123..ec7957be9d21 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp @@ -201,8 +201,9 @@ static void getRangeBounds(TilingInterface op, SmallVectorImpl &lb, SmallVector loopRange = op.getIterationDomain(builder); auto getStaticValue = [](OpFoldResult ofr) -> int64_t { std::optional intVal = getConstantIntValue(ofr); - if (!intVal) + if (!intVal) { return ShapedType::kDynamic; + } return intVal.value(); }; lb = llvm::map_to_vector(loopRange, @@ -332,8 +333,9 @@ static int64_t getVectorSize(mlir::FunctionOpInterface entryPointFn, static int64_t getVectorSize(mlir::FunctionOpInterface entryPointFn, ShapedType shapedType) { Type elementType = shapedType.getElementType(); - if (!elementType.isIntOrFloat()) + if (!elementType.isIntOrFloat()) { return 1; + } unsigned byteWidth = IREE::Util::getRoundedElementByteWidth(elementType); return getVectorSize(entryPointFn, byteWidth); } @@ -385,12 +387,14 @@ getMinTilingSizesForEachDim(mlir::FunctionOpInterface entryPointFn, for (auto [index, map] : llvm::enumerate(op.getIndexingMapsArray())) { // Check the fastest varying dimension of the operand. Set the vector size // of the corresponding loop to the vector size. - if (map.getNumResults() == 0) + if (map.getNumResults() == 0) { continue; + } auto fastestVaryingDimExpr = dyn_cast(map.getResults().back()); - if (!fastestVaryingDimExpr) + if (!fastestVaryingDimExpr) { continue; + } unsigned fastestVaryingDim = fastestVaryingDimExpr.getPosition(); // If the indexing map has result it has to be a shaped type. @@ -923,8 +927,9 @@ getDefaultDistributedLevelTileSizes(Operation *op, // Final fix up of the tile sizes to make sure that they divide the problem // size to make it vectorizable. for (auto i : llvm::seq(0, distributedTileSizes.size())) { - if (!distributedTileSizes[i]) + if (!distributedTileSizes[i]) { continue; + } distributedTileSizes[i] = getMaxDistributionTileSize( lbs[i], ubs[i], distributedTileSizes[i], adjustedMinTileSizes[i], config.allowIncompleteTile); @@ -950,12 +955,14 @@ static void splitParallelAndReductionTiles( llvm::enumerate(tilingOp.getLoopIteratorTypes())) { if (iteratorType == utils::IteratorType::parallel) { reductionSizes[index] = 0; - if (reductionScalableFlags) + if (reductionScalableFlags) { (*reductionScalableFlags)[index] = false; + } } else { parallelSizes[index] = 0; - if (parallelScalableFlags) + if (parallelScalableFlags) { (*parallelScalableFlags)[index] = false; + } } } } @@ -965,8 +972,9 @@ static void setAlwaysVectorizeSizes(linalg::LinalgOp op, SmallVector staticLoopRanges = op.getStaticLoopRanges(); for (auto [index, size, iterType] : llvm::enumerate(staticLoopRanges, op.getIteratorTypesArray())) { - if (ShapedType::isStatic(size)) + if (ShapedType::isStatic(size)) { continue; + } vecTileSizes[index] = 1; } LDBG() << "Set always-vectorize sizes: " << vecTileSizes; @@ -1372,8 +1380,9 @@ setMatmulPeelingRootConfig(mlir::FunctionOpInterface entryPointFn, // The LLVM backend struggles to legalize non-power-of-two scalable vectors, // hence the extra rounding up. for (auto [index, size] : llvm::enumerate(roundedVecTileSizes)) { - if (!size) + if (!size) { continue; + } roundedVecTileSizes[index] = roundUpToPow2(size, /*predicate=*/inputVecScalableTileFlags[index]); @@ -1501,8 +1510,9 @@ static FailureOr nonWideningLinalgElementType(linalg::LinalgOp op) { } assert(!inputAndOutputElementTypes.empty() && "expected linalg op to have input and output types"); - if (!llvm::all_equal(inputAndOutputElementTypes)) + if (!llvm::all_equal(inputAndOutputElementTypes)) { return failure(); + } return inputAndOutputElementTypes[0]; } @@ -1522,8 +1532,9 @@ static void getMatmulVectorSizesUsingFullVectorHeuristics( mlir::FunctionOpInterface entryPointFn, linalg::LinalgOp op, int64_t vectorSize, SmallVectorImpl &sizes, SmallVectorImpl &scalableSizeFlags) { - if (sizes.empty()) + if (sizes.empty()) { getDefaultMatmulVectorSizes(op, vectorSize, sizes, scalableSizeFlags); + } // Find the smallest type size in the matmul. SmallVector matmulTypes; @@ -1534,11 +1545,13 @@ static void getMatmulVectorSizesUsingFullVectorHeuristics( int64_t minSize = std::numeric_limits::max(); for (Type mmType : matmulTypes) { - if (auto shType = dyn_cast(mmType)) + if (auto shType = dyn_cast(mmType)) { mmType = shType.getElementType(); + } - if (mmType.isSignlessIntOrFloat()) + if (mmType.isSignlessIntOrFloat()) { minSize = std::min(minSize, int64_t{mmType.getIntOrFloatBitWidth()}); + } } LDBG() << "Smallest type found: " << minSize << " bits"; @@ -1567,14 +1580,16 @@ getMatmulRISCVVectorSizes(mlir::FunctionOpInterface entryPointFn, linalg::LinalgOp op, int64_t vectorSize, SmallVectorImpl &sizes, SmallVectorImpl &scalableSizeFlags) { - if (sizes.empty()) + if (sizes.empty()) { getDefaultMatmulVectorSizes(op, vectorSize, sizes, scalableSizeFlags); + } // TODO: support widening matmul. // Determines n dimension tile size with VLEN for // nonWideningLinalgElementType. FailureOr elementType = nonWideningLinalgElementType(op); - if (failed(elementType)) + if (failed(elementType)) { return; + } // nativeVectorSize is cacluated with VLEN and LMUL=2. int64_t nativeVectorSize = getNativeVectorSizeInBytes(entryPointFn); @@ -1591,8 +1606,9 @@ getMatmulRISCVVectorSizes(mlir::FunctionOpInterface entryPointFn, } FailureOr cDims = linalg::inferContractionDims(op); - if (failed(cDims) || cDims->m.size() != 1) + if (failed(cDims) || cDims->m.size() != 1) { return; + } // Use 7 x lmul4 to fully utilize vector registers. sizes[0] = 7; // Calculate tile size for the main vector dimension (N). @@ -1620,12 +1636,14 @@ getMatmulAArch64SMEVectorSizes(linalg::LinalgOp op, // Double-check the operation is one that is supported for lowering to ArmSME. Operation *rawOp = op.getOperation(); if (!(IREE::LinalgExt::isPureMatmul(rawOp) || - isa(rawOp))) + isa(rawOp))) { return; + } auto elementType = nonWideningLinalgElementType(op); - if (failed(elementType)) + if (failed(elementType)) { return; + } // TODO(macdue): Come up with some heuristics to pick the appropriate tiling // for SME, i.e. optimal layout based on static sizes. @@ -2023,9 +2041,10 @@ getMmt4dLoweringConfig(linalg::LinalgOp op, DictionaryAttr targetConfig) { bool scalableTilesFound = false; // If scalable vectorization is enabled, adjust the vector tile sizes and the // corresponding scalable flags. - if (targetConfig && isScalableVectorizationEnabled()) + if (targetConfig && isScalableVectorizationEnabled()) { scalableTilesFound = adjustVectorSizesForScalableVectorization( op, targetConfig, M0, N0, vecTileSizes, vecScalableTileFlags); + } // In the existence of scalable tiles, we do not yet support limiting vector // sizes as this assumes static tile sizes. // TODO: extend this mechanism to handle _scalable_ tile sizes as well. @@ -2133,8 +2152,9 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn, // but it does not know that it is working on packed domain. We need to take // inner tile sizes into account and adjust the distribution tile sizes. for (auto [pos, size] : llvm::zip_equal(dimPos, innerTiles)) { - if (distTileSizes[pos] == 0 || ShapedType::isDynamic(size)) + if (distTileSizes[pos] == 0 || ShapedType::isDynamic(size)) { continue; + } distTileSizes[pos] = distTileSizes[pos] / size; distTileSizes[pos] = std::max(distTileSizes[pos], int64_t{1}); } @@ -2191,8 +2211,9 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn, ArrayRef dimPos = op.getInnerDimsPos(); for (auto [pos, size, scalable] : llvm::zip_equal(dimPos, innerTiles, scalableFlags)) { - if (distTileSizes[pos] == 0 || ShapedType::isDynamic(size)) + if (distTileSizes[pos] == 0 || ShapedType::isDynamic(size)) { continue; + } int64_t alignedTileSize = llvm::alignTo(distTileSizes[pos], size); distTileSizes[pos] = roundUpToPow2(alignedTileSize, scalable); } @@ -2520,8 +2541,9 @@ static void getTransposeX86VectorSizes( linalg::GenericOp genericOp, IREE::HAL::ExecutableTargetAttr targetAttr, ArrayRef minTileSizes, SmallVectorImpl &sizes) { if (!targetAttr || !hasAVX2Feature(targetAttr.getConfiguration()) || - !x86TransposeLoweringPrecondition(genericOp)) + !x86TransposeLoweringPrecondition(genericOp)) { return; + } if (llvm::count_if(minTileSizes, [](int64_t tileSize) { return tileSize > 1; }) != 2) { @@ -2561,12 +2583,14 @@ static void getTransposeX86VectorSizes( static void getTransposeAArch64VectorSizes( linalg::GenericOp genericOp, IREE::HAL::ExecutableTargetAttr targetAttr, SmallVectorImpl &sizes, SmallVectorImpl &scalableFlags) { - if (!targetAttr || !isLinalgGeneric2DTranspose(genericOp)) + if (!targetAttr || !isLinalgGeneric2DTranspose(genericOp)) { return; + } auto elementType = nonWideningLinalgElementType(genericOp); - if (failed(elementType)) + if (failed(elementType)) { return; + } if (hasSMEFeature(targetAttr.getConfiguration()) && isScalableVectorizationEnabled() && !clDisableArmSMETiling) { @@ -2599,12 +2623,14 @@ getTransposeVectorSizes(mlir::FunctionOpInterface entryPointFn, scalableFlags); } - if (tileSizes.empty()) + if (tileSizes.empty()) { return std::nullopt; + } // If scalable flags are empty, assume target doesn't care about scalability. - if (scalableFlags.empty()) + if (scalableFlags.empty()) { scalableFlags = SmallVector(tileSizes.size(), false); + } LDBG() << "Transpose vector sizes: " << tileSizes; LDBG() << "Transpose vector scalable flags: " << scalableFlags; @@ -2621,15 +2647,17 @@ setTransposeLikeOpRootConfig(mlir::FunctionOpInterface entryPointFn, assert(!getLoweringConfig(genericOp) && "expected lowering_config is not set"); - if (!linalgOpInfo.isTranspose()) + if (!linalgOpInfo.isTranspose()) { return failure(); + } LDBG() << "Setting transpose-like op root configuration"; std::optional vecDims = getTransposeVectorSizes( entryPointFn, genericOp, linalgOpInfo, targetMLTransInfo); - if (!vecDims) + if (!vecDims) { return failure(); + } auto [vecSizes, vecScalableDims] = *vecDims; @@ -2667,10 +2695,12 @@ static LogicalResult setElementwiseGenericOpRootConfig( LDBG() << "Setting elementwise generic op root configuration"; unsigned numLoops = genericOp.getNumLoops(); - if (numLoops == 0) + if (numLoops == 0) { return failure(); - if (!linalg::isElementwise(genericOp)) + } + if (!linalg::isElementwise(genericOp)) { return failure(); + } DistributionHeuristicConfig distConfig; distConfig.allowIncompleteTile = true; @@ -2797,13 +2827,15 @@ enum class Conv2DDimOrder { static Conv2DDimOrder getConv2DDimOrder(Operation *op) { if (isa(op)) + linalg::PoolingNchwMaxOp>(op)) { return Conv2DDimOrder::Nchw; + } if (isa(op)) + linalg::DepthwiseConv2DNhwcHwcOp>(op)) { return Conv2DDimOrder::Nhwc; + } llvm::llvm_unreachable_internal("unsupported conv op"); } @@ -2890,42 +2922,54 @@ getNhwcConvVectorSizes(mlir::FunctionOpInterface entryPointFn, if (targetAttr) { DictionaryAttr targetConfig = targetAttr.getConfiguration(); if (isX86(targetConfig)) { - if (is2DConvOp(op)) + if (is2DConvOp(op)) { return {1, 1, 8, vectorSize, 1, 1, 8}; - if (is2DDepthConvOp(op)) + } + if (is2DDepthConvOp(op)) { return {1, 1, 8, vectorSize, 1, 3}; - if (is2DPoolingOp(op)) + } + if (is2DPoolingOp(op)) { return {1, 1, 8, vectorSize, 1, 8}; + } llvm_unreachable("unsupported conv"); } if (isRISCV(targetConfig)) { - if (is2DConvOp(op)) + if (is2DConvOp(op)) { return {1, 1, 8, vectorSize * 2, 1, 1, 8}; - if (is2DDepthConvOp(op)) + } + if (is2DDepthConvOp(op)) { return {1, 1, 8, vectorSize, 1, 3}; - if (is2DPoolingOp(op)) + } + if (is2DPoolingOp(op)) { return {1, 1, 8, vectorSize * 2, 1, 8}; + } llvm_unreachable("unsupported conv"); } if (isAArch64(targetConfig)) { - if (is2DConvOp(op)) + if (is2DConvOp(op)) { return {1, 1, 32, 64, 1, 1, 16}; - if (is2DDepthConvOp(op)) + } + if (is2DDepthConvOp(op)) { return {1, 1, 4, 4, 1, 4}; - if (is2DPoolingOp(op)) + } + if (is2DPoolingOp(op)) { return {1, 1, 32, 64, 1, 16}; + } llvm_unreachable("unsupported conv"); } } // Get default hard-coded tile sizes if we couldn't compute anything // better. - if (is2DConvOp(op)) + if (is2DConvOp(op)) { return {1, 1, vectorSize, vectorSize, 1, 1, vectorSize}; - if (is2DDepthConvOp(op)) + } + if (is2DDepthConvOp(op)) { return {1, 1, vectorSize, vectorSize, 1, vectorSize}; - if (is2DPoolingOp(op)) + } + if (is2DPoolingOp(op)) { return {1, 1, vectorSize, vectorSize, 1, vectorSize}; + } llvm_unreachable("unsupported conv"); } @@ -3713,8 +3757,9 @@ adjustTileSizesForRootUnPackOp(mlir::FunctionOpInterface entryPointFn, linalgOp.getNumLoops(), false); for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) { auto unpackOp = opOperand->get().getDefiningOp(); - if (!unpackOp) + if (!unpackOp) { continue; + } foundUnPackOp = true; auto idxMap = linalgOp.getMatchingIndexingMap(opOperand); @@ -3732,19 +3777,22 @@ adjustTileSizesForRootUnPackOp(mlir::FunctionOpInterface entryPointFn, ArrayRef dimPos = unpackOp.getInnerDimsPos(); for (auto [pos, size, scalable] : llvm::zip_equal(dimPos, innerTiles, scalableFlags)) { - if (ShapedType::isDynamic(size)) + if (ShapedType::isDynamic(size)) { continue; + } auto dimExpr = dyn_cast(idxMap.getResult(pos)); - if (!dimExpr) + if (!dimExpr) { return failure(); + } int mappedPos = dimExpr.getPosition(); alignedSizes[mappedPos] = std::lcm(alignedSizes[mappedPos], size); vecParallelScalableTileFlags[mappedPos] = scalable; } } - if (!foundUnPackOp) + if (!foundUnPackOp) { return success(); + } LDBG() << "The tile sizes for each dimension should be aligned to " << alignedSizes; @@ -3758,8 +3806,9 @@ adjustTileSizesForRootUnPackOp(mlir::FunctionOpInterface entryPointFn, for (IREE::CPU::LoweringConfigLevelInfo &info : tilingInfo) { SmallVector &tileSizes = info.sizes; for (auto idx : llvm::seq(0, tileSizes.size())) { - if (tileSizes[idx] == 0) + if (tileSizes[idx] == 0) { continue; + } int64_t alignedTileSize = llvm::alignTo(tileSizes[idx], alignedSizes[idx]); tileSizes[idx] = roundUpToPow2( @@ -3938,13 +3987,15 @@ setTranslationInfoAndRootConfig(mlir::FunctionOpInterface entryPointFn, ArrayRef computeOps) { // Make sure that lowering_config is not preset on any compute ops. for (auto computeOp : computeOps) { - if (getLoweringConfig(computeOp)) + if (getLoweringConfig(computeOp)) { return failure(); + } } FailureOr rootOp = getRootOperation(computeOps); - if (failed(rootOp)) + if (failed(rootOp)) { return failure(); + } Operation *rootOperation = rootOp.value(); // Handle the case with no known root operation. diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPU2DScalableTo1DScalable.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPU2DScalableTo1DScalable.cpp index 48e404cea3ab..1961fe1e6125 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPU2DScalableTo1DScalable.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPU2DScalableTo1DScalable.cpp @@ -91,8 +91,9 @@ class LLVMCPU2DScalableTo1DScalablePass }; static bool opKnownToSupport2DScalableVectorizationWithArmSME(Operation *op) { - if (auto genericOp = dyn_cast(op)) + if (auto genericOp = dyn_cast(op)) { return isLinalgGeneric2DTranspose(genericOp); + } return isa(op); } @@ -206,16 +207,18 @@ dropScalabilityFromUnsupportedOperations(mlir::FunctionOpInterface funcOp, scf::SCFTilingOptions options; setSCFTileSizes(options, tilingOp, loopTileSizes, /*tileScalableFlags=*/{}); auto tilingResult = scf::tileUsingSCF(rewriter, tilingOp, options); - if (failed(tilingResult)) + if (failed(tilingResult)) { return failure(); + } // Update the lowering config of the new tiled operations. IREE::CPU::LoweringConfigAttr newLoweringConfig = getLoweringConfigWithNewVectorSizes(loweringConfigAttr, *vectorSizes, newScalableFlags); for (auto *newOp : tilingResult->tiledOps) { - if (isa(newOp)) + if (isa(newOp)) { setLoweringConfig(newOp, newLoweringConfig); + } } rewriter.replaceOp(tilingOp, tilingResult->replacements); @@ -225,8 +228,9 @@ dropScalabilityFromUnsupportedOperations(mlir::FunctionOpInterface funcOp, void LLVMCPU2DScalableTo1DScalablePass::runOnOperation() { if (failed(dropScalabilityFromUnsupportedOperations(getOperation(), - assumeArmSME))) + assumeArmSME))) { signalPassFailure(); + } } } // namespace diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUAssignConstantOrdinals.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUAssignConstantOrdinals.cpp index e3eeb49d3ed0..815fd344b0e8 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUAssignConstantOrdinals.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUAssignConstantOrdinals.cpp @@ -23,8 +23,9 @@ struct LLVMCPUAssignConstantOrdinalsPass // Get a constant key -> ordinal mapping. auto keyOrdinals = variantOp.gatherConstantOrdinals(); - if (keyOrdinals.empty()) + if (keyOrdinals.empty()) { return; + } // Update placeholders to hold the concrete ordinal values. // Eventually MLIR or LLVM will inline them. @@ -33,8 +34,9 @@ struct LLVMCPUAssignConstantOrdinalsPass llvm::make_early_inc_range(moduleOp.getOps())) { auto keyAttr = globalOp->getAttr( IREE::HAL::ExecutableConstantBlockOp::getKeyAttrName()); - if (!keyAttr) + if (!keyAttr) { continue; + } auto it = keyOrdinals.find(keyAttr); if (it == keyOrdinals.end()) { globalOp.emitOpError() diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUAssignImportOrdinals.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUAssignImportOrdinals.cpp index 4827f7c3ef08..b7e6726d1560 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUAssignImportOrdinals.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUAssignImportOrdinals.cpp @@ -37,13 +37,15 @@ struct LLVMCPUAssignImportOrdinalsPass for (auto globalOp : llvm::make_early_inc_range(moduleOp.getOps())) { auto keyAttr = globalOp->getAttrOfType(importKeyAttr); - if (!keyAttr) + if (!keyAttr) { continue; + } uniqueKeys.insert(keyAttr); ordinalGlobals[keyAttr].push_back(globalOp); } - if (uniqueKeys.empty()) + if (uniqueKeys.empty()) { return; + } auto sortedKeys = uniqueKeys.takeVector(); llvm::stable_sort(sortedKeys, [](auto lhs, auto rhs) { return lhs.getValue() < rhs.getValue(); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUCheckIRBeforeLLVMConversion.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUCheckIRBeforeLLVMConversion.cpp index 95bd62f46f81..d35cc4c2df47 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUCheckIRBeforeLLVMConversion.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUCheckIRBeforeLLVMConversion.cpp @@ -40,8 +40,9 @@ struct LLVMCPUCheckIRBeforeLLVMConversionPass /// defined for HAL LLVMCPU target). static LogicalResult checkStackAllocationSize(mlir::FunctionOpInterface funcOp) { - if (funcOp.getFunctionBody().empty()) + if (funcOp.getFunctionBody().empty()) { return success(); + } // In rare cases where the attribute is not present in the module, a value of // 32KB will be taken. @@ -73,8 +74,9 @@ checkStackAllocationSize(mlir::FunctionOpInterface funcOp) { int allocaSize = 1; auto allocaType = cast(allocaOp.getType()); for (auto dimSize : allocaType.getShape()) { - if (ShapedType::isDynamic(dimSize)) + if (ShapedType::isDynamic(dimSize)) { continue; + } allocaSize *= dimSize; } for (auto operand : allocaOp.getDynamicSizes()) { diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMmt4dVectorLowering.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMmt4dVectorLowering.cpp index 3e66355ad9a6..1d38122ce7ed 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMmt4dVectorLowering.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMmt4dVectorLowering.cpp @@ -48,13 +48,15 @@ void LLVMCPUMmt4dVectorLoweringPass::runOnOperation() { std::optional numLoops; funcOp.walk([&](vector::ContractionOp op) { - if (numLoops) + if (numLoops) { return signalPassFailure(); + } numLoops = op.getIndexingMapsArray()[0].getNumDims(); }); // No vector.contract op to optimize. - if (!numLoops) + if (!numLoops) { return; + } { // Fold consumer add ops into the contraction op itself. diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUPeel.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUPeel.cpp index 2a688c74523e..1edcb02b504e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUPeel.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUPeel.cpp @@ -30,8 +30,9 @@ namespace { // stages. void collectLoopsToPeel(Operation *op, llvm::SmallSetVector &loopsToPeel) { - if (!iree_compiler::getLoweringConfig(op)) + if (!iree_compiler::getLoweringConfig(op)) { return; + } int maxNumLoopsToPeel = TypeSwitch(op) .Case([](auto linalgOp) { @@ -44,8 +45,9 @@ void collectLoopsToPeel(Operation *op, for (int i = 0; i < maxNumLoopsToPeel; ++i) { op = op->getParentOfType(); auto loop = cast_or_null(op); - if (!loop || iree_compiler::isTiledAndDistributedLoop(loop)) + if (!loop || iree_compiler::isTiledAndDistributedLoop(loop)) { break; + } LDBG() << "Loop to peel\n " << *op; loopsToPeel.insert(loop); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp index 4f5555e841fb..e6d8e35d45c0 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp @@ -223,8 +223,9 @@ static LogicalResult verifyLoweringConfiguration(FunctionOpInterface funcOp, return WalkResult::advance(); } auto loweringConfig = getLoweringConfig(op); - if (!loweringConfig) + if (!loweringConfig) { return WalkResult::advance(); + } return verificationFn(op, loweringConfig); }); return failure(walkResult.wasInterrupted()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTile.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTile.cpp index b02963a6dcd2..14bd061c545f 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTile.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTile.cpp @@ -57,15 +57,17 @@ void LLVMCPUTilePass::runOnOperation() { SmallVector computeOps = getComputeOps(funcOp); for (auto computeOp : computeOps) { auto op = dyn_cast(computeOp); - if (!op || op.getLoopIteratorTypes().empty()) + if (!op || op.getLoopIteratorTypes().empty()) { continue; + } // For now do not tile `tensor.pad` operations. The `tensor.pad` // operations might be those introduced by the padding-based // codegeneration strategy. Those are not meant to be tiled again. // Need a better way for handling this, but this works for now. - if (isa(computeOp)) + if (isa(computeOp)) { continue; + } IREE::Codegen::LoweringConfigAttrInterface maybeLoweringConfig = getLoweringConfig(op); @@ -104,8 +106,9 @@ void LLVMCPUTilePass::runOnOperation() { std::move(tileScalableFlags)); FailureOr tiledResults = scf::tileUsingSCF(rewriter, op, options); - if (failed(tiledResults)) + if (failed(tiledResults)) { continue; + } rewriter.replaceOp(op, tiledResults->replacements); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorTransposeLowering.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorTransposeLowering.cpp index 263c336d894a..83768362d7ec 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorTransposeLowering.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorTransposeLowering.cpp @@ -24,8 +24,9 @@ static bool has16x16Transpose(mlir::FunctionOpInterface funcOp) { bool res = false; funcOp.walk([&](vector::TransposeOp op) { auto srcGtOneDims = isTranspose2DSlice(op); - if (failed(srcGtOneDims)) + if (failed(srcGtOneDims)) { return WalkResult::advance(); + } VectorType srcType = op.getSourceVectorType(); int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value())); int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value())); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions/BUILD.bazel index 8e73e87b943c..522ae9359b0e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "LLVMCPUExtensionsOps.td", ], diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.cpp index 3c53e1ad239d..9ba70c0ce6a4 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.cpp @@ -104,36 +104,43 @@ bool hasI8mmFeature(DictionaryAttr targetConfig) { bool isLinalgGeneric2DTranspose(linalg::GenericOp genericOp) { // Check op has 2 dimensions. - if (genericOp.getNumLoops() != 2) + if (genericOp.getNumLoops() != 2) { return false; + } // Check op has single input and output. - if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1) + if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1) { return false; + } // Check all iterators are parallel. - if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) + if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) { return false; + } // Check that the two indexing maps are a permutation of each other. SmallVector indexingMaps = genericOp.getIndexingMapsArray(); bool isTranspose = (indexingMaps[0].isPermutation() && indexingMaps[1].isIdentity()) || (indexingMaps[1].isPermutation() && indexingMaps[0].isIdentity()); - if (!isTranspose) + if (!isTranspose) { return false; + } // Make sure the region only contains a yield op. Block &body = genericOp.getRegion().front(); - if (!llvm::hasSingleElement(body)) + if (!llvm::hasSingleElement(body)) { return false; + } auto yieldOp = cast(body.getTerminator()); // The yield op should return the block argument corresponding to the input. auto yieldArg = dyn_cast(yieldOp.getValues()[0]); - if (!yieldArg || yieldArg.getArgNumber() != 0 || yieldArg.getOwner() != &body) + if (!yieldArg || yieldArg.getArgNumber() != 0 || + yieldArg.getOwner() != &body) { return false; + } return true; } diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp index a5f09bba5a1b..974527d74fed 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp @@ -853,10 +853,11 @@ class MMTKernelGenerator { // the constraints string. Not confusing at all! inputs.append(lhs.begin(), lhs.end()); for (const auto &v : rhs) { - if (cast(v.getType()).getNumElements() == 1) + if (cast(v.getType()).getNumElements() == 1) { inputs.push_back(extract(rewriter, loc, v, 0)); - else + } else { inputs.push_back(v); + } } inputs.append(acc.begin(), acc.end()); // Create the inline asm op. @@ -1039,8 +1040,9 @@ struct MMT_8x4x8_i8i8i32_Aarch64Dotprod_Intrinsics Value inLhs = getUnpromotedInput(I8Type, I32Type, lhs); Value inRhs = getUnpromotedInput(I8Type, I32Type, rhs); - if (!inLhs || !inRhs) + if (!inLhs || !inRhs) { return failure(); + } auto loc = contractionOp.getLoc(); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp index b0995cafd0a8..7f44dd67bcfe 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp @@ -37,11 +37,13 @@ void ConvertToDynamicSharedMemory(ModuleOp moduleOp) { moduleOp.walk([&](LLVM::AddressOfOp addressOfOp) { // Check that the global associated with this addressOfOp has shared memory // space. - if (addressOfOp.getGlobal(symbolTableCollection).getAddrSpace() == 3) + if (addressOfOp.getGlobal(symbolTableCollection).getAddrSpace() == 3) { addressOfOps.push_back(addressOfOp); + } }); - if (addressOfOps.size() == 0) + if (addressOfOps.size() == 0) { return; + } OpBuilder builder(moduleOp); builder.setInsertionPoint(&moduleOp.front()); auto type = @@ -118,8 +120,9 @@ struct ScalarizeMathOp : public OpRewritePattern { LogicalResult matchAndRewrite(MathOpTy mathOp, PatternRewriter &rewriter) const override { auto vecType = dyn_cast(mathOp.getType()); - if (!vecType) + if (!vecType) { return failure(); + } Location loc = mathOp.getLoc(); Value newVector = arith::ConstantOp::create(rewriter, loc, vecType, rewriter.getZeroAttr(vecType)); @@ -151,8 +154,9 @@ struct ConvertSharedMemAllocOp : public OpRewritePattern { LogicalResult matchAndRewrite(memref::AllocOp allocOp, PatternRewriter &rewriter) const override { - if (!hasSharedMemoryAddressSpace(allocOp.getType())) + if (!hasSharedMemoryAddressSpace(allocOp.getType())) { return failure(); + } ArrayRef shape = allocOp.getType().getShape(); if (ShapedType::isDynamicShape(shape)) { return failure(); @@ -164,15 +168,16 @@ struct ConvertSharedMemAllocOp : public OpRewritePattern { } else { // If no alignment specified align at least to the size of an element. Type elType = allocOp.getType().getElementType(); - if (auto shapeType = dyn_cast(elType)) + if (auto shapeType = dyn_cast(elType)) { alignement = shapeType.getNumElements() * shapeType.getElementTypeBitWidth() / 8; - else if (elType.isIndex()) { + } else if (elType.isIndex()) { auto mod = allocOp->getParentOfType(); LowerToLLVMOptions options(mod.getContext(), DataLayout(mod)); alignement = options.getIndexBitwidth() / 8; - } else + } else { alignement = elType.getIntOrFloatBitWidth() / 8; + } } // In CUDA workgroup memory is represented by a global variable. MemRefType allocType = allocOp.getType(); @@ -262,8 +267,9 @@ class ConvertFunc : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { FunctionType fnType = funcOp.getFunctionType(); (void)fnType; - if (!funcOp.isPublic()) + if (!funcOp.isPublic()) { return failure(); + } // illegal FuncOp must have 0 inputs. assert(fnType.getNumInputs() == 0 && fnType.getNumResults() == 0); @@ -296,8 +302,9 @@ class ConvertFunc : public ConvertOpToLLVMPattern { FailureOr> maybeBindingsInfo = analyzeSubspans(subspans, numBindings, getTypeConverter()); - if (failed(maybeBindingsInfo)) + if (failed(maybeBindingsInfo)) { return failure(); + } auto bindingsInfo = std::move(*maybeBindingsInfo); SmallVector llvmInputTypes; @@ -309,8 +316,9 @@ class ConvertFunc : public ConvertOpToLLVMPattern { // All the push constants are i32 and go at the end of the argument list. llvmInputTypes.resize(numBindings + numConstants, rewriter.getI32Type()); - if (!llvmInputTypes.empty()) + if (!llvmInputTypes.empty()) { signatureConverter.addInputs(llvmInputTypes); + } // Construct newFunc with all attributes except return type & symbol name. SmallVector funcAttrs; @@ -384,8 +392,9 @@ struct ConvertIREEBindingSubspanOp final ConversionPatternRewriter &rewriter) const override { // Bail until nested under an LLVMFuncOp. auto llvmFuncOp = op->getParentOfType(); - if (!llvmFuncOp) + if (!llvmFuncOp) { return failure(); + } assert(llvmFuncOp.getNumArguments() > 0); Location loc = op->getLoc(); @@ -399,7 +408,7 @@ struct ConvertIREEBindingSubspanOp final auto [strides, offset] = memrefType.getStridesAndOffset(); if (memrefType.hasStaticShape() && - !llvm::any_of(strides, ShapedType::isDynamic) && + llvm::none_of(strides, ShapedType::isDynamic) && ShapedType::isStatic(offset)) { auto desc = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), memrefType, llvmBufferBasePtr); @@ -489,8 +498,9 @@ struct ConvertIREEConstantOp final ConversionPatternRewriter &rewriter) const override { // Bail until nested under an LLVMFuncOp. auto llvmFuncOp = op->getParentOfType(); - if (!llvmFuncOp) + if (!llvmFuncOp) { return failure(); + } assert(llvmFuncOp.getNumArguments() > 0); auto ireeConstantOp = cast(op); @@ -572,8 +582,9 @@ struct HALInterfaceWorkgroupOpsConverter final gpu::Dimension::z}; NewOpTy newOp = rewriter.replaceOpWithNewOp(op, op.getType(), dimAttr[index]); - if (IntegerAttr bound = op.getUpperBoundAttr()) + if (IntegerAttr bound = op.getUpperBoundAttr()) { newOp.setUpperBoundAttr(bound); + } return success(); } }; @@ -602,23 +613,26 @@ struct ConvertIREEUtilAssumeIntOp final ConversionPatternRewriter &rewriter) const override { // Bail until nested under an LLVMFuncOp. auto llvmFuncOp = op->getParentOfType(); - if (!llvmFuncOp) + if (!llvmFuncOp) { return failure(); + } Location loc = op.getLoc(); auto updateConds = [&](std::optional &conds, Value cond) { - if (!conds) + if (!conds) { conds = cond; - else + } else { conds = LLVM::AndOp::create(rewriter, loc, *conds, cond); + } }; // Materialize the assumptions that aren't atteched directly to arguments // in order to account for the fact that i64 inputs get passed in as a pair // of i32 constants. for (auto [idx, mlirVal, llvmVal] : llvm::enumerate(op.getOperands(), adaptor.getOperands())) { - if (mlirVal.getDefiningOp()) + if (mlirVal.getDefiningOp()) { continue; + } std::optional conds; Type type = llvmVal.getType(); auto [min, max] = op.getUnionedUnsignedRange(idx); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index aced4d9490f0..9688fc8597cf 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -224,8 +224,9 @@ static LogicalResult setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, mlir::FunctionOpInterface entryPoint, linalg::LinalgOp op) { - if (target.getWgp().getMma().empty()) + if (target.getWgp().getMma().empty()) { return failure(); + } const int64_t targetSubgroupSize = target.getPreferredSubgroupSize(); @@ -303,15 +304,17 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, intrinsics.reserve(target.getWgp().getMma().size()); MLIRContext *context = op.getContext(); for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { - if (mma.getSubgroupSize() != targetSubgroupSize) + if (mma.getSubgroupSize() != targetSubgroupSize) { continue; + } storeMmaInfo(mma, intrinsics); // Skip adding any virtual intrinsics since they are not tested for // convolutions. } - if (intrinsics.empty()) + if (intrinsics.empty()) { return failure(); + } // TODO: Replace the below with algorithm described in // https://github.com/iree-org/iree/discussions/21506. @@ -330,12 +333,12 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, // First try to find a schedule with an exactly matching intrinsic. FailureOr schedule = deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes, - targetSubgroupSize, wgpCount); + targetSubgroupSize, wgpCount, op.getLoc()); if (failed(schedule)) { // Then try again by allowing upcasting accumulator. schedule = deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes, - targetSubgroupSize, wgpCount, + targetSubgroupSize, wgpCount, op.getLoc(), /*transposedLhs*/ false, /*transposedRhs*/ false, /*canUpcastAcc=*/true); } @@ -429,9 +432,11 @@ debugPrintContractionInfo(StringRef label, unsigned numLoops, contractionDims.n, contractionDims.k}; std::string dimSymbols(numLoops, '*'); for (auto [idx, val] : llvm::enumerate(dimSymbols)) { - for (auto [letter, dim] : llvm::zip_equal(StringRef("bmnk"), dimVals)) - if (llvm::is_contained(dim, idx)) + for (auto [letter, dim] : llvm::zip_equal(StringRef("bmnk"), dimVals)) { + if (llvm::is_contained(dim, idx)) { val = letter; + } + } } DBGS() << "Contraction dims: " << llvm::interleaved_array(dimSymbols) << "\n"; DBGS() << label << ": " << llvm::interleaved_array(sizes) << "\n"; @@ -441,8 +446,9 @@ static LogicalResult setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, mlir::FunctionOpInterface entryPoint, linalg::LinalgOp op) { - if (target.getWgp().getMma().empty()) + if (target.getWgp().getMma().empty()) { return failure(); + } const int64_t targetSubgroupSize = target.getPreferredSubgroupSize(); @@ -515,9 +521,16 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, // all instances of schedule->m/nSubgroupCounts[0], // schedule->m/n/kTileSizes[0] and schedule->m/n/kSizes[0] need to use the // full list of sizes instead of just the first element. - GPUMatmulShapeType problem{ - {bounds[mDim]}, {bounds[nDim]}, {bounds[kDim]}, getDimBounds(batchDims), - lhsElemType, rhsElemType, initElemType, numHorizontallyFusedOps}; + GPUMatmulShapeType problem{{bounds[mDim]}, + {bounds[nDim]}, + {bounds[kDim]}, + getDimBounds(batchDims), + lhsElemType, + rhsElemType, + initElemType, + /*aScaleType=*/nullptr, + /*bScaleType=*/nullptr, + numHorizontallyFusedOps}; // Helper fn to store mma information. auto storeMmaInfo = [](IREE::GPU::MmaInterfaceAttr mma, @@ -531,14 +544,16 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, intrinsics.reserve(target.getWgp().getMma().size()); MLIRContext *context = op.getContext(); for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { - if (mma.getSubgroupSize() != targetSubgroupSize) + if (mma.getSubgroupSize() != targetSubgroupSize) { continue; + } storeMmaInfo(mma, intrinsics); // Skip adding any virtual intrinsics since they are not tested for matmuls. } - if (intrinsics.empty()) + if (intrinsics.empty()) { return failure(); + } GPUMMAHeuristicSeeds seeds; @@ -582,13 +597,13 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, // First try to find a schedule with an exactly matching intrinsic. std::optional schedule = deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes, - targetSubgroupSize, wgpCount); + targetSubgroupSize, wgpCount, op.getLoc()); if (!schedule) { // Then try again by allowing upcasting accumulator. - schedule = deduceMMASchedule(problem, intrinsics, seeds, - maxSharedMemoryBytes, targetSubgroupSize, - wgpCount, transposedLhs, transposedRhs, - /*canUpcastAcc=*/true); + schedule = + deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes, + targetSubgroupSize, wgpCount, op.getLoc(), + transposedLhs, transposedRhs, /*canUpcastAcc=*/true); } if (!schedule) { @@ -697,8 +712,9 @@ setAttentionPipelineAttributes(IREE::GPU::TargetAttr target, static LogicalResult setAttentionIntrinsicBasedVectorDistributionConfig( IREE::GPU::TargetAttr target, mlir::FunctionOpInterface entryPoint, IREE::LinalgExt::AttentionOp op) { - if (target.getWgp().getMma().empty()) + if (target.getWgp().getMma().empty()) { return failure(); + } const int64_t targetSubgroupSize = target.getPreferredSubgroupSize(); @@ -779,8 +795,9 @@ static LogicalResult setAttentionIntrinsicBasedVectorDistributionConfig( intrinsics.reserve(target.getWgp().getMma().size()); MLIRContext *context = op.getContext(); for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { - if (mma.getSubgroupSize() != targetSubgroupSize) + if (mma.getSubgroupSize() != targetSubgroupSize) { continue; + } storeMmaInfo(mma, intrinsics); // Store info on virtual intrinsics based on current mma if any for (IREE::GPU::VirtualMMAIntrinsic virtualIntrinsic : @@ -791,8 +808,9 @@ static LogicalResult setAttentionIntrinsicBasedVectorDistributionConfig( } } - if (intrinsics.empty()) + if (intrinsics.empty()) { return failure(); + } // We assume that P uses the element type of V for input // and both matmuls have f32 as output. It is possible to use other element @@ -1337,8 +1355,9 @@ setVectorDistributionConfig(IREE::GPU::TargetAttr target, Operation *computeOp) { // We haven't properly plumbed through MMA op layouts and conversions for CUDA // to target NVIDIA GPUs. So disable the vector distribution pass for it. - if (!isROCmBackend(target)) + if (!isROCmBackend(target)) { return failure(); + } if (!clGPUEnableVectorDistribution) { LDBG() << "Vector Distribution not enabled, skipping..."; @@ -1402,8 +1421,9 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target, staticNonUnitParallelDimCount += bounds[nDim] != 1 && ShapedType::isStatic(bounds[nDim]); } - if (staticNonUnitParallelDimCount <= 1) + if (staticNonUnitParallelDimCount <= 1) { return failure(); + } // Don't consider operations that don't have a broadcast, those should go // through reductions. @@ -1463,8 +1483,9 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target, } std::optional subgroupSize = std::nullopt; - if (!subgroupSizes.empty()) + if (!subgroupSizes.empty()) { subgroupSize = subgroupSizes.front(); + } // For the LLVMGPUTileAndFuse pipeline, we need to split tile sizes // for workgroup, thread, and reduction. @@ -1592,8 +1613,9 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target, int64_t tileK = config.tileSize[2]; // Since specialization doesn't work for K loop and peeling is not enabled yet // we pick a tileK size that is aligned on the K size. - if (ShapedType::isDynamic(sizeK)) + if (ShapedType::isDynamic(sizeK)) { tileK = 1; + } while (sizeK % tileK != 0) { tileK >>= 1; } @@ -1773,8 +1795,9 @@ static LogicalResult setRootDefaultConfig(IREE::GPU::TargetAttr target, shape.back() % (workgroupSize[0] * vectorSize) != 0) { vectorSize /= 2; } - if (vectorSize == 1) // assume there is fastpath + slowpath + if (vectorSize == 1) { // assume there is fastpath + slowpath vectorSize = 4; + } int64_t problemSize = llvm::product_of(shape); if ((problemSize / (preferredSubgroupSize * vectorSize)) < 64) { vectorSize = 1; @@ -1788,8 +1811,9 @@ static LogicalResult setRootDefaultConfig(IREE::GPU::TargetAttr target, int64_t id = 0; for (int64_t dim : llvm::reverse(shape)) { // Unit loops are already skipped. - if (dim == 1) + if (dim == 1) { continue; + } if (dim < flatWG) { skipInnerTiling++; workgroupSize[id] = dim; @@ -1799,8 +1823,9 @@ static LogicalResult setRootDefaultConfig(IREE::GPU::TargetAttr target, } flatWG = flatWG / dim; id++; - if (flatWG <= 1 || id >= workgroupSize.size()) + if (flatWG <= 1 || id >= workgroupSize.size()) { break; + } } break; } @@ -1831,8 +1856,9 @@ static LogicalResult setRootDefaultConfig(IREE::GPU::TargetAttr target, workgroupTileSizes[depth - 1] = 0; skipInnerTiling--; id++; - if (id >= workgroupSize.size()) + if (id >= workgroupSize.size()) { break; + } continue; } workgroupTileSizes[depth - 1] = workgroupSize[id] * vectorSize; @@ -1873,12 +1899,14 @@ static bool isMatvecLike(linalg::LinalgOp linalgOp) { // TODO: Allow for matvec with fused dequantization. FailureOr dims = linalg::inferContractionDims(linalgOp); - if (failed(dims)) + if (failed(dims)) { return false; + } // TODO: Support batch matvec. - if (!dims->batch.empty()) + if (!dims->batch.empty()) { return false; + } if (dims->m.size() >= 2 || dims->n.size() >= 2 || !llvm::hasSingleElement(dims->k)) { @@ -2000,8 +2028,9 @@ static LogicalResult setArgmaxUkernelConfig( op.getReductionDims(reductionDims); // Currently Argmax UKernel only support 1 reduction dim. - if (reductionDims.size() != 1) + if (reductionDims.size() != 1) { return failure(); + } // Make sure reduction dimensions are static and innermost ones. SmallVector bounds = op.getStaticLoopRanges(); @@ -2075,14 +2104,16 @@ static bool distributeToOneDim(const int64_t inputDim, // Handle 4 elements per thread for the innermost dimension. We need // this for vectorized load. chosenTileSize = 4; - if (inputDim % (dim * chosenTileSize) != 0) + if (inputDim % (dim * chosenTileSize) != 0) { continue; + } } else { - for (int64_t t = residualTilingFactor; t >= 1; t >>= 1) + for (int64_t t = residualTilingFactor; t >= 1; t >>= 1) { if (inputDim % (dim * t) == 0) { chosenTileSize = t; break; } + } } if (chosenTileSize) { wgDimSize = dim; @@ -2185,8 +2216,9 @@ static LogicalResult setConvolutionConfig( // OC -> x if (!distributeToOneDim(oc, /*isInnerMostDim=*/true, residualThreads, residualTilingFactor, workgroupSize[0], - workgroupTileSizes[3])) + workgroupTileSizes[3])) { return failure(); + } // Deduce the configruation for the OW and OH dimension. Try to make them // even if possible given we typically have images with the same height @@ -2212,10 +2244,11 @@ static LogicalResult setConvolutionConfig( auto pipeline = CodeGenPipeline::LLVMGPUVectorize; TileSizesListType tileSizes; // Add reduction tile sizes. - if (isNCHW) + if (isNCHW) { workgroupTileSizes.append({4, 1, 1}); - else if (isNHWC) + } else if (isNHWC) { workgroupTileSizes.append({1, 1, 4}); + } tileSizes.push_back(workgroupTileSizes); // Tile along OH by size 1 to enable downsizing 2-D convolution to 1-D. @@ -2364,8 +2397,9 @@ static void propagateLoweringConfig(Operation *rootOperation, if (IREE::Codegen::LoweringConfigAttrInterface config = getLoweringConfig(rootOperation)) { for (auto op : computeOps) { - if (op == rootOperation) + if (op == rootOperation) { continue; + } setLoweringConfig(op, config); } } @@ -2376,8 +2410,9 @@ static void propagateLoweringConfig(Operation *rootOperation, //===----------------------------------------------------------------------===// LogicalResult initGPULaunchConfig(FunctionOpInterface funcOp) { IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp); - if (!target) + if (!target) { return funcOp.emitError("missing GPU target in #hal.executable.target"); + } auto exportOp = getEntryPoint(funcOp); if (!getTranslationInfo(funcOp) && exportOp) { @@ -2500,8 +2535,9 @@ LogicalResult initGPULaunchConfig(FunctionOpInterface funcOp) { return success(); } - if (failed(setRootConfig(target, funcOp, rootOperation))) + if (failed(setRootConfig(target, funcOp, rootOperation))) { return funcOp.emitOpError("failed to set root config"); + } if (IREE::Codegen::TranslationInfoAttr translationInfo = getTranslationInfo(funcOp)) { diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUAssignConstantOrdinals.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUAssignConstantOrdinals.cpp index 6428cf2bad99..a7841657b4f2 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUAssignConstantOrdinals.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUAssignConstantOrdinals.cpp @@ -23,8 +23,9 @@ struct LLVMGPUAssignConstantOrdinalsPass // Get a constant key -> ordinal mapping. auto keyOrdinals = variantOp.gatherConstantOrdinals(); - if (keyOrdinals.empty()) + if (keyOrdinals.empty()) { return; + } // Update placeholders to hold the concrete ordinal values. // Eventually MLIR or LLVM will inline them. @@ -33,8 +34,9 @@ struct LLVMGPUAssignConstantOrdinalsPass llvm::make_early_inc_range(moduleOp.getOps())) { auto keyAttr = globalOp->getAttr( IREE::HAL::ExecutableConstantBlockOp::getKeyAttrName()); - if (!keyAttr) + if (!keyAttr) { continue; + } auto it = keyOrdinals.find(keyAttr); if (it == keyOrdinals.end()) { globalOp.emitOpError() diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastAddressSpaceFunction.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastAddressSpaceFunction.cpp index 3164c15f210c..77a691ce3cdb 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastAddressSpaceFunction.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastAddressSpaceFunction.cpp @@ -66,8 +66,9 @@ struct LLVMGPUCastAddressSpaceFunctionPass final SymbolTable::lookupSymbolIn(moduleOp, callee)); if (fnDecl) { SmallVector callArgumentTypes; - for (auto op : newOperands) + for (auto op : newOperands) { callArgumentTypes.push_back(op.getType()); + } FunctionType functionType = rewriter.getFunctionType( callArgumentTypes, fnDecl->getResultTypes()); fnDecl.setType(functionType); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp index 6fa62b9d5e33..f09091de6ffe 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp @@ -71,8 +71,9 @@ void LLVMGPULowerExecutableTargetPass::runOnOperation() { FunctionOpInterface funcOp = getOperation(); IREE::Codegen::TranslationInfoAttr translationInfo = getTranslationInfo(funcOp); - if (!translationInfo) + if (!translationInfo) { return; + } std::optional maybePipeline = getFunctionOpInterfacePassManager(funcOp); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUSelectLoweringStrategy.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUSelectLoweringStrategy.cpp index 04be0ad56e92..804f4d88ca2f 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUSelectLoweringStrategy.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUSelectLoweringStrategy.cpp @@ -43,8 +43,9 @@ static LogicalResult verifyLoweringConfiguration( IREE::Codegen::TranslationInfoAttr translationInfo) { auto walkResult = funcOp.walk([&](Operation *op) -> WalkResult { auto loweringConfig = getLoweringConfig(op); - if (!loweringConfig) + if (!loweringConfig) { return success(); + } if (translationInfo.getDispatchLoweringPassPipeline() == IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUVectorDistribute) { diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp index 968879feacb2..546ce726be7b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp @@ -56,13 +56,15 @@ static void populateVectorUnrollPatterns(RewritePatternSet &patterns, bool useMmaSyncShape) { auto unrollOrder = [](Operation *op) -> std::optional> { auto contract = dyn_cast(op); - if (!contract) + if (!contract) { return std::nullopt; + } return gpuMmaUnrollOrder(contract); }; auto getNativeShape = [useMmaSyncShape](Operation *op) { - if (useMmaSyncShape) + if (useMmaSyncShape) { return getMmaNativeVectorSize(op); + } return getWmmaNativeVectorSize(op); }; vector::populateVectorUnrollPatterns( diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp index d546b1f426b7..852238b5be91 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp @@ -115,13 +115,15 @@ calculateDistributedTileSize(ArrayRef numElements, OpBuilder &builder, unsigned idIdx = 0; std::reverse(distributedDim.begin(), distributedDim.end()); for (unsigned depth : partitionedLoops) { - if (depth >= blockTileSize.size()) + if (depth >= blockTileSize.size()) { continue; + } tileSizesVal[depth] = arith::ConstantIndexOp::create( builder, operation->getLoc(), llvm::divideCeil(blockTileSize[depth], distributedDim[idIdx++])); - if (idIdx == kNumMaxParallelDims) + if (idIdx == kNumMaxParallelDims) { break; + } } return tileSizesVal; } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp index 95bca2146d31..d69e8fe68a10 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp @@ -66,8 +66,9 @@ struct PromoteContractOperands final Value promoteToElementType(Location loc, RewriterBase &rewriter, Value v, Type dstElementType) const { Type elementType = getElementTypeOrSelf(v.getType()); - if (elementType == dstElementType) + if (elementType == dstElementType) { return v; + } // vector.contract only allows extension on operands. assert(elementType.getIntOrFloatBitWidth() <= @@ -75,11 +76,13 @@ struct PromoteContractOperands final "vector.contract does not allow truncation of operands"); Type promotedType = dstElementType; - if (auto vecType = dyn_cast(v.getType())) + if (auto vecType = dyn_cast(v.getType())) { promotedType = vecType.clone(promotedType); + } - if (isa(dstElementType)) + if (isa(dstElementType)) { return arith::ExtFOp::create(rewriter, loc, promotedType, v); + } // For integer types, vector.contract only supports signless integer types // and promotion happens via sign extension. return arith::ExtSIOp::create(rewriter, loc, promotedType, v); @@ -409,8 +412,9 @@ struct ContractToChainFMA final : OpRewritePattern { static std::optional getDimPosition(AffineMap map, unsigned dim) { for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { - if (map.getDimPosition(i) == dim) + if (map.getDimPosition(i) == dim) { return i; + } } return std::nullopt; } @@ -419,8 +423,9 @@ struct ContractToChainFMA final : OpRewritePattern { ArrayAttr iteratorTypes) { SmallVector dimsIdx; for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { - if (vector::isReductionIterator(iteratorTypes[map.getDimPosition(i)])) + if (vector::isReductionIterator(iteratorTypes[map.getDimPosition(i)])) { dimsIdx.push_back(i); + } } return dimsIdx; } @@ -506,8 +511,10 @@ struct UnrollElementwiseOps final : public RewritePattern { LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) + if (!OpTrait::hasElementwiseMappableTraits(op) || + op->getNumResults() != 1) { return failure(); + } Location loc = op->getLoc(); VectorType dstVecTy = dyn_cast(op->getResult(0).getType()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 5a80b5335d49..d3f8639a497d 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -12,6 +12,7 @@ #include "iree/compiler/Codegen/Common/CombineLayoutTransformation.h" #include "iree/compiler/Codegen/Common/GPU/Passes.h" #include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" #include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h" #include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h" #include "iree/compiler/Codegen/LLVMGPU/Passes.h" @@ -171,8 +172,9 @@ static LogicalResult gpuCopyFn(OpBuilder &builder, Location loc, Value from, if (hasSharedMemoryAddressSpace(cast(to.getType()))) { needsBarrier = true; } - if (needsBarrier) + if (needsBarrier) { gpu::BarrierOp::create(builder, loc); + } Operation *copy = memref::CopyOp::create(builder, loc, from, to); if (needsBarrier) { setMarker(copy, getCopyToWorkgroupMemoryMarker()); @@ -188,14 +190,16 @@ static LogicalResult canReorderWorkgroups(FunctionOpInterface funcOp) { if (!target) { return failure(); } - if (target.getBackend() != "rocm") + if (target.getBackend() != "rocm") { return success(); + } // Workgroup reordering on ROCm currently requires all workgrup counts to be // static. SmallVector workgroupCounts = getStaticNumWorkgroups(funcOp); - if (llvm::any_of(workgroupCounts, ShapedType::isDynamic)) + if (llvm::any_of(workgroupCounts, ShapedType::isDynamic)) { return failure(); + } // This is further restricted to 2D+ grids as we reorder along the X and Y // workgroup IDs. @@ -399,9 +403,8 @@ LogicalResult isAtBoundary(Operation *op) { return success(); } } else if (isa(op)) { - if (llvm::all_of(op->getUsers(), [](Operation *user) { - return isa(user); - })) { + if (llvm::all_of(op->getUsers(), + llvm::IsaPred)) { return success(); } } @@ -579,6 +582,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager, funcPassManager.addPass(createCSEPass()); // Step 9. Remaining post-bufferization optimizations/lowerings. + funcPassManager.addPass(createFlattenSwizzleHintAllocsPass()); funcPassManager.addPass(createPropagateDispatchSizeBoundsPass()); funcPassManager.addPass(IREE::GPU::createLowerIREEGPUOpsPass()); funcPassManager.addPass(createUnrollAnnotatedLoopsPass()); @@ -690,8 +694,9 @@ static LogicalResult gpuVectorCopyFn(OpBuilder &builder, Location loc, if (hasSharedMemoryAddressSpace(cast(to.getType()))) { needsBarrier = true; } - if (needsBarrier) + if (needsBarrier) { gpu::BarrierOp::create(builder, loc); + } VectorType vectorType = VectorType::get(fromType.getShape(), fromType.getElementType()); Value c0 = arith::ConstantIndexOp::create(builder, loc, 0); @@ -736,6 +741,8 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager, funcPassManager.addPass(createCSEPass()); funcPassManager.addPass(createGPUPromoteMatmulOperandsPass()); + funcPassManager.addPass(createGPUExpandDimensionsPass()); + // Tile to reduction loops. { GPUApplyTilingLevelPassOptions options; @@ -1138,7 +1145,8 @@ void buildLLVMGPUCodegenPassPipeline(OpPassManager &variantPassManager, FunctionLikeNest(modulePassManager) .addPass( [&] { return createLLVMGPULowerExecutableTargetPass(options); }) - .addPass(createVerifyWorkgroupDistributionPass); + .addPass(createVerifyWorkgroupDistributionPass) + .addPass(createRemoveIndexHintsPass); if (clPatchFuncOps) { modulePassManager.addPass(createPatchFuncOpsPass()); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLAnnotateKernelForTranslation.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLAnnotateKernelForTranslation.cpp index 75fcc0f758fa..4520aa619026 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLAnnotateKernelForTranslation.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLAnnotateKernelForTranslation.cpp @@ -102,16 +102,19 @@ annotateKernelForTranslation(LLVM::LLVMFuncOp funcOp, // attribute. FailureOr chipset = getChipsetVersion(builder.getContext(), targetAttr); - if (failed(chipset)) + if (failed(chipset)) { return variantOp.emitError() << "failed to parse amdgpu chipset"; + } - if (chipset->majorVersion != 9 || *chipset < amdgpu::Chipset(9, 4, 0)) + if (chipset->majorVersion != 9 || *chipset < amdgpu::Chipset(9, 4, 0)) { return success(); + } auto inRegAttrName = builder.getStringAttr(LLVM::LLVMDialect::getInRegAttrName()); - for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) + for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) { funcOp.setArgAttr(i, inRegAttrName, unitAttr); + } return success(); } @@ -142,8 +145,9 @@ struct ROCDLAnnotateKernelForTranslationPass final // Un-exported functions are library functions or otherwise not kernels, so // don't need these annotations. - if (!exportOp) + if (!exportOp) { return; + } if (failed(annotateKernelForTranslation(funcOp, variantOp, exportOp))) { return signalPassFailure(); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLBufferInstructionsOptimization.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLBufferInstructionsOptimization.cpp index 9bc90f177f54..9b69bf72063c 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLBufferInstructionsOptimization.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLBufferInstructionsOptimization.cpp @@ -82,13 +82,15 @@ void simplifyMaskOps(RewriterBase &rewriter, vector::CreateMaskOp maskOp) { for (Operation *user : maskOp.getResult().getUsers()) { auto readOp = dyn_cast(user); // Only TransferReadOps are supported. - if (!readOp) + if (!readOp) { continue; + } auto sourceType = dyn_cast(readOp.getBase().getType()); // only supported for fat raw buffers. - if (!sourceType || !hasAMDGPUFatRawBufferAddressSpace(sourceType)) + if (!sourceType || !hasAMDGPUFatRawBufferAddressSpace(sourceType)) { continue; + } SmallVector inBounds = readOp.getInBoundsValues(); // Only supported for reads that are fully in_bounds. diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLConfigureBufferInstructions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLConfigureBufferInstructions.cpp index 7d9f8df3488e..584d433bddff 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLConfigureBufferInstructions.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLConfigureBufferInstructions.cpp @@ -46,8 +46,9 @@ static Value stripIntegerCasts(Value val) { /// loads, which is a conservative approximatino for workgroup-uniformity that /// can be made more extensive if needed. static bool isDefinitelyWorkgroupUniform(Value arg) { - if (!arg) + if (!arg) { return true; + } SetVector dependencies; BackwardSliceOptions opts; arg = stripIntegerCasts(arg); @@ -60,8 +61,9 @@ static bool isDefinitelyWorkgroupUniform(Value arg) { getBackwardSlice(arg, &dependencies, opts); assert(result.succeeded()); return llvm::all_of(dependencies, [&](Operation *op) { - if (matchPattern(op, m_Constant())) + if (matchPattern(op, m_Constant())) { return true; + } if (isa(op)) { return true; } @@ -116,13 +118,15 @@ struct ROCDLConfigureBufferInstructionsPass final : impl::ROCDLConfigureBufferInstructionsPassBase< ROCDLConfigureBufferInstructionsPass> { void runOnOperation() override { - if (!clROCDLlEnableBufferInstructions) + if (!clROCDLlEnableBufferInstructions) { return; + } mlir::FunctionOpInterface funcOp = getOperation(); // Is this really the best way to skip this pass on non-rocdl targets? IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp); - if (!target || !target.isAMD()) + if (!target || !target.isAMD()) { return; + } // Initialize the DataFlowSolver with IntegerRangeAnalysis. DataFlowSolver solver; diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel index 3227ac21ea14..a152783be7e4 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "LLVMGPUExtensionsOps.td", ], diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp index 166b61e87f12..e6ea60f915f2 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp @@ -96,8 +96,9 @@ transform_dialect::MapNestedForallToGpuThreadsOp::applyToOne( mlir::transform::gpu::mapNestedForallToThreadsImpl( rewriter, transformOp, target, getWorkgroupDims(), getSubgroupSize(), getSyncAfterDistribution()); - if (!diag.succeeded()) + if (!diag.succeeded()) { return diag; + } IREE::Codegen::TranslationInfoAttr updatedTranslationInfo = IREE::Codegen::TranslationInfoAttr::get( @@ -161,8 +162,9 @@ replaceAllUsesOfLaneWithin(RewriterBase &b, Value laneId = executeOp.getLaneid(); bool applied = false; for (Operation *user : llvm::make_early_inc_range(laneId.getUsers())) { - if (!executeOp->isProperAncestor(user)) + if (!executeOp->isProperAncestor(user)) { continue; + } b.startOpModification(user); user->replaceUsesOfWith(laneId, zero); b.finalizeOpModification(user); @@ -179,47 +181,61 @@ replaceAllUsesOfLaneWithin(RewriterBase &b, static FailureOr isThreadIdxxZeroPredicate(scf::IfOp ifOp) { if (!ifOp || ifOp.getNumResults() > 0 || ifOp.getThenRegion().getBlocks().size() != 1 || - !ifOp.getElseRegion().empty()) + !ifOp.getElseRegion().empty()) { return failure(); + } auto pred = ifOp.getCondition().getDefiningOp(); - if (!pred) + if (!pred) { return failure(); + } auto EQ = arith::CmpIPredicate::eq; auto SLT = arith::CmpIPredicate::slt; auto SLE = arith::CmpIPredicate::sle; auto ULT = arith::CmpIPredicate::ult; auto ULE = arith::CmpIPredicate::ule; if (auto threadIdOp = pred.getLhs().getDefiningOp()) { - if (threadIdOp.getDimension() != gpu::Dimension::x) + if (threadIdOp.getDimension() != gpu::Dimension::x) { return failure(); - if (pred.getPredicate() == EQ && isZeroInteger(pred.getRhs())) + } + if (pred.getPredicate() == EQ && isZeroInteger(pred.getRhs())) { return threadIdOp; - if (pred.getPredicate() == SLE && isZeroInteger(pred.getRhs())) + } + if (pred.getPredicate() == SLE && isZeroInteger(pred.getRhs())) { return threadIdOp; - if (pred.getPredicate() == ULE && isZeroInteger(pred.getRhs())) + } + if (pred.getPredicate() == ULE && isZeroInteger(pred.getRhs())) { return threadIdOp; - if (pred.getPredicate() == SLT && isOneInteger(pred.getRhs())) + } + if (pred.getPredicate() == SLT && isOneInteger(pred.getRhs())) { return threadIdOp; - if (pred.getPredicate() == ULT && isOneInteger(pred.getRhs())) + } + if (pred.getPredicate() == ULT && isOneInteger(pred.getRhs())) { return threadIdOp; + } } auto SGT = arith::CmpIPredicate::sgt; auto SGE = arith::CmpIPredicate::sge; auto UGT = arith::CmpIPredicate::ugt; auto UGE = arith::CmpIPredicate::uge; if (auto threadIdOp = pred.getRhs().getDefiningOp()) { - if (threadIdOp.getDimension() != gpu::Dimension::x) + if (threadIdOp.getDimension() != gpu::Dimension::x) { return failure(); - if (pred.getPredicate() == EQ && isZeroInteger(pred.getLhs())) + } + if (pred.getPredicate() == EQ && isZeroInteger(pred.getLhs())) { return threadIdOp; - if (pred.getPredicate() == SGE && isZeroInteger(pred.getLhs())) + } + if (pred.getPredicate() == SGE && isZeroInteger(pred.getLhs())) { return threadIdOp; - if (pred.getPredicate() == UGE && isZeroInteger(pred.getLhs())) + } + if (pred.getPredicate() == UGE && isZeroInteger(pred.getLhs())) { return threadIdOp; - if (pred.getPredicate() == SGT && isOneInteger(pred.getLhs())) + } + if (pred.getPredicate() == SGT && isOneInteger(pred.getLhs())) { return threadIdOp; - if (pred.getPredicate() == UGT && isOneInteger(pred.getLhs())) + } + if (pred.getPredicate() == UGT && isOneInteger(pred.getLhs())) { return threadIdOp; + } } return failure(); } @@ -235,8 +251,9 @@ rewriteScfIfAsWarpExecuteOnLane0(RewriterBase &rewriter, Location loc, // Bail if cond is not `if (threadIdx.x == 0)`. FailureOr maybeThreadIdxxOp = isThreadIdxxZeroPredicate(ifOp); - if (failed(maybeThreadIdxxOp)) + if (failed(maybeThreadIdxxOp)) { return failure(); + } // All the code below will be executed on a single warp given a // fixed (threadIdxy, threadIdxz). Note, we reuse @@ -384,8 +401,9 @@ static OpOperand *getWarpResult(gpu::WarpExecuteOnLane0Op warpOp, Value yieldValues = yieldOperand.get(); Operation *definedOp = yieldValues.getDefiningOp(); if (definedOp && fn(definedOp)) { - if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty()) + if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty()) { return &yieldOperand; + } } } return {}; @@ -414,15 +432,17 @@ struct WarpOpLoad : public OpRewritePattern { LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred); - if (!operand) + if (!operand) { return failure(); + } auto load = operand->get().getDefiningOp(); unsigned operandIndex = operand->getOperandNumber(); Value distributedVal = warpOp.getResult(operandIndex); auto indices = llvm::to_vector_of(load.getIndices()); - if (!indices.empty()) + if (!indices.empty()) { return failure(); + } OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(warpOp); @@ -458,17 +478,20 @@ struct HoistSharedMemoryAlloc : public OpRewritePattern { using Base::Base; LogicalResult matchAndRewrite(memref::AllocOp alloc, PatternRewriter &rewriter) const override { - if (!iree_compiler::hasSharedMemoryAddressSpace(alloc.getType())) + if (!iree_compiler::hasSharedMemoryAddressSpace(alloc.getType())) { return failure(); + } auto warpParent = alloc->getParentOfType(); - if (!warpParent) + if (!warpParent) { return failure(); + } alloc->moveBefore(warpParent); // Conservatively move the dealloc after the warpOp. This may // extend the liverange of the allocation but is always correct. for (Operation *user : alloc->getUsers()) { - if (isa(user)) + if (isa(user)) { user->moveAfter(warpParent); + } } return success(); } @@ -488,8 +511,9 @@ static void populateMultiReductionLoweringPatterns(Operation *target, static AffineMap simpleDistributionFunction(Value val) { AffineMap map = AffineMap::get(val.getContext()); auto vecType = dyn_cast(val.getType()); - if (!vecType) + if (!vecType) { return map; + } // Create a map (d0, d1) -> (d1) to distribute along the inner // dimension. Once we support n-d distribution we can add more // complex cases. @@ -673,15 +697,17 @@ transform_dialect::VectorToMMAConversionOp::applyToOne( auto diag = DiagnosedSilenceableFailure::success(); if (getUseWmma()) { - if (failed(convertVectorToMMAOps(rewriter, target))) + if (failed(convertVectorToMMAOps(rewriter, target))) { return mlir::emitDefiniteFailure( target, "vector to wmma patterns failed to apply"); + } return listener.checkAndResetError(); } - if (failed(convertVectorToNVVMCompatibleMMASync(rewriter, funcOp))) + if (failed(convertVectorToNVVMCompatibleMMASync(rewriter, funcOp))) { return mlir::emitDefiniteFailure(target, "vector to mma patterns failed to apply"); + } DEBUG_WITH_TYPE(DEBUG_VECTOR_TO_MMA, { @@ -694,10 +720,11 @@ transform_dialect::VectorToMMAConversionOp::applyToOne( RewritePatternSet f32ToTF32patterns(funcOp.getContext()); nvgpu::populateMmaSyncF32ToTF32Patterns(f32ToTF32patterns, nvgpu::MmaSyncF32Lowering::TF32); - if (failed( - applyPatternsGreedily(funcOp, std::move(f32ToTF32patterns), config))) + if (failed(applyPatternsGreedily(funcOp, std::move(f32ToTF32patterns), + config))) { return mlir::emitDefiniteFailure( target, "vector to mma F32ToTF32 patterns failed to apply"); + } return listener.checkAndResetError(); } @@ -826,8 +853,9 @@ static bool isKnownNoEffectsOpWithoutInterface(Operation *op) { /// Returns `true` if the op is defines the parallel region that is subject to /// barrier synchronization. static bool isParallelRegionBoundary(Operation *op) { - if (op->hasAttr("__parallel_region_boundary_for_test")) + if (op->hasAttr("__parallel_region_boundary_for_test")) { return true; + } // We consider functions inside executable variants . return isa(op); @@ -871,12 +899,14 @@ collectEffects(Operation *op, bool ignoreBarriers = true) { // Skip over barriers to avoid infinite recursion (those barriers would ask // this barrier again). - if (ignoreBarriers && isa(op)) + if (ignoreBarriers && isa(op)) { return true; + } // Skip over ops that we know have no effects. - if (isKnownNoEffectsOpWithoutInterface(op)) + if (isKnownNoEffectsOpWithoutInterface(op)) { return true; + } // Collect effect instances the operation. Note that the implementation of // getEffects erases all effect instances that have the type other than the @@ -891,9 +921,11 @@ collectEffects(Operation *op, if (op->hasTrait()) { for (auto ®ion : op->getRegions()) { for (auto &block : region) { - for (auto &innerOp : block) - if (!collectEffects(&innerOp, effects, ignoreBarriers)) + for (auto &innerOp : block) { + if (!collectEffects(&innerOp, effects, ignoreBarriers)) { return false; + } + } } } return true; @@ -915,8 +947,9 @@ static bool getEffectsBefore(Operation *op, SmallVectorImpl &effects, bool stopAtBarrier) { - if (!op->getBlock()) + if (!op->getBlock()) { return true; + } // If there is a non-structured control flow, bail. Region *region = op->getBlock()->getParent(); @@ -930,23 +963,27 @@ getEffectsBefore(Operation *op, for (Operation *it = op->getPrevNode(); it != nullptr; it = it->getPrevNode()) { if (isa(it)) { - if (stopAtBarrier) + if (stopAtBarrier) { return true; - else + } else { continue; + } } - if (!collectEffects(it, effects)) + if (!collectEffects(it, effects)) { return false; + } } } // Stop if reached the parallel region boundary. - if (isParallelRegionBoundary(op->getParentOp())) + if (isParallelRegionBoundary(op->getParentOp())) { return true; + } // Otherwise, keep collecting above the parent operation. - if (!getEffectsBefore(op->getParentOp(), effects, stopAtBarrier)) + if (!getEffectsBefore(op->getParentOp(), effects, stopAtBarrier)) { return false; + } // If the op is loop-like, collect effects from the trailing operations until // we hit a barrier because they can executed before the current operation by @@ -971,16 +1008,18 @@ getEffectsBefore(Operation *op, // If the parent operation is not guaranteed to execute its (single-block) // region once, walk the block. bool conservative = false; - if (!hasSingleExecutionBody(op->getParentOp())) + if (!hasSingleExecutionBody(op->getParentOp())) { op->getParentOp()->walk([&](Operation *in) { - if (conservative) + if (conservative) { return WalkResult::interrupt(); + } if (!collectEffects(in, effects)) { conservative = true; return WalkResult::interrupt(); } return WalkResult::advance(); }); + } return !conservative; } @@ -995,8 +1034,9 @@ static bool getEffectsAfter(Operation *op, SmallVectorImpl &effects, bool stopAtBarrier) { - if (!op->getBlock()) + if (!op->getBlock()) { return true; + } // If there is a non-structured control flow, bail. Region *region = op->getBlock()->getParent(); @@ -1006,25 +1046,30 @@ getEffectsAfter(Operation *op, } // Collect all effects after the op. - if (op != &op->getBlock()->back()) + if (op != &op->getBlock()->back()) { for (Operation *it = op->getNextNode(); it != nullptr; it = it->getNextNode()) { if (isa(it)) { - if (stopAtBarrier) + if (stopAtBarrier) { return true; + } continue; } - if (!collectEffects(it, effects)) + if (!collectEffects(it, effects)) { return false; + } } + } // Stop if reached the parallel region boundary. - if (isParallelRegionBoundary(op->getParentOp())) + if (isParallelRegionBoundary(op->getParentOp())) { return true; + } // Otherwise, keep collecting below the parent operation. - if (!getEffectsAfter(op->getParentOp(), effects, stopAtBarrier)) + if (!getEffectsAfter(op->getParentOp(), effects, stopAtBarrier)) { return false; + } // If the op is loop-like, collect effects from the leading operations until // we hit a barrier because they can executed after the current operation by @@ -1041,8 +1086,9 @@ getEffectsAfter(Operation *op, // operation `op2` at iteration `i-1` and the side effects must be ordered // appropriately. if (isSequentialLoopLike(op->getParentOp())) { - if (isa(op->getBlock()->front())) + if (isa(op->getBlock()->front())) { return true; + } bool exact = collectEffects(&op->getBlock()->front(), effects); return getEffectsAfter(&op->getBlock()->front(), effects, @@ -1053,16 +1099,18 @@ getEffectsAfter(Operation *op, // If the parent operation is not guaranteed to execute its (single-block) // region once, walk the block. bool conservative = false; - if (!hasSingleExecutionBody(op->getParentOp())) + if (!hasSingleExecutionBody(op->getParentOp())) { op->getParentOp()->walk([&](Operation *in) { - if (conservative) + if (conservative) { return WalkResult::interrupt(); + } if (!collectEffects(in, effects)) { conservative = true; return WalkResult::interrupt(); } return WalkResult::advance(); }); + } return !conservative; } @@ -1071,8 +1119,9 @@ getEffectsAfter(Operation *op, static Value getBase(Value v) { while (true) { Operation *definingOp = v.getDefiningOp(); - if (!definingOp) + if (!definingOp) { break; + } bool shouldContinue = TypeSwitch(v.getDefiningOp()) @@ -1090,8 +1139,9 @@ static Value getBase(Value v) { return true; }) .Default([](Operation *) { return false; }); - if (!shouldContinue) + if (!shouldContinue) { break; + } } return v; } @@ -1163,8 +1213,9 @@ static bool maybeCaptured(Value v) { } std::optional knownCaptureStatus = getKnownCapturingStatus(user, v); - if (!knownCaptureStatus || *knownCaptureStatus) + if (!knownCaptureStatus || *knownCaptureStatus) { return true; + } } } @@ -1227,20 +1278,24 @@ static bool mayAlias(Value first, Value second) { // Non-equivalent distinct bases and globals cannot alias. At this point, we // have already filtered out based on values being equal and global name being // equal. - if ((isDistinct[0] || isGlobal[0]) && (isDistinct[1] || isGlobal[1])) + if ((isDistinct[0] || isGlobal[0]) && (isDistinct[1] || isGlobal[1])) { return false; + } bool isArg[] = {isFunctionArgument(first), isFunctionArgument(second)}; // Distinct bases (allocations) cannot have been passed as an argument. - if ((isDistinct[0] && isArg[1]) || (isDistinct[1] && isArg[0])) + if ((isDistinct[0] && isArg[1]) || (isDistinct[1] && isArg[0])) { return false; + } // Non-captured base distinct values cannot conflict with another base value. - if (isDistinct[0] && !maybeCaptured(first)) + if (isDistinct[0] && !maybeCaptured(first)) { return false; - if (isDistinct[1] && !maybeCaptured(second)) + } + if (isDistinct[1] && !maybeCaptured(second)) { return false; + } // Otherwise, conservatively assume aliasing. DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, DBGS_ALIAS() << "-> may alias!\n"); @@ -1263,8 +1318,9 @@ static bool mayAlias(MemoryEffects::EffectInstance a, Value v2) { /// cannot alias. static bool mayAlias(MemoryEffects::EffectInstance a, MemoryEffects::EffectInstance b) { - if (a.getResource()->getResourceID() != b.getResource()->getResourceID()) + if (a.getResource()->getResourceID() != b.getResource()->getResourceID()) { return false; + } if (Value v2 = b.getValue()) { return mayAlias(a, v2); } else if (Value v = a.getValue()) { @@ -1287,8 +1343,9 @@ haveConflictingEffects(ArrayRef beforeEffects, for (const MemoryEffects::EffectInstance &before : beforeEffects) { for (const MemoryEffects::EffectInstance &after : afterEffects) { // If cannot alias, definitely no conflict. - if (!mayAlias(before, after)) + if (!mayAlias(before, after)) { continue; + } // Read/read is not a conflict. if (isa(before.getEffect()) && @@ -1313,8 +1370,9 @@ haveConflictingEffects(ArrayRef beforeEffects, // conflicts. // 2. either the program is ill-formed and we are in undefined behavior // territory. - if (isa(before.getEffect())) + if (isa(before.getEffect())) { continue; + } // Other kinds of effects create a conflict, e.g. read-after-write. LLVM_DEBUG( diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp index 14f462c45f17..2c476276990b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp @@ -83,8 +83,9 @@ struct MaskResult { }; static MaskResult getMask(Operation *op) { auto transferRead = dyn_cast(op); - if (!transferRead || !transferRead.getMask()) + if (!transferRead || !transferRead.getMask()) { return MaskResult{}; + } vector::ExtractOp maybeExtractOp = transferRead.getMask().getDefiningOp(); auto maskOp = @@ -111,8 +112,9 @@ static MaskResult getMask(Operation *op) { static Value getMaskValue(RewriterBase &rewriter, Operation *op) { MaskResult maskResult = getMask(op); - if (!maskResult.maskOp) + if (!maskResult.maskOp) { return Value(); + } Value count = maskResult.maskOp->getOperands().back(); vector::ExtractOp maybeExtractOp = maskResult.maybeExtractOp; if (maybeExtractOp) { @@ -142,14 +144,18 @@ static Value getValueStored(Operation *writeOp) { } static Operation::operand_range getIndices(Operation *op) { - if (auto vectorReadOp = dyn_cast(op)) + if (auto vectorReadOp = dyn_cast(op)) { return vectorReadOp.getIndices(); - if (auto vectorStoreOp = dyn_cast(op)) + } + if (auto vectorStoreOp = dyn_cast(op)) { return vectorStoreOp.getIndices(); - if (auto transferReadOp = dyn_cast(op)) + } + if (auto transferReadOp = dyn_cast(op)) { return transferReadOp.getIndices(); - if (auto transferWriteOp = dyn_cast(op)) + } + if (auto transferWriteOp = dyn_cast(op)) { return transferWriteOp.getIndices(); + } llvm_unreachable("unsupported op type"); } @@ -196,8 +202,9 @@ void createAsyncGroups(RewriterBase &rewriter, mlir::FunctionOpInterface funcOp, llvm::SmallSetVector copyToSharedMem; // Look for all the copy that can be converted to async copy ops. funcOp.walk([&](Operation *writeOp) { - if (!isContiguousStore(writeOp)) + if (!isContiguousStore(writeOp)) { return WalkResult::advance(); + } LDBG() << "--candidate writeOp: " << *writeOp; Value vectorVal = getValueStored(writeOp); if (cast(vectorVal.getType()).getRank() != 1) { @@ -242,8 +249,9 @@ void createAsyncGroups(RewriterBase &rewriter, mlir::FunctionOpInterface funcOp, if (!resultsInSupportedAsyncCopy(cast(loadBase.getType()), getIndices(readOp), vecType) || !resultsInSupportedAsyncCopy(cast(storeBase.getType()), - getIndices(writeOp), vecType)) + getIndices(writeOp), vecType)) { return WalkResult::advance(); + } LDBG() << "--writeOp can be made async -> SUCCESS"; copyToSharedMem.insert(writeOp); @@ -263,8 +271,9 @@ void createAsyncGroups(RewriterBase &rewriter, mlir::FunctionOpInterface funcOp, // Ignore ops without side effects auto memInterface = dyn_cast(nextNode); if (memInterface && memInterface.hasNoEffect() && - !nextNode->hasTrait()) + !nextNode->hasTrait()) { continue; + } // ignore read from a different address space. if (isa(nextNode)) { Operation *readOp = nextNode; @@ -315,8 +324,9 @@ void createAsyncGroups(RewriterBase &rewriter, mlir::FunctionOpInterface funcOp, nvgpu::DeviceAsyncWaitOp::create(rewriter, funcOp.getLoc(), groupToken, nullptr); // Clean up old stores. - for (Operation *writeOp : group) + for (Operation *writeOp : group) { rewriter.eraseOp(writeOp); + } } } @@ -360,8 +370,9 @@ void addBarrier(mlir::FunctionOpInterface funcOp, Operation *alloc, needBarrier = true; } else { for (Operation &op : entryBlock->getOperations()) { - if (&op == alloc) + if (&op == alloc) { break; + } if (op.getNumRegions() != 0) { needBarrier = true; break; @@ -372,8 +383,9 @@ void addBarrier(mlir::FunctionOpInterface funcOp, Operation *alloc, } } } - if (!needBarrier) + if (!needBarrier) { return; + } OpBuilder builder(alloc); // TODO: make it a option if needed. if (hasAsyncCopies) { @@ -400,8 +412,9 @@ void packSharedMemoryAlloc(mlir::FunctionOpInterface funcOp) { SmallVector aliasGroups; analyseAllocsForPacking(funcOp, allocs, aliasGroups); // If there is 1 or less alias group there is nothing to do. - if (aliasGroups.size() <= 1) + if (aliasGroups.size() <= 1) { return; + } // Pack all the allocations into one i8 alloc. // We may need to add extra barriers to make sure we are done writting or diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/ROCDLPrefetchSharedMemoryCopy.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/ROCDLPrefetchSharedMemoryCopy.cpp index 9f831452610f..fcae85d1525b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/ROCDLPrefetchSharedMemoryCopy.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/ROCDLPrefetchSharedMemoryCopy.cpp @@ -277,14 +277,17 @@ static LogicalResult classifyOperationsIntoStages( LDBG() << "\n=== Final Stage Classification ==="; LDBG() << "--- Read Stage (" << result.readStage.size() << " ops) ---"; - for (Operation *op : result.readStage) + for (Operation *op : result.readStage) { LDBG() << *op; + } LDBG() << "--- Write Stage (" << result.writeStage.size() << " ops) ---"; - for (Operation *op : result.writeStage) + for (Operation *op : result.writeStage) { LDBG() << *op; + } LDBG() << "--- Compute Stage (" << result.computeStage.size() << " ops) ---"; - for (Operation *op : result.computeStage) + for (Operation *op : result.computeStage) { LDBG() << *op; + } return success(); } @@ -365,28 +368,35 @@ populateOpToStageMap(const StageClassification &stages, scf::ForOp forOp, unsigned numStages, llvm::DenseMap &opToStage) { auto assignOp = [&](Operation *op, unsigned stage) { - if (!op || isa(op)) + if (!op || isa(op)) { return; + } opToStage[op] = stage; }; if (numStages == 2) { // Two-stage pipelining: read+write in stage 0, compute in stage 1. - for (Operation *op : stages.readStage) + for (Operation *op : stages.readStage) { assignOp(op, /*stage=*/0); - for (Operation *op : stages.writeStage) + } + for (Operation *op : stages.writeStage) { assignOp(op, /*stage=*/0); - for (Operation *op : stages.computeStage) + } + for (Operation *op : stages.computeStage) { assignOp(op, /*stage=*/1); + } } else { // Three-stage pipelining: read in stage 0, write in stage 1, compute in // stage 2. - for (Operation *op : stages.readStage) + for (Operation *op : stages.readStage) { assignOp(op, /*stage=*/0); - for (Operation *op : stages.writeStage) + } + for (Operation *op : stages.writeStage) { assignOp(op, /*stage=*/1); - for (Operation *op : stages.computeStage) + } + for (Operation *op : stages.computeStage) { assignOp(op, /*stage=*/2); + } } } @@ -513,8 +523,9 @@ invokePipelineForLoop(scf::ForOp forOp, const scf::PipeliningOption &options) { // Helper to check for shared memory. static bool hasSharedMemory(Value val) { auto memrefType = dyn_cast(val.getType()); - if (!memrefType) + if (!memrefType) { return false; + } auto addrSpace = dyn_cast_if_present(memrefType.getMemorySpace()); return addrSpace && addrSpace.getValue() == gpu::AddressSpace::Workgroup; @@ -587,10 +598,12 @@ static SharedBarrierState insertBarriersInRange(RewriterBase &rewriter, state.needBarrierBeforeWrite = false; } - if (hasSharedRead) + if (hasSharedRead) { state.needBarrierBeforeWrite = true; - if (hasSharedWrite) + } + if (hasSharedWrite) { state.needBarrierBeforeRead = true; + } } return state; diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel index d5e77632a5a3..631efee31d63 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel @@ -17,17 +17,10 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "amdgpu_emulate_narrow_type.mlir", "assign_constant_ordinals.mlir", - "conv_pipeline_test_cuda.mlir", - "convert_to_nvvm.mlir", - "convert_to_rocdl.mlir", - "convert_to_rocdl_gfx950.mlir", - "create_async_groups.mlir", - "create_tile_sizes.mlir", - "distribute_to_thread.mlir", - "elementwise_pipeline.mlir", "cast_address_space_function.mlir", "cast_type_to_fit_mma.mlir", "config_custom_op.mlir", @@ -38,23 +31,32 @@ iree_lit_test_suite( "config_root_op_attribute.mlir", "config_sort.mlir", "config_winograd.mlir", + "configure_tensor_layout.mlir", + "conv_pipeline_test_cuda.mlir", + "convert_to_nvvm.mlir", + "convert_to_rocdl.mlir", + "convert_to_rocdl_gfx950.mlir", + "create_async_groups.mlir", + "create_tile_sizes.mlir", + "distribute_to_thread.mlir", + "elementwise_pipeline.mlir", "extract_address_computation_gpu.mlir", "gpu_pipeline_data_tiling.mlir", "gpu_pipeline_relayout_ops.mlir", "horizontal_fusion_pipeline.mlir", - "link_executables.mlir", - "reduction_pipeline_cuda.mlir", - "reduction_pipeline_rocm.mlir", - "reduction_pipeline_softmax_rocm.mlir", - "reuse_shared_memory_allocs.mlir", - "rocdl_pipeline_test.mlir", "legalize.mlir", "linalg_transform.mlir", + "link_executables.mlir", "llvmgpu_bufferize.mlir", "nvvm_pipeline_test.mlir", "pack_shared_memory_alloc.mlir", "pipeline_coalesced_dma.mlir", "prefetch_shared_memory.mlir", + "reduction_pipeline_cuda.mlir", + "reduction_pipeline_rocm.mlir", + "reduction_pipeline_softmax_rocm.mlir", + "reuse_shared_memory_allocs.mlir", + "rocdl_pipeline_test.mlir", "sort_pipeline_test.mlir", "tensorcore_vectorization.mlir", "transform_dialect_bufferize.mlir", @@ -68,7 +70,6 @@ iree_lit_test_suite( "transform_gpu_pipelining.mlir", "transform_vector_to_mma.mlir", "transpose_pipeline_test.mlir", - "configure_tensor_layout.mlir", "vector_lowering.mlir", "vector_to_gpu.mlir", "winograd_pipeline_test.mlir", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel index 05ccee9b3112..342984a9a8ec 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "annotate_kernel_for_translation.mlir", "buffer_instructions_optimization.mlir", @@ -24,12 +25,12 @@ iree_lit_test_suite( "config_igemm_tile_and_fuse.mlir", "config_tile_and_fuse.mlir", "config_tile_and_fuse_gfx950.mlir", + "config_user_vector_distribute.mlir", "config_vector_distribute_gfx1100.mlir", "config_vector_distribute_gfx942.mlir", "config_vector_distribute_gfx950.mlir", "config_vector_distribute_reduction_gfx942.mlir", "config_vector_distribute_reduction_gfx950.mlir", - "config_user_vector_distribute.mlir", "configure_buffer_instructions.mlir", "pipeline_direct_conv_tile_and_fuse.mlir", "pipeline_elementwise_f8fnuz.mlir", @@ -40,10 +41,10 @@ iree_lit_test_suite( "pipeline_tile_and_fuse.mlir", "pipeline_tile_and_fuse_gfx950.mlir", "pipeline_vector_distribute_dynamic_shapes_gfx942.mlir", + "pipeline_vector_distribute_gfx1100.mlir", "pipeline_vector_distribute_gfx942.mlir", - "pipeline_vector_distribute_reduction_gfx942.mlir", "pipeline_vector_distribute_gfx950.mlir", - "pipeline_vector_distribute_gfx1100.mlir", + "pipeline_vector_distribute_reduction_gfx942.mlir", ], include = ["*.mlir"], ), diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir index 179e5bb577c9..94f4164a6d92 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir @@ -1112,3 +1112,31 @@ func.func @aligned_matmul_biasadd(%lhs : tensor<512x512xf16>, %rhs : tensor<512x // CHECK-LABEL: func.func @aligned_matmul_biasadd( // CHECK: promote_operands = [0, 1] + +// ----- + +// Currently falls back to non-MMA path since MMA intrinsics require matching +// operand types. +func.func @mixed_precision_matmul_f32xbf16(%lhs: tensor<16x64xf32>, %rhs: tensor<64x32xbf16>) -> tensor<16x32xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %empty = tensor.empty() : tensor<16x32xf32> + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<16x32xf32>) -> tensor<16x32xf32> + %result = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%lhs, %rhs : tensor<16x64xf32>, tensor<64x32xbf16>) + outs(%fill : tensor<16x32xf32>) { + ^bb0(%in: f32, %in_0: bf16, %out: f32): + %0 = arith.extf %in_0 : bf16 to f32 + %1 = arith.mulf %in, %0 : f32 + %2 = arith.addf %out, %1 : f32 + linalg.yield %2 : f32 + } -> tensor<16x32xf32> + return %result : tensor<16x32xf32> +} + +// CHECK-LABEL: func.func @mixed_precision_matmul_f32xbf16( +// CHECK-SAME: #iree_codegen.translation_info +// CHECK-NOT: mma_kind diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir index 7e9c14c0f678..824caf09bd88 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir @@ -3,6 +3,12 @@ // RUN: --iree-codegen-llvmgpu-use-igemm=false \ // RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s +// RUN: iree-opt --mlir-print-local-scope --split-input-file --iree-gpu-test-target=gfx950 \ +// RUN: --iree-codegen-llvmgpu-use-tile-and-fuse-matmul=true --iree-codegen-llvmgpu-test-tile-and-fuse-vectorize=true \ +// RUN: --iree-codegen-llvmgpu-use-igemm=false \ +// RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" \ +// RUN: --remarks-filter=".*" %s 2>&1 | FileCheck %s --check-prefix=CHECK-REMARKS + #lhs_map = affine_map<(M, N, Ko, Kb) -> (M, Ko, Kb)> #rhs_map = affine_map<(M, N, Ko, Kb) -> (N, Ko, Kb)> #scale_m = affine_map<(M, N, Ko, Kb) -> (M, Ko)> @@ -35,6 +41,10 @@ func.func @scaled_matmul( // CHECK-SAME: subgroup = [4, 8, 0, 0] // CHECK-SAME: workgroup = [256, 256, 0, 0] +// CHECK-REMARKS: [Analysis] SharedMemoryUsage +// CHECK-REMARKS-SAME: Category:deduceMMASchedule +// CHECK-REMARKS-SAME: Remark=34816 + // ----- #lhs_map = affine_map<(B, M, N, Ko, Kb) -> (B, M, Ko, Kb)> @@ -70,6 +80,10 @@ func.func @scaled_matmul_with_batch( // CHECK-SAME: subgroup = [0, 4, 8, 0, 0] // CHECK-SAME: workgroup = [1, 256, 256, 0, 0] +// CHECK-REMARKS: [Analysis] SharedMemoryUsage +// CHECK-REMARKS-SAME: Category:deduceMMASchedule +// CHECK-REMARKS-SAME: Remark=34816 + // ----- #lhs_map = affine_map<(M, N, Ko, Kb) -> (M, Ko, Kb)> @@ -132,6 +146,10 @@ func.func @scaled_matmul_with_dynamic_batch( // CHECK-SAME: subgroup = [0, 4, 4, 0, 0] // CHECK-SAME: workgroup = [1, 128, 256, 0, 0] +// CHECK-REMARKS: [Analysis] SharedMemoryUsage +// CHECK-REMARKS-SAME: Category:deduceMMASchedule +// CHECK-REMARKS-SAME: Remark=26112 + // ----- #lhs_map = affine_map<(M, N, Ko, Kb) -> (M, Ko, Kb)> @@ -166,6 +184,10 @@ func.func @small_scaled_matmul( // CHECK-SAME: subgroup = [1, 1, 0, 0] // CHECK-SAME: workgroup = [16, 16, 0, 0] +// CHECK-REMARKS: [Analysis] SharedMemoryUsage +// CHECK-REMARKS-SAME: Category:deduceMMASchedule +// CHECK-REMARKS-SAME: Remark=2176 + // ----- module { @@ -273,3 +295,7 @@ func.func @scaled_matmul_accumulate( // CHECK-SAME: reduction = [0, 0, 1, 1] // CHECK-SAME: subgroup = [2, 8, 0, 0] // CHECK: workgroup = [128, 256, 0, 0] + +// CHECK-REMARKS: [Analysis] SharedMemoryUsage +// CHECK-REMARKS-SAME: Category:deduceMMASchedule +// CHECK-REMARKS-SAME: Remark=157184 diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir index 80be4aa67df1..d0f5410dc619 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir @@ -84,11 +84,12 @@ func.func @reduction_with_no_consumer() { // CHECK-LABEL: func.func @reduction_with_no_consumer // CHECK: lowering_config = #iree_gpu.lowering_config -// CHECK-SAME: lane_basis = {{\[}}[1, 1, 1, 64], [0, 1, 2, 3] -// CHECK-SAME: partial_reduction = [0, 0, 1, 4096] -// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 8], [0, 1, 2, 3] -// CHECK-SAME: thread = [0, 0, 1, 8], -// CHECK-SAME: workgroup = [1, 1, 0, 0] +// CHECK-SAME: expand_dims = #iree_gpu.expand_dims<{{\[}}[0], [1], [2], [3, 4]{{\]}}, output_shape = [?, ?, ?, ?, 8]> +// CHECK-SAME: lane_basis = {{\[}}[1, 1, 1, 64, 1], [0, 1, 2, 3, 4]{{\]}} +// CHECK-SAME: partial_reduction = [0, 0, 1, 512, 0] +// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 8, 1], [0, 1, 2, 3, 4]{{\]}} +// CHECK-SAME: thread = [0, 0, 1, 1, 8] +// CHECK-SAME: workgroup = [1, 1, 0, 0, 0] // ----- @@ -150,20 +151,22 @@ func.func @test_multiple_reduction() { // CHECK-SAME: ins(%{{.*}} : tensor<2x32x10x16384xf32>) // CHECK-SAME: outs({{.*}}: tensor<2x32xf32>) // CHECK-SAME: attrs = {lowering_config = #iree_gpu.lowering_config<{ -// CHECK-SAME: lane_basis = {{\[}}[1, 1, 1, 64], [0, 1, 2, 3]], -// CHECK-SAME: partial_reduction = [0, 0, 1, 8192], -// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 16], [0, 1, 2, 3]], -// CHECK-SAME: thread = [0, 0, 1, 8], -// CHECK-SAME: workgroup = [1, 1, 0, 0] +// CHECK-SAME: expand_dims = #iree_gpu.expand_dims<{{\[}}[0], [1], [2], [3, 4]{{\]}}, output_shape = [?, ?, ?, ?, 8]>, +// CHECK-SAME: lane_basis = {{\[}}[1, 1, 1, 64, 1], [0, 1, 2, 3, 4]{{\]}}, +// CHECK-SAME: partial_reduction = [0, 0, 1, 1024, 0], +// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 16, 1], [0, 1, 2, 3, 4]{{\]}}, +// CHECK-SAME: thread = [0, 0, 1, 1, 8], +// CHECK-SAME: workgroup = [1, 1, 0, 0, 0] // CHECK: %{{.*}} = linalg.generic {indexing_maps = [#map, #map1, #map1], // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "reduction"]} // CHECK-SAME: ins{{.*}}, {{.*}} : tensor<2x32x10x16384xf32>, tensor<2x32xf32>) // CHECK-SAME: outs(%{{.*}} : tensor<2x32xf32>) // CHECK-SAME: attrs = {lowering_config = #iree_gpu.lowering_config<{ -// CHECK-SAME: lane_basis = {{\[}}[1, 1, 1, 64], [0, 1, 2, 3]], -// CHECK-SAME: partial_reduction = [0, 0, 1, 8192], -// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 16], [0, 1, 2, 3]], -// CHECK-SAME: thread = [0, 0, 1, 8], +// CHECK-SAME: expand_dims = #iree_gpu.expand_dims<{{\[}}[0], [1], [2], [3, 4]{{\]}}, output_shape = [?, ?, ?, ?, 8]>, +// CHECK-SAME: lane_basis = {{\[}}[1, 1, 1, 64, 1], [0, 1, 2, 3, 4]{{\]}}, +// CHECK-SAME: partial_reduction = [0, 0, 1, 1024, 0], +// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 16, 1], [0, 1, 2, 3, 4]{{\]}}, +// CHECK-SAME: thread = [0, 0, 1, 1, 8], // CHECK: %{{.*}} = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]} // CHECK-SAME: ins({{.*}}, %{{.*}}, {{.*}} : tensor<2x32x10x16384xf16>, tensor<2x32xf32>, tensor<2x32xf32>) @@ -250,11 +253,12 @@ func.func @test_multiple_stores(%arg0: !iree_tensor_ext.dispatch.tensor, +// CHECK-SAME: lane_basis = {{\[}}[1, 64, 1], [0, 1, 2]{{\]}}, +// CHECK-SAME: partial_reduction = [0, 1024, 0], +// CHECK-SAME: subgroup_basis = {{\[}}[1, 16, 1], [0, 1, 2]{{\]}}, +// CHECK-SAME: thread = [0, 1, 4], +// CHECK-SAME: workgroup = [1, 0, 0] // ----- @@ -291,9 +295,9 @@ func.func @test_gather_config(%arg0: !iree_tensor_ext.dispatch.tensor, +// CHECK-SAME: lane_basis = {{\[}}[1, 1, 64, 1], [0, 1, 2, 3]{{\]}}, +// CHECK-SAME: partial_reduction = [0, 0, 64, 0], +// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 1], [0, 1, 2, 3]{{\]}}, +// CHECK-SAME: thread = [0, 0, 1, 8], +// CHECK-SAME: workgroup = [4, 1, 0, 0] diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir index 08b764e77fca..56767f219292 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir @@ -119,11 +119,12 @@ func.func @vmt1() attributes {hal.executable.target = #executable_target_rocm_hs // CHECK-SAME: translation_info = #[[$TRANSLATION]] // CHECK: linalg.generic // CHECK-SAME: attrs = {lowering_config = #iree_gpu.lowering_config<{ -// CHECK-SAME: lane_basis = {{\[}}[1, 1, 64], [0, 1, 2]], -// CHECK-SAME: partial_reduction = [0, 0, 512], -// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1], [0, 1, 2]], -// CHECK-SAME: thread = [0, 0, 8], -// CHECK-SAME: workgroup = [1, 8, 0] +// CHECK-SAME: expand_dims = #iree_gpu.expand_dims<{{\[}}[0], [1], [2, 3]{{\]}}, output_shape = [?, ?, ?, 8]>, +// CHECK-SAME: lane_basis = {{\[}}[1, 1, 64, 1], [0, 1, 2, 3]{{\]}}, +// CHECK-SAME: partial_reduction = [0, 0, 64, 0], +// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 1], [0, 1, 2, 3]{{\]}}, +// CHECK-SAME: thread = [0, 0, 1, 8], +// CHECK-SAME: workgroup = [1, 8, 0, 0] // ----- @@ -162,11 +163,12 @@ func.func @matvec_like_no_m_dim() attributes {hal.executable.target = #executabl // CHECK-SAME: translation_info = #[[$TRANSLATION]] // CHECK: linalg.generic // CHECK-SAME: attrs = {lowering_config = #iree_gpu.lowering_config<{ -// CHECK-SAME: lane_basis = {{\[}}[1, 64], [0, 1]], -// CHECK-SAME: partial_reduction = [0, 512], -// CHECK-SAME: subgroup_basis = {{\[}}[1, 1], [0, 1]], -// CHECK-SAME: thread = [0, 8], -// CHECK-SAME: workgroup = [8, 0] +// CHECK-SAME: expand_dims = #iree_gpu.expand_dims<{{\[}}[0], [1, 2]{{\]}}, output_shape = [?, ?, 8]>, +// CHECK-SAME: lane_basis = {{\[}}[1, 64, 1], [0, 1, 2]{{\]}}, +// CHECK-SAME: partial_reduction = [0, 64, 0], +// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1], [0, 1, 2]{{\]}}, +// CHECK-SAME: thread = [0, 1, 8], +// CHECK-SAME: workgroup = [8, 0, 0] // ----- @@ -204,11 +206,12 @@ func.func @matvec_unit_n_dim() attributes {hal.executable.target = #executable_t // CHECK-SAME: translation_info = #[[$TRANSLATION]] // CHECK: linalg.generic // CHECK-SAME: attrs = {lowering_config = #iree_gpu.lowering_config<{ -// CHECK-SAME: lane_basis = {{\[}}[1, 1, 64], [0, 1, 2]], -// CHECK-SAME: partial_reduction = [0, 0, 512], -// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1], [0, 1, 2]], -// CHECK-SAME: thread = [0, 0, 8], -// CHECK-SAME: workgroup = [8, 1, 0] +// CHECK-SAME: expand_dims = #iree_gpu.expand_dims<{{\[}}[0], [1], [2, 3]{{\]}}, output_shape = [?, ?, ?, 8]>, +// CHECK-SAME: lane_basis = {{\[}}[1, 1, 64, 1], [0, 1, 2, 3]{{\]}}, +// CHECK-SAME: partial_reduction = [0, 0, 64, 0], +// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 1], [0, 1, 2, 3]{{\]}}, +// CHECK-SAME: thread = [0, 0, 1, 8], +// CHECK-SAME: workgroup = [8, 1, 0, 0] // ----- @@ -248,11 +251,12 @@ func.func @vmt2() attributes {hal.executable.target = #executable_target_rocm_hs // CDNA3-SAME: translation_info = #[[$TRANSLATION]] // CDNA3: linalg.generic // CDNA3-SAME: attrs = {lowering_config = #iree_gpu.lowering_config<{ -// CDNA3-SAME: lane_basis = {{\[}}[1, 1, 32], [0, 1, 2]], -// CDNA3-SAME: partial_reduction = [0, 0, 256], -// CDNA3-SAME: subgroup_basis = {{\[}}[1, 1, 1], [0, 1, 2]], -// CDNA3-SAME: thread = [0, 0, 8], -// CDNA3-SAME: workgroup = [1, 4, 0] +// CDNA3-SAME: expand_dims = #iree_gpu.expand_dims<{{\[}}[0], [1], [2, 3]{{\]}}, output_shape = [?, ?, ?, 8]>, +// CDNA3-SAME: lane_basis = {{\[}}[1, 1, 32, 1], [0, 1, 2, 3]{{\]}}, +// CDNA3-SAME: partial_reduction = [0, 0, 32, 0], +// CDNA3-SAME: subgroup_basis = {{\[}}[1, 1, 1, 1], [0, 1, 2, 3]{{\]}}, +// CDNA3-SAME: thread = [0, 0, 1, 8], +// CDNA3-SAME: workgroup = [1, 4, 0, 0] // ----- @@ -308,11 +312,12 @@ func.func @i4_dequant_matvec() { // CHECK: linalg.generic // CHECK: linalg.generic // CHECK-SAME: attrs = {lowering_config = #iree_gpu.lowering_config<{ -// CHECK-SAME: lane_basis = {{\[}}[1, 1, 64], [0, 1, 2]], -// CHECK-SAME: partial_reduction = [0, 1, 128], -// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1], [0, 1, 2]], -// CHECK-SAME: thread = [0, 1, 2], -// CHECK-SAME: workgroup = [8, 0, 0] +// CHECK-SAME: expand_dims = #iree_gpu.expand_dims<{{\[}}[0], [1], [2, 3]{{\]}}, output_shape = [?, ?, ?, 2]>, +// CHECK-SAME: lane_basis = {{\[}}[1, 1, 64, 1], [0, 1, 2, 3]{{\]}}, +// CHECK-SAME: partial_reduction = [0, 1, 64, 0], +// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 1], [0, 1, 2, 3]{{\]}}, +// CHECK-SAME: thread = [0, 1, 1, 2], +// CHECK-SAME: workgroup = [8, 0, 0, 0] // ----- @@ -353,11 +358,12 @@ func.func @skinny_mmt_lhs_is_vector() { // CHECK: linalg.matmul // CHECK-SAME: indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]] // CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{ -// CHECK-SAME: lane_basis = {{\[}}[1, 1, 64], [0, 1, 2]], -// CHECK-SAME: partial_reduction = [0, 0, 512], -// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1], [0, 1, 2]], -// CHECK-SAME: thread = [0, 0, 8], -// CHECK-SAME: workgroup = [2, 1, 0]}>} +// CHECK-SAME: expand_dims = #iree_gpu.expand_dims<{{\[}}[0], [1], [2, 3]{{\]}}, output_shape = [?, ?, ?, 8]>, +// CHECK-SAME: lane_basis = {{\[}}[1, 1, 64, 1], [0, 1, 2, 3]{{\]}}, +// CHECK-SAME: partial_reduction = [0, 0, 64, 0], +// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 1], [0, 1, 2, 3]{{\]}}, +// CHECK-SAME: thread = [0, 0, 1, 8], +// CHECK-SAME: workgroup = [2, 1, 0, 0]}>} // ----- @@ -395,11 +401,12 @@ func.func @skinny_mmt_lhs_is_matrix() { // CHECK: linalg.matmul // CHECK-SAME: indexing_maps // CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{ -// CHECK-SAME: lane_basis = {{\[}}[1, 1, 64], [0, 1, 2]], -// CHECK-SAME: partial_reduction = [0, 0, 512], -// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1], [0, 1, 2]], -// CHECK-SAME: thread = [0, 0, 8], -// CHECK-SAME: workgroup = [8, 1, 0]}>} +// CHECK-SAME: expand_dims = #iree_gpu.expand_dims<{{\[}}[0], [1], [2, 3]{{\]}}, output_shape = [?, ?, ?, 8]>, +// CHECK-SAME: lane_basis = {{\[}}[1, 1, 64, 1], [0, 1, 2, 3]{{\]}}, +// CHECK-SAME: partial_reduction = [0, 0, 64, 0], +// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 1], [0, 1, 2, 3]{{\]}}, +// CHECK-SAME: thread = [0, 0, 1, 8], +// CHECK-SAME: workgroup = [8, 1, 0, 0]}>} // ----- diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_cuda.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_cuda.mlir index b728b6d9e7b6..86c4b0e4b232 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_cuda.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_cuda.mlir @@ -37,21 +37,21 @@ hal.executable.variant @cuda target(<"cuda", "cuda-nvptx-fb">) { // CHECK: #[[TRANSLATION_INFO:.+]] = #iree_codegen.translation_info : vector<1x1x4xf32> +// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x1x1x1x4xf32> +// CHECK-DAG: %[[CST_ACC:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x1xf32> // CHECK-DAG: gpu.thread_id x -// CHECK: %[[R0:.+]] = scf.for %{{.*}} = %c0 to %c10240 step %c1024 iter_args(%[[A0:.+]] = %[[CST]]) -> (vector<1x1x4xf32>) { -// CHECK: %[[V:.+]] = vector.transfer_read {{.*}} : memref<512x10240xf32, #hal.descriptor_type>, vector<4xf32> -// CHECK: %[[STRIDED:.+]] = vector.insert_strided_slice %[[V]], {{.*}} : vector<4xf32> into vector<1x1x4xf32> -// CHECK: %[[ADD:.+]] = arith.addf %[[STRIDED]], %[[A0]] : vector<1x1x4xf32> -// CHECK: scf.yield %[[ADD]] : vector<1x1x4xf32> +// CHECK: %[[R0:.+]] = scf.for %{{.*}} = %c0 to %c2560 step %c256 iter_args(%[[A0:.+]] = %[[CST_ACC]]) -> (vector<1x1x1xf32>) { +// CHECK: %[[V:.+]] = vector.transfer_read {{.*}} : memref<512x10240xf32, {{.*}}>, vector<1x4xf32> +// CHECK: %[[STRIDED:.+]] = vector.insert_strided_slice %[[V]], {{.*}} : vector<1x4xf32> into vector<1x1x1x1x1x4xf32> +// CHECK: %[[REDUCE:.+]] = vector.multi_reduction , %[[STRIDED]], %[[CST_ACC]] [1, 3, 5] : vector<1x1x1x1x1x4xf32> to vector<1x1x1xf32> +// CHECK: %[[ADD:.+]] = arith.addf %[[REDUCE]], %[[A0]] : vector<1x1x1xf32> +// CHECK: scf.yield %[[ADD]] : vector<1x1x1xf32> // CHECK: } // CHECK: gpu.subgroup_reduce add {{.*}} cluster(size = 32) : (f32) -> f32 // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<10xf32, #gpu.address_space> -// CHECK: vector.transfer_write %{{.*}}, %[[ALLOC]]{{.*}} : vector<1xf32> // CHECK: gpu.barrier // CHECK: vector.transfer_read %[[ALLOC]]{{.*}} // CHECK: gpu.subgroup_reduce add {{.*}} cluster(size = 8) : (f32) -> f32 -// CHECK: vector.transfer_write {{.*}} : vector, memref<512xf32, #hal.descriptor_type> // ----- @@ -103,15 +103,14 @@ hal.executable.variant @cuda target(<"cuda", "cuda-nvptx-fb">) { // CHECK: #[[TRANSLATION_INFO:.+]] = #iree_codegen.translation_info (vector<1x1x4xf32>) { -// CHECK: vector.transfer_read {{.*}} : memref<512x10240xf32, -// CHECK: arith.addf {{.*}} : vector<1x1x4xf32> +// CHECK: scf.for {{.*}} -> (vector<1x1x1xf32>) { +// CHECK: vector.transfer_read {{.*}} : memref<512x10240xf32, {{.*}}>, vector<1x4xf32> +// CHECK: vector.multi_reduction , {{.*}} [1, 3, 5] : vector<1x1x1x1x1x4xf32> to vector<1x1x1xf32> +// CHECK: arith.addf {{.*}} : vector<1x1x1xf32> // CHECK: scf.yield // CHECK: gpu.subgroup_reduce -// CHECK: vector.transfer_write {{.*}} : vector<1xf32 // CHECK: gpu.subgroup_reduce // CHECK: arith.divf {{.*}} : vector -// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, {{.*}} // CHECK: return // ----- @@ -144,30 +143,25 @@ hal.executable.variant @cuda target(<"cuda", "cuda-nvptx-fb">) { // CHECK: #[[TRANSLATION_INFO:.+]] = #iree_codegen.translation_info (vector<1x1x4xf32>) { -// CHECK: vector.transfer_read {{.*}} : memref<12x128x40960xf32, -// CHECK: arith.maxnumf {{.*}} : vector<1x1x4xf32> +// CHECK: scf.for {{.*}} -> (vector<1x1x1xf32>) { +// CHECK: vector.transfer_read {{.*}} : memref<12x128x40960xf32, {{.*}}>, vector<1x4xf32> +// CHECK: vector.multi_reduction , {{.*}} {{.*}} : vector<1x1x1x1x1x4xf32> to vector<1x1x1xf32> +// CHECK: arith.maxnumf {{.*}} : vector<1x1x1xf32> // CHECK: scf.yield -// CHECK: vector.multi_reduction // CHECK: gpu.subgroup_reduce maxnumf -// CHECK: vector.transfer_write // CHECK: gpu.barrier // CHECK: gpu.subgroup_reduce maxnumf -// CHECK: vector.broadcast %{{.*}} : f32 to vector<1x1x4xf32> -// CHECK: scf.for {{.*}} -> (vector<1x1x4xf32>) { +// CHECK: scf.for {{.*}} -> (vector<1x1x1xf32>) { // CHECK: vector.transfer_read // CHECK: arith.subf // CHECK: math.exp +// CHECK: vector.multi_reduction // CHECK: arith.addf // CHECK: scf.yield -// CHECK: vector.multi_reduction // CHECK: gpu.subgroup_reduce add -// CHECK: vector.transfer_write // CHECK: gpu.barrier -// CHECK: vector.transfer_read // CHECK: gpu.subgroup_reduce add -// CHECK: vector.broadcast -// CHECK: scf.for +// CHECK: scf.forall // CHECK: vector.transfer_read // CHECK: arith.subf // CHECK: math.exp @@ -206,23 +200,22 @@ hal.executable.variant @cuda target(<"cuda", "cuda-nvptx-fb">) { // CHECK: #[[TRANSLATION_INFO:.+]] = #iree_codegen.translation_info (vector<1x1x4xf32>) { -// CHECK: vector.transfer_read {{.*}} : memref<12x256x40960xf32, -// CHECK: arith.maxnumf {{.*}} : vector<1x1x4xf32> +// CHECK: scf.for {{.*}} -> (vector<1x1x1xf32>) { +// CHECK: vector.transfer_read {{.*}} : memref<12x256x40960xf32, {{.*}}>, vector<1x4xf32> +// CHECK: vector.multi_reduction , {{.*}} {{.*}} : vector<1x1x1x1x1x4xf32> to vector<1x1x1xf32> +// CHECK: arith.maxnumf {{.*}} : vector<1x1x1xf32> // CHECK: scf.yield -// CHECK: vector.multi_reduction // CHECK: gpu.subgroup_reduce maxnumf -// CHECK: vector.broadcast %{{.*}} : f32 to vector<1x1x4xf32> -// CHECK: scf.for {{.*}} -> (vector<1x1x4xf32>) { +// CHECK: vector.broadcast %{{.*}} : f32 to vector<1x1x1x1x1x4xf32> +// CHECK: scf.for {{.*}} -> (vector<1x1x1xf32>) { // CHECK: vector.transfer_read // CHECK: arith.subf // CHECK: math.exp +// CHECK: vector.multi_reduction // CHECK: arith.addf // CHECK: scf.yield -// CHECK: vector.multi_reduction // CHECK: gpu.subgroup_reduce add -// CHECK: vector.broadcast -// CHECK: scf.for +// CHECK: scf.forall // CHECK: vector.transfer_read // CHECK: arith.subf // CHECK: math.exp @@ -523,9 +516,13 @@ hal.executable private @i4_dequant_matvec { // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x4xf16> -// CHECK: scf.for %{{.+}} = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%{{.*}} = %[[CST]]) -> (vector<1x1x4xf16>) -// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1x1x4xf16> -// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1x1x4xf16> - -// CHECK: vector.multi_reduction , %{{.*}}, %{{.*}} [0, 1, 2] : vector<1x1x4xf16> to f16 +// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x1xf16> +// CHECK: scf.for %{{.+}} = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%{{.*}} = %[[CST]]) -> (vector<1x1x1xf16>) +// CHECK: vector.transfer_read {{.*}} : memref<4096x32x128xi4, {{.*}}>, vector<1x4xi4> +// CHECK: arith.extui %{{.*}} : vector<1x1x1x1x1x4xi4> to vector<1x1x1x1x1x4xi32> +// CHECK: arith.uitofp %{{.*}} : vector<1x1x1x1x1x4xi32> to vector<1x1x1x1x1x4xf16> +// CHECK: arith.subf %{{.*}}, %{{.*}} : vector<1x1x1x1x1x4xf16> +// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1x1x1x1x1x4xf16> +// CHECK: vector.contract {{.*}} : vector<1x1x1x1x1x4xf16>, vector<1x1x1x1x1x4xf16> into vector<1x1x1xf16> + +// CHECK: vector.extract {{.*}} : f16 from vector<1x1x1xf16> diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_rocm.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_rocm.mlir index b9ec424ce9c6..d48dc93ec947 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_rocm.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_rocm.mlir @@ -138,16 +138,16 @@ hal.executable private @i4_dequant_matvec { // RDNA3-DAG: %[[C0:.+]] = arith.constant 0 : index // RDNA3-DAG: %[[C32:.+]] = arith.constant 32 : index // RDNA3-DAG: %[[C1:.+]] = arith.constant 1 : index -// RDNA3-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<4x1x1x1x1x4xf16> -// RDNA3: %[[FOR:.+]] = scf.for %{{.+}} = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%{{.*}} = %[[CST]]) -> (vector<4x1x1x1x1x4xf16>) -// RDNA3: %{{.*}} = arith.extui %{{.*}} : vector<4x1x1x1x1x4xi4> to vector<4x1x1x1x1x4xi32> -// RDNA3: %{{.*}} = arith.uitofp %{{.*}} : vector<4x1x1x1x1x4xi32> to vector<4x1x1x1x1x4xf16> -// RDNA3: %{{.*}} = arith.subf %{{.*}}, %{{.*}} : vector<4x1x1x1x1x4xf16> -// RDNA3: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : vector<4x1x1x1x1x4xf16> -// RDNA3: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : vector<4x1x1x1x1x4xf16> -// RDNA3: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<4x1x1x1x1x4xf16> +// RDNA3-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<4x1x1x1x1x1xf16> +// RDNA3: %[[FOR:.+]] = scf.for %{{.+}} = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%{{.*}} = %[[CST]]) -> (vector<4x1x1x1x1x1xf16>) +// RDNA3: memref.expand_shape {{.*}} : memref<4x1x128xi4, {{.*}}> into memref<4x1x32x4xi4, {{.*}}> +// RDNA3: %{{.*}} = arith.extui %{{.*}} : vector<4x1x1x1x1x1x1x1x4xi4> to vector<4x1x1x1x1x1x1x1x4xi32> +// RDNA3: %{{.*}} = arith.uitofp %{{.*}} : vector<4x1x1x1x1x1x1x1x4xi32> to vector<4x1x1x1x1x1x1x1x4xf16> +// RDNA3: %{{.*}} = arith.subf %{{.*}}, %{{.*}} : vector<4x1x1x1x1x1x1x1x4xf16> +// RDNA3: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : vector<4x1x1x1x1x1x1x1x4xf16> +// RDNA3: vector.contract {{.*}} : vector<1x1x1x1x1x4xf16>, vector<4x1x1x1x1x1x1x1x4xf16> into vector<4x1x1x1x1x1xf16> -// RDNA3: %{{.*}} = vector.multi_reduction , %{{.*}}, %{{.*}} [1, 3, 5] : vector<4x1x1x1x1x4xf16> to vector<4x1x1xf16> +// RDNA3: vector.shape_cast %{{.*}} : vector<4x1x1x1x1x1xf16> to vector<4x1x1xf16> // ----- @@ -252,13 +252,13 @@ hal.executable private @matvec_fp16 { // CHECK-SAME: translation_info = #[[$TRANSLATION]] // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C512:.+]] = arith.constant 512 : index -// CHECK-DAG: %[[C4096:.+]] = arith.constant 4096 : index -// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<8x1x1x1x1x8xf16> -// CHECK: scf.for %{{.+}} = %[[C0]] to %[[C4096]] step %[[C512]] iter_args(%[[ARG:.+]] = %[[CST]]) -> (vector<8x1x1x1x1x8xf16>) -// CHECK: {{.*}} = arith.mulf %{{.*}}, %{{.*}} : vector<8x1x1x1x1x8xf16> -// CHECK: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<8x1x1x1x1x8xf16> +// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index +// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<8x1x1x1x1x1xf16> +// CHECK: scf.for %{{.+}} = %[[C0]] to %[[C512]] step %[[C64]] iter_args(%[[ARG:.+]] = %[[CST]]) -> (vector<8x1x1x1x1x1xf16>) +// CHECK: memref.expand_shape {{.*}} : memref<8x512xf16, {{.*}}> into memref<8x64x8xf16, {{.*}}> +// CHECK: vector.contract {{.*}} : vector<1x1x1x1x1x8xf16>, vector<8x1x1x1x1x1x1x1x8xf16> into vector<8x1x1x1x1x1xf16> -// CHECK: vector.multi_reduction , %{{.*}}, %{{.*}} [1, 3, 5] : vector<8x1x1x1x1x8xf16> to vector<8x1x1xf16> +// CHECK: vector.shape_cast %{{.*}} : vector<8x1x1x1x1x1xf16> to vector<8x1x1xf16> // ----- @@ -304,14 +304,14 @@ hal.executable private @matvec_fp16 { // RDNA3: func.func @matvec_fp16() // RDNA3-SAME: translation_info = #[[$TRANSLATION]] // RDNA3-DAG: %[[C0:.+]] = arith.constant 0 : index -// RDNA3-DAG: %[[C256:.+]] = arith.constant 256 : index -// RDNA3-DAG: %[[C4096:.+]] = arith.constant 4096 : index -// RDNA3-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<4x1x1x1x1x8xf16> -// RDNA3: scf.for %{{.+}} = %[[C0]] to %[[C4096]] step %[[C256]] iter_args(%[[ARG:.+]] = %[[CST]]) -> (vector<4x1x1x1x1x8xf16>) -// RDNA3: {{.*}} = arith.mulf %{{.*}}, %{{.*}} : vector<4x1x1x1x1x8xf16> -// RDNA3: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<4x1x1x1x1x8xf16> +// RDNA3-DAG: %[[C512:.+]] = arith.constant 512 : index +// RDNA3-DAG: %[[C32:.+]] = arith.constant 32 : index +// RDNA3-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<4x1x1x1x1x1xf16> +// RDNA3: scf.for %{{.+}} = %[[C0]] to %[[C512]] step %[[C32]] iter_args(%[[ARG:.+]] = %[[CST]]) -> (vector<4x1x1x1x1x1xf16>) +// RDNA3: memref.expand_shape {{.*}} : memref<4x256xf16, {{.*}}> into memref<4x32x8xf16, {{.*}}> +// RDNA3: vector.contract {{.*}} : vector<1x1x1x1x1x8xf16>, vector<4x1x1x1x1x1x1x1x8xf16> into vector<4x1x1x1x1x1xf16> -// RDNA3: vector.multi_reduction , %{{.*}}, %{{.*}} [1, 3, 5] : vector<4x1x1x1x1x8xf16> to vector<4x1x1xf16> +// RDNA3: vector.shape_cast %{{.*}} : vector<4x1x1x1x1x1xf16> to vector<4x1x1xf16> // ----- diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp index 743718c3d89b..fcd725f7d171 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp @@ -34,8 +34,9 @@ static LogicalResult setAMDMatmulConfig(linalg::LinalgOp op, if (succeeded(setCooperativeMatrixConfig( target, op, AMDNumSubgroupsPerWorkgroup, AMDNumMNTilesPerSubgroup, AMDCoopMatrixSoftwarePipelineDepth, - AMDCoopMatrixSoftwarePipelineStoreStage))) + AMDCoopMatrixSoftwarePipelineStoreStage))) { return success(); + } int subgroupSize = target.getPreferredSubgroupSize(); const std::array workgroupXY = {subgroupSize / 2, 8}; @@ -69,16 +70,18 @@ LogicalResult setAMDCodeGenConfig(IREE::GPU::TargetAttr target, int subgroupSize = target.getPreferredSubgroupSize(); if (auto linalgOp = dyn_cast(rootOp)) { - if (isMatmulOrBatchMatmul(linalgOp)) + if (isMatmulOrBatchMatmul(linalgOp)) { return setAMDMatmulConfig(linalgOp, target); + } } if (auto convOp = dyn_cast(rootOp)) { // Use the result type in case of larger bitwidth for accumulators. auto type = cast(convOp->getResult(0).getType()); const int bitwidth = type.getElementTypeBitWidth(); - if (bitwidth > 32) + if (bitwidth > 32) { return failure(); + } const int multipler = 32 / bitwidth; bool hasPaddedInput = convOp.image().getDefiningOp(); const int bestTilingFactor = (hasPaddedInput ? 16 : 32) * multipler; diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp index 5921c4d7612d..b99b0d8b7f22 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp @@ -40,24 +40,28 @@ LogicalResult setAdrenoCodeGenConfig(IREE::GPU::TargetAttr target, Operation *rootOp) { int subgroupSize = target.getPreferredSubgroupSize(); - if (!isa(rootOp)) + if (!isa(rootOp)) { return failure(); + } auto linalgOp = cast(rootOp); - if (isMatmulOrBatchMatmul(linalgOp)) + if (isMatmulOrBatchMatmul(linalgOp)) { return setAdrenoMatmulConfig(linalgOp, target); + } if (auto convOp = dyn_cast(rootOp)) { // Use the result type in case of larger bitwidth for accumulators. auto type = cast(convOp->getResult(0).getType()); const int bitwidth = type.getElementTypeBitWidth(); - if (bitwidth > 32) + if (bitwidth > 32) { return failure(); + } const int multipler = 32 / bitwidth; auto convDimsOrFailure = linalg::inferConvolutionDims(linalgOp); - if (failed(convDimsOrFailure)) + if (failed(convDimsOrFailure)) { return failure(); + } const int bestTilingFactor = (convDimsOrFailure->depth.empty() ? 32 : 16) * multipler; return setConvOpConfig(cast(rootOp), subgroupSize, diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/AppleConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/AppleConfig.cpp index 1977157cca8c..091e00aad644 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/AppleConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/AppleConfig.cpp @@ -40,16 +40,18 @@ LogicalResult setAppleCodeGenConfig(IREE::GPU::TargetAttr target, int subgroupSize = target.getPreferredSubgroupSize(); if (auto linalgOp = dyn_cast(rootOp)) { - if (isMatmulOrBatchMatmul(linalgOp)) + if (isMatmulOrBatchMatmul(linalgOp)) { return setAppleMatmulConfig(linalgOp, target); + } } if (auto convOp = dyn_cast(rootOp)) { // Use the result type in case of larger bitwidth for accumulators. auto type = cast(convOp->getResult(0).getType()); const int bitwidth = type.getElementTypeBitWidth(); - if (bitwidth > 32) + if (bitwidth > 32) { return failure(); + } const int multipler = 32 / bitwidth; const int bestTilingFactor = 16 * multipler; return setConvOpConfig(cast(rootOp), subgroupSize, diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp index 4b012027be90..e296200c1f53 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp @@ -101,8 +101,9 @@ createResourceVariable(Location loc, const SubspanResourceInfo &resource, llvm::formatv("__resource_var_{}_{}_", resource.set, resource.binding); variable = spirv::GlobalVariableOp::create( builder, loc, globalVariableType, name, resource.set, resource.binding); - if (resource.aliased) + if (resource.aliased) { variable->setAttr("aliased", builder.getUnitAttr()); + } } else { std::string name = llvm::formatv("__resource_var_indirect_{}_", resource.set); @@ -543,8 +544,9 @@ class ConvertToSPIRVPass final LogicalResult initializeOptions( StringRef options, function_ref errorHandler) override { - if (failed(Pass::initializeOptions(options, errorHandler))) + if (failed(Pass::initializeOptions(options, errorHandler))) { return failure(); + } indexBits = indexBitsOption; return success(); } @@ -561,17 +563,20 @@ void ConvertToSPIRVPass::runOnOperation() { MLIRContext *context = &getContext(); ModuleOp moduleOp = getOperation(); - if (moduleOp.getBody()->empty()) + if (moduleOp.getBody()->empty()) { return; + } bool useIndirectBindings = usesIndirectBindingsAttr(moduleOp); for (auto funcOp : moduleOp.getOps()) { auto exportOp = getEntryPoint(funcOp); - if (!exportOp) + if (!exportOp) { continue; - if (funcOp->hasAttr(spirv::getEntryPointABIAttrName())) + } + if (funcOp->hasAttr(spirv::getEntryPointABIAttrName())) { continue; + } std::optional workgroupSize = exportOp->getWorkgroupSize(); if (!workgroupSize) { exportOp->emitOpError( @@ -757,8 +762,9 @@ void ConvertToSPIRVPass::runOnOperation() { SmallVector functions; for (auto fn : moduleOp.getOps()) { - if (!fn.isPublic()) + if (!fn.isPublic()) { continue; + } functions.push_back(fn); } @@ -770,8 +776,9 @@ void ConvertToSPIRVPass::runOnOperation() { } auto addressingModel = spirv::AddressingModel::Logical; - if (useIndirectBindings) + if (useIndirectBindings) { addressingModel = spirv::AddressingModel::PhysicalStorageBuffer64; + } // Collect all SPIR-V ops into a spirv.module. OpBuilder builder = OpBuilder::atBlockBegin(moduleOp.getBody()); @@ -781,10 +788,12 @@ void ConvertToSPIRVPass::runOnOperation() { Dialect *spvDialect = spvModule->getDialect(); for (Operation &op : llvm::make_early_inc_range(*moduleOp.getBody())) { // Skip the newly created spirv.module itself. - if (&op == spvModule) + if (&op == spvModule) { continue; - if (op.getDialect() == spvDialect) + } + if (op.getDialect() == spvDialect) { op.moveBefore(body, body->end()); + } } } diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp index 3334ce7885ca..d0ea809f68e0 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp @@ -54,8 +54,9 @@ using CodeGenPipeline = IREE::Codegen::DispatchLoweringPassPipeline; // Check if the given linalg op is fused with another op that may result // in too much shared memory usage. static bool fusedOpMayUseExtraSharedMemory(linalg::LinalgOp matmul) { - if (matmul->getNumResults() != 1) + if (matmul->getNumResults() != 1) { return true; + } auto entryPoint = matmul->getParentOfType(); @@ -105,14 +106,16 @@ static bool tileConvOneDim(const int64_t inputDim, const bool isInnerMostDim, // Handle `vectorSize` elements per thread for the innermost dimension. // We need this for the best utilization of memory. chosenTileSize = vectorSize; - if (inputDim % (dim * chosenTileSize) != 0) + if (inputDim % (dim * chosenTileSize) != 0) { continue; + } } else { - for (int64_t t = residualTilingFactor; t >= 1; t >>= 1) + for (int64_t t = residualTilingFactor; t >= 1; t >>= 1) { if (inputDim % (dim * t) == 0) { chosenTileSize = t; break; } + } } if (chosenTileSize) { wgDimSize = dim; @@ -168,12 +171,14 @@ LogicalResult setConvOpConfig(linalg::LinalgOp linalgOp, // Restrict to pure 4-D input/output shapes for now. This excludes convolution // ops with 1- or 3-D window sizes. It also excludes 2-D-window convolution // ops like `linalg.depthwise_conv_2d_nhwc_hwcm`. - if (inputShape.size() != 4 || outputShape.size() != 4) + if (inputShape.size() != 4 || outputShape.size() != 4) { return failure(); + } auto convDimsOrFailure = linalg::inferConvolutionDims(linalgOp); - if (failed(convDimsOrFailure)) + if (failed(convDimsOrFailure)) { return failure(); + } const mlir::linalg::ConvolutionDimensions &convDims = *convDimsOrFailure; LLVM_DEBUG(llvm::dbgs() << "conv: " << linalgOp << "\n" << "conv batch dim: " @@ -231,8 +236,9 @@ LogicalResult setConvOpConfig(linalg::LinalgOp linalgOp, // We use `vectorSize` as the tile size along IC dimension. If smaller than // 4, it will be unrolled into size 1. - if (ic && !(*ic % vectorSize == 0 || *ic < 4)) + if (ic && !(*ic % vectorSize == 0 || *ic < 4)) { return failure(); + } // The core idea is to distribute the convolution dimensions to the workgroup // Z/Y/X dimensions, with each thread in a workgroup handling multiple vector @@ -263,8 +269,9 @@ LogicalResult setConvOpConfig(linalg::LinalgOp linalgOp, // OC -> x if (!tileConvOneDim(oc, /*isInnerMostDim=*/true, vectorSize, residualThreads, residualTilingFactor, workgroupSize[0], - workgroupTileSizes[3])) + workgroupTileSizes[3])) { return failure(); + } // Deduce the configruation for the OW and OH dimension. Try to make them // even if possible given we typically have images with the same height @@ -362,18 +369,21 @@ std::tuple getMatmulBMNKIndex(linalg::LinalgOp op, } else if (inLHS) { // For cases where we have two parallel dimensions only accessed by // the LHS, treat the outer one of them as the batch dimension. - if (mIndex >= 0 && bIndex < 0) + if (mIndex >= 0 && bIndex < 0) { bIndex = mIndex; + } mIndex = i; } else if (inRHS) { // For cases where we have two parallel dimensions only accessed by // the RHS, treat the outer one of them as the batch dimension. - if (nIndex >= 0 && bIndex < 0) + if (nIndex >= 0 && bIndex < 0) { bIndex = nIndex; + } nIndex = i; } - if (lastParallelDim) + if (lastParallelDim) { *lastParallelDim = i; + } } LLVM_DEBUG({ @@ -459,15 +469,17 @@ int64_t getTileBytes(int64_t mTileSize, int64_t nTileSize, int64_t kTileSize, int64_t elementBits, bool promoteC) { int64_t paddingBits = detail::bankConflictReductionPaddingBits / elementBits; int64_t count = (mTileSize + nTileSize) * (kTileSize + paddingBits); - if (promoteC) + if (promoteC) { count += mTileSize * (nTileSize + paddingBits); + } return (elementBits / 8) * count; } int64_t getMultiBufferMemoryUsage(int64_t singleBufferBytes, unsigned depth, unsigned storeStage) { - if (depth == 0) + if (depth == 0) { return singleBufferBytes; + } return singleBufferBytes * (storeStage == 1 ? depth : depth + 1); }; @@ -479,8 +491,9 @@ static bool adjustToVectorLoad(ArrayRef dimMNKSize, int64_t &mTileSize, const int64_t subgroupSize, int64_t vectorSize) { const int64_t totalThreads = wgSize[0] * wgSize[1] * wgSize[2]; LLVM_DEBUG(llvm::dbgs() << "initial total thread = " << totalThreads << "\n"); - if (totalThreads <= subgroupSize) + if (totalThreads <= subgroupSize) { return false; + } const bool canVectorLoadLHS = canPerformVectorAccessUsingAllThreads( {mTileSize, kTileSize}, totalThreads, vectorSize); @@ -490,8 +503,9 @@ static bool adjustToVectorLoad(ArrayRef dimMNKSize, int64_t &mTileSize, LLVM_DEBUG(llvm::dbgs() << "RHS vector load: " << canVectorLoadRHS << "\n"); // If we can perform vector load of neither, just don't use shared memory. - if (!canVectorLoadLHS && !canVectorLoadRHS) + if (!canVectorLoadLHS && !canVectorLoadRHS) { return false; + } // If we can only perform vector load of one operands, adjust the tiling // scheme to see if we can make both work. Increase K to load more data for @@ -499,15 +513,18 @@ static bool adjustToVectorLoad(ArrayRef dimMNKSize, int64_t &mTileSize, if (canVectorLoadLHS && !canVectorLoadRHS) { for (const int scale : {2, 4}) { const int64_t newKTileSize = kTileSize * scale; - if (dimMNKSize[2] % newKTileSize != 0) + if (dimMNKSize[2] % newKTileSize != 0) { continue; + } const int64_t newMTileSize = mTileSize / scale; const int64_t newWgMDim = wgSize[1] / scale; - if (newMTileSize == 0 || newWgMDim == 0) + if (newMTileSize == 0 || newWgMDim == 0) { continue; + } const int64_t newCount = wgSize[0] * newWgMDim * wgSize[2]; - if (newCount <= subgroupSize) + if (newCount <= subgroupSize) { continue; + } if (!canPerformVectorAccessUsingAllThreads({newMTileSize, newKTileSize}, newCount, vectorSize) || !canPerformVectorAccessUsingAllThreads({newKTileSize, nTileSize}, @@ -542,8 +559,9 @@ static bool adjustToPromote(ArrayRef dimMNKSize, int64_t &mTileSize, LLVM_DEBUG(llvm::dbgs() << "subgroup size = " << subgroupSize << "\n"); const int vectorSize = kMaxVectorNumBits / elementBits; if (!adjustToVectorLoad(dimMNKSize, mTileSize, nTileSize, kTileSize, wgSize, - subgroupSize, vectorSize)) + subgroupSize, vectorSize)) { return false; + } // Don't do multibuffering if the inner reduction loop is folded out. if (dimMNKSize[2] == kTileSize) { @@ -563,8 +581,9 @@ static bool adjustToPromote(ArrayRef dimMNKSize, int64_t &mTileSize, // possible. do { if (getMultiBufferMemoryUsage(usedBytes, pipelineDepth, storeStage) <= - maxBytes) + maxBytes) { return true; + } } while (pipelineDepth-- > 1); // If we can't fit in workgroup memory, don't multibuffer. @@ -573,8 +592,9 @@ static bool adjustToPromote(ArrayRef dimMNKSize, int64_t &mTileSize, if (storeStage == 0) { storeStage = 1; if (getMultiBufferMemoryUsage(usedBytes, pipelineDepth, storeStage) <= - maxBytes) + maxBytes) { return true; + } } // Using too much workgroup memory. Try to reduce the tile size for X/Y once @@ -609,23 +629,27 @@ LogicalResult setMatmulOpConfig(IREE::GPU::TargetAttr target, auto rhsType = cast(rhs->get().getType()); auto elementBits = static_cast(IREE::Util::getTypeBitWidth(lhsType.getElementType())); - if (!llvm::is_contained({8, 16, 32}, elementBits)) + if (!llvm::is_contained({8, 16, 32}, elementBits)) { return failure(); + } ArrayRef lhsShape = lhsType.getShape(); ArrayRef rhsShape = rhsType.getShape(); - if (llvm::any_of(lhsShape, ShapedType::isDynamic)) + if (llvm::any_of(lhsShape, ShapedType::isDynamic)) { return failure(); - if (llvm::any_of(rhsShape, ShapedType::isDynamic)) + } + if (llvm::any_of(rhsShape, ShapedType::isDynamic)) { return failure(); + } assert(llvm::is_contained({2u, 3u}, op.getNumParallelLoops())); int lastParallelDim = -1; const auto [bIndex, mIndex, nIndex, kIndex] = getMatmulBMNKIndex(op, &lastParallelDim); - if (mIndex < 0 || nIndex < 0 || kIndex < 0) + if (mIndex < 0 || nIndex < 0 || kIndex < 0) { return failure(); + } const bool isBM = bIndex >= 0; SmallVector loopRanges = op.getStaticLoopRanges(); @@ -669,8 +693,9 @@ LogicalResult setMatmulOpConfig(IREE::GPU::TargetAttr target, SmallVector workgroupTileSizes(numLoops, 0); SmallVector reductionTileSizes(numLoops, 0); - if (isBM) + if (isBM) { workgroupTileSizes[bIndex] = 1; + } if (!tileMatmulNToWorkgroupX(dimN, bestThreadN, residualThreads, bestX, residualTilingFactor, workgroupSize[0], @@ -722,8 +747,9 @@ LogicalResult setMatmulOpConfig(IREE::GPU::TargetAttr target, // Tile all additional reduction dimensions with size 1 to materialize loops. for (auto [i, it] : llvm::enumerate(op.getIteratorTypesArray())) { - if (linalg::isReductionIterator(it) && reductionTileSizes[i] == 0) + if (linalg::isReductionIterator(it) && reductionTileSizes[i] == 0) { reductionTileSizes[i] = 1; + } } TileSizesListType tileSizes; @@ -733,8 +759,9 @@ LogicalResult setMatmulOpConfig(IREE::GPU::TargetAttr target, // Merge reductionTileSizes into workgroupTileSizes--this is needed by the // pipeline passes shared between SPIR-V and LLVMGPU. for (auto [i, it] : llvm::enumerate(op.getIteratorTypesArray())) { - if (linalg::isReductionIterator(it)) + if (linalg::isReductionIterator(it)) { workgroupTileSizes[i] = reductionTileSizes[i]; + } } tileSizes.push_back(workgroupTileSizes); @@ -787,8 +814,9 @@ static LogicalResult setTilingAndMatmulOpConfig(linalg::LinalgOp op, //===----------------------------------------------------------------------===// bool isCooperativeMatrixFusable(linalg::GenericOp genericOp) { - if (genericOp.getNumLoops() != genericOp.getNumParallelLoops()) + if (genericOp.getNumLoops() != genericOp.getNumParallelLoops()) { return false; + } // Look at fused elementwise ops to make sure they are allowed by the // cooperative matrix spec. @@ -802,8 +830,9 @@ bool isCooperativeMatrixFusable(linalg::GenericOp genericOp) { arith::UIToFPOp, // Special cases of these ops are directly allowed to sue // cooperative matrix types. Other cases can use a loop. - arith::MulFOp>(op)) + arith::MulFOp>(op)) { return false; + } } // Look at operands to make sure we don't have inlined constants. Cooperative @@ -811,8 +840,9 @@ bool isCooperativeMatrixFusable(linalg::GenericOp genericOp) { // classes. for (Value input : genericOp.getInputs()) { if (isa(input.getType())) { - if (matchPattern(input, m_Constant())) + if (matchPattern(input, m_Constant())) { return false; + } continue; } @@ -822,8 +852,9 @@ bool isCooperativeMatrixFusable(linalg::GenericOp genericOp) { input = subviewOp.getViewSource(); } if (auto toMemrefOp = input.getDefiningOp()) { - if (matchPattern(toMemrefOp.getTensor(), m_Constant())) + if (matchPattern(toMemrefOp.getTensor(), m_Constant())) { return false; + } } } @@ -833,11 +864,13 @@ bool isCooperativeMatrixFusable(linalg::GenericOp genericOp) { bool needToPrmoteCForCooperativeMatrix(linalg::LinalgOp matmulOp) { assert(matmulOp.hasPureTensorSemantics()); Value result = matmulOp.getOperation()->getResult(0); - if (!result.hasOneUse()) + if (!result.hasOneUse()) { return true; // Be conservative. + } Operation *user = *result.getUsers().begin(); - if (isa(user)) + if (isa(user)) { return false; + } if (auto genericOp = dyn_cast(user)) { return !isCooperativeMatrixFusable(genericOp); } @@ -854,11 +887,13 @@ setCooperativeMatrixConfig(IREE::GPU::TargetAttr target, linalg::LinalgOp op, unsigned softwarePipelineStoreStage) { LLVM_DEBUG(llvm::dbgs() << "trying to matmul cooperative matrix config...\n"); // This configuration is only for cooperative matrix. - if (target.getWgp().getMma().empty()) + if (target.getWgp().getMma().empty()) { return failure(); + } - if (op.hasDynamicShape()) + if (op.hasDynamicShape()) { return failure(); + } Value lhs = op.getDpsInputOperand(0)->get(); Value rhs = op.getDpsInputOperand(1)->get(); @@ -867,8 +902,9 @@ setCooperativeMatrixConfig(IREE::GPU::TargetAttr target, linalg::LinalgOp op, int lastParallelDim = -1; const auto [bIndex, mIndex, nIndex, kIndex] = getMatmulBMNKIndex(op, &lastParallelDim); - if (mIndex < 0 || nIndex < 0 || kIndex < 0) + if (mIndex < 0 || nIndex < 0 || kIndex < 0) { return failure(); + } const bool isBM = bIndex >= 0; SmallVector loopRanges = op.getStaticLoopRanges(); @@ -926,11 +962,12 @@ setCooperativeMatrixConfig(IREE::GPU::TargetAttr target, linalg::LinalgOp op, bool transposedRhs = nIndex != cast(maps[1].getResults().back()).getPosition(); - FailureOr schedule = - deduceMMASchedule(problem, intrinsics, seeds, sharedMemoryLimitInBytes, - subgroupSize, transposedLhs, transposedRhs); - if (failed(schedule)) + FailureOr schedule = deduceMMASchedule( + problem, intrinsics, seeds, sharedMemoryLimitInBytes, subgroupSize, + /*cuCount=*/std::nullopt, op.getLoc(), transposedLhs, transposedRhs); + if (failed(schedule)) { return failure(); + } assert(schedule->hasSingleDimensions() && "expected single M/N/K dimension"); auto pipeline = CodeGenPipeline::SPIRVCooperativeMatrixVectorize; @@ -940,21 +977,24 @@ setCooperativeMatrixConfig(IREE::GPU::TargetAttr target, linalg::LinalgOp op, schedule->mSubgroupCounts[0], 1}; SmallVector vectorSizes(kIndex + 1, 0); - if (isBM) + if (isBM) { vectorSizes[bIndex] = 1; + } vectorSizes[mIndex] = schedule->mSizes[0]; vectorSizes[nIndex] = schedule->nSizes[0]; vectorSizes[kIndex] = schedule->kSizes[0]; SmallVector subgroupTileSizes(lastParallelDim + 1, 0); - if (isBM) + if (isBM) { subgroupTileSizes[bIndex] = 1; + } subgroupTileSizes[mIndex] = schedule->mTileSizes[0] * vectorSizes[mIndex]; subgroupTileSizes[nIndex] = schedule->nTileSizes[0] * vectorSizes[nIndex]; SmallVector workgroupTileSizes(lastParallelDim + 1, 0); - if (isBM) + if (isBM) { workgroupTileSizes[bIndex] = 1; + } workgroupTileSizes[mIndex] = schedule->mSubgroupCounts[0] * subgroupTileSizes[mIndex]; workgroupTileSizes[nIndex] = @@ -1156,8 +1196,9 @@ static LogicalResult setReductionConfig(IREE::GPU::TargetAttr target, // This pipeline eventually generates non-uniform group shuffle ops, which // requires special capability. - if (!target.supportsSubgroupShuffle()) + if (!target.supportsSubgroupShuffle()) { return failure(); + } SmallVector parallelDims; SmallVector reductionDims; @@ -1168,8 +1209,9 @@ static LogicalResult setReductionConfig(IREE::GPU::TargetAttr target, int64_t numParallelDims = op.getNumParallelLoops(); // We should have reduction dimensions. - if (reductionDims.empty()) + if (reductionDims.empty()) { return failure(); + } // Make sure reduction dimensions are static and innermost ones. int64_t numDynamicReductionDims = 0; @@ -1188,8 +1230,9 @@ static LogicalResult setReductionConfig(IREE::GPU::TargetAttr target, return failure(); } - if (op.getRegionOutputArgs().size() != 1) + if (op.getRegionOutputArgs().size() != 1) { return failure(); + } // Only support projected permutation for now. This could be extended to // projected permutated with broadcast. @@ -1205,8 +1248,9 @@ static LogicalResult setReductionConfig(IREE::GPU::TargetAttr target, SmallVector combinerOps; if (matchReduction(op.getRegionOutputArgs(), i, combinerOps) && combinerOps.size() == 1) { - if (foundSingleReductionOutput) + if (foundSingleReductionOutput) { return failure(); + } foundSingleReductionOutput = true; continue; } @@ -1214,8 +1258,9 @@ static LogicalResult setReductionConfig(IREE::GPU::TargetAttr target, return failure(); } } - if (!foundSingleReductionOutput) + if (!foundSingleReductionOutput) { return failure(); + } int subgroupSize = target.getPreferredSubgroupSize(); @@ -1253,24 +1298,29 @@ static LogicalResult setReductionConfig(IREE::GPU::TargetAttr target, } int64_t reductionSize = 1; - for (int64_t dim : reductionDims) + for (int64_t dim : reductionDims) { reductionSize *= bounds[dim]; - if (reductionSize % subgroupSize != 0) + } + if (reductionSize % subgroupSize != 0) { return failure(); + } const Type elementType = cast(op.getDpsInits()[0].getType()).getElementType(); - if (!elementType.isIntOrFloat()) + if (!elementType.isIntOrFloat()) { return failure(); + } unsigned bitWidth = IREE::Util::getTypeBitWidth(elementType); // Reduction distribution only supports 8/16/32 bit types now. - if (bitWidth != 32 && bitWidth != 16 && bitWidth != 8) + if (bitWidth != 32 && bitWidth != 16 && bitWidth != 8) { return failure(); + } // Let each thread handle `vectorSize` elements. unsigned vectorSize = kMaxVectorNumBits / bitWidth; - while ((reductionSize / vectorSize) % subgroupSize != 0) + while ((reductionSize / vectorSize) % subgroupSize != 0) { vectorSize /= 2; + } // Deduce the workgroup size we should use for reduction. Currently a // workgroup processes all elements in reduction dimensions. Need to make sure @@ -1295,8 +1345,9 @@ static LogicalResult setReductionConfig(IREE::GPU::TargetAttr target, int64_t parallelSize = 1; for (int64_t dim : parallelDims) { - if (ShapedType::isStatic(bounds[dim])) + if (ShapedType::isStatic(bounds[dim])) { parallelSize *= bounds[dim]; + } } // Total parallel size that can fill the GPU with enough workgorups. // TODO: query from the target device; roughly 2x hardware compute unit. @@ -1316,8 +1367,9 @@ static LogicalResult setReductionConfig(IREE::GPU::TargetAttr target, // First, do warp reductions along multiple subgroups. // Second, reduce results from multiple subgroups using single warp reduce. // The final warp reduce requires subgroup count <= subgroup size to work. - if ((groupSize / subgroupSize) > subgroupSize) + if ((groupSize / subgroupSize) > subgroupSize) { return failure(); + } if (hasIncompatibleConsumer(op, groupSize)) { LDBG() << "Reduction has incompatible consumer, limiting workgroup size " @@ -1332,13 +1384,15 @@ static LogicalResult setReductionConfig(IREE::GPU::TargetAttr target, for (int i = reductionDims.size() - 1; i >= 0; --i) { int64_t dim = reductionDims[i]; int64_t bound = bounds[dim]; - if (i == reductionDims.size() - 1) + if (i == reductionDims.size() - 1) { bound /= vectorSize; + } APInt size = GreatestCommonDivisor(APInt(64, uint64_t(remaingGroupSize)), APInt(64, uint64_t(bound))); reductionTileSizes[dim] = size.getSExtValue(); - if (i == reductionDims.size() - 1) + if (i == reductionDims.size() - 1) { reductionTileSizes[dim] *= vectorSize; + } remaingGroupSize /= size.getSExtValue(); } @@ -1461,10 +1515,12 @@ static LogicalResult setDefaultOpConfig(IREE::GPU::TargetAttr target, // dimensions to 1 for extra dimensions. if (isa(linalgOp.getOperation())) { for (int64_t i = 0, e = workgroupTileSizes.size(); i < e; i++) { - if (workgroupTileSizes[i] != 0) + if (workgroupTileSizes[i] != 0) { break; - if (loopBounds[i] != 1) + } + if (loopBounds[i] != 1) { workgroupTileSizes[i] = 1; + } } } // Scan from the innermost shape dimension and try to deduce the @@ -1473,8 +1529,9 @@ static LogicalResult setDefaultOpConfig(IREE::GPU::TargetAttr target, for (auto shapeDim : llvm::reverse(partitionedLoops)) { int64_t loopBound = loopBounds[shapeDim]; // Skip dynamic dimensions. - if (ShapedType::isDynamic(loopBound)) + if (ShapedType::isDynamic(loopBound)) { continue; + } // Try to find some power of two that can devide the current shape dim // size. This vector keeps the candidate tile sizes. @@ -1495,12 +1552,14 @@ static LogicalResult setDefaultOpConfig(IREE::GPU::TargetAttr target, for (int64_t candidate : candidates) { int64_t scaledTileSize = candidate * scaleToByte; if (loopBound % scaledTileSize != 0) { - if (!lossFactor) + if (!lossFactor) { continue; + } // Skip this candidate if it causes many threads to be idle. int64_t idleThreads = candidate - (loopBound % scaledTileSize); - if (idleThreads > candidate / *lossFactor) + if (idleThreads > candidate / *lossFactor) { continue; + } } // If the workload is too small and we cannot distribute to more than 2 // workgroups, try a smaller tile size to increase parallelism. @@ -1526,8 +1585,9 @@ static LogicalResult setDefaultOpConfig(IREE::GPU::TargetAttr target, assert(numThreads % (candidate / vectorSize) == 0); numThreads /= candidate / vectorSize; } else { - if (wgDim == 0) + if (wgDim == 0) { vectorizable = false; + } threadTileSizes[shapeDim] = scaleToByte; workgroupSize[wgDim] = candidate; assert(numThreads % candidate == 0); @@ -1538,8 +1598,9 @@ static LogicalResult setDefaultOpConfig(IREE::GPU::TargetAttr target, } // Stop if we have distributed all threads. - if (numThreads == 1) + if (numThreads == 1) { break; + } wgDim++; } return numThreads; @@ -1555,8 +1616,9 @@ static LogicalResult setDefaultOpConfig(IREE::GPU::TargetAttr target, int64_t lossFactor = 32; for (; lossFactor >= 1; lossFactor >>= 1) { - if (distributeToThreads(numThreads, lossFactor) == 1) + if (distributeToThreads(numThreads, lossFactor) == 1) { break; + } } } @@ -1600,19 +1662,26 @@ static LogicalResult setSPIRVOpConfig(IREE::GPU::TargetAttr target, Operation *rootOp) { // First try to find a proper CodeGen configuration to tile and vectorize for // the current target architecture. - if (target.isAMD() && succeeded(detail::setAMDCodeGenConfig(target, rootOp))) + if (target.isAMD() && + succeeded(detail::setAMDCodeGenConfig(target, rootOp))) { return success(); + } if (target.isApple() && - succeeded(detail::setAppleCodeGenConfig(target, rootOp))) + succeeded(detail::setAppleCodeGenConfig(target, rootOp))) { return success(); - if (target.isARM() && succeeded(detail::setMaliCodeGenConfig(target, rootOp))) + } + if (target.isARM() && + succeeded(detail::setMaliCodeGenConfig(target, rootOp))) { return success(); + } if (target.isNVIDIA() && - succeeded(detail::setNVIDIACodeGenConfig(target, rootOp))) + succeeded(detail::setNVIDIACodeGenConfig(target, rootOp))) { return success(); + } if (target.isQualcomm() && - succeeded(detail::setAdrenoCodeGenConfig(target, rootOp))) + succeeded(detail::setAdrenoCodeGenConfig(target, rootOp))) { return success(); + } // Otherwise fallback to use a default configuration that tiles and // distributes/vectorizes. @@ -1635,8 +1704,9 @@ static LogicalResult setSPIRVOpConfig(IREE::GPU::TargetAttr target, const int subgroupSize = 32; auto result = detail::setConvOpConfig(cast(*op), subgroupSize, bestTilingFactor); - if (succeeded(result)) + if (succeeded(result)) { return success(); + } } // If unsuccessful, try to tile and distribute/vectorize. return setDefaultOpConfig(target, op); @@ -1692,22 +1762,26 @@ static LogicalResult setConfigForKernel(IREE::GPU::TargetAttr target, ArrayRef roots(computeOps); while (roots.size() > 1) { auto linalgOp = dyn_cast(roots.front()); - if (!linalgOp) + if (!linalgOp) { break; - if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops()) + } + if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops()) { break; + } roots = roots.drop_front(); } for (Operation *computeOp : roots) { - if (succeeded(setSPIRVOpConfig(target, funcOp, computeOp))) + if (succeeded(setSPIRVOpConfig(target, funcOp, computeOp))) { return success(); + } } Operation *computeOp = roots.back(); // If there are still no root op, check for any linalg.generic op. - if (succeeded(setDefaultOpConfig(target, computeOp))) + if (succeeded(setDefaultOpConfig(target, computeOp))) { return success(); + } // Check if the op configuration was set. return computeOp->emitOpError( @@ -1717,11 +1791,13 @@ static LogicalResult setConfigForKernel(IREE::GPU::TargetAttr target, LogicalResult initSPIRVLaunchConfig(FunctionOpInterface funcOp) { IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp); - if (!target) + if (!target) { return funcOp.emitError("missing GPU target in #hal.executable.target"); + } - if (getTranslationInfo(funcOp)) + if (getTranslationInfo(funcOp)) { return success(); + } if (auto exportOp = getEntryPoint(funcOp)) { // If no translation info set, first check whether we already have workgroup diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/MaliConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/MaliConfig.cpp index 37b8ba322161..25643415e539 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/MaliConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/MaliConfig.cpp @@ -45,16 +45,18 @@ LogicalResult setMaliCodeGenConfig(IREE::GPU::TargetAttr target, const int subgroupSize = target.getPreferredSubgroupSize(); if (auto linalgOp = dyn_cast(rootOp)) { - if (isMatmulOrBatchMatmul(linalgOp)) + if (isMatmulOrBatchMatmul(linalgOp)) { return setMaliMatmulConfig(linalgOp, target); + } } if (auto convOp = dyn_cast(rootOp)) { // Use the result type in case of larger bitwidth for accumulators. auto type = cast(convOp->getResult(0).getType()); const int bitwidth = type.getElementTypeBitWidth(); - if (bitwidth > 32) + if (bitwidth > 32) { return failure(); + } const int multipler = 32 / bitwidth; bool hasPaddedInput = convOp.image().getDefiningOp(); const int bestTilingFactor = (hasPaddedInput ? 8 : 16) * multipler; diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/NVIDIAConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/NVIDIAConfig.cpp index 5f4505a02d62..aefeef7ec608 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/NVIDIAConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/NVIDIAConfig.cpp @@ -30,8 +30,9 @@ static LogicalResult setNVIDIAMatmulConfig(linalg::LinalgOp op, // First try to see if we can use tensor cores. if (succeeded(setCooperativeMatrixConfig(target, op, NVIDIANumSubgroupsPerWorkgroup, - NVIDIANumMNTilesPerSubgroup))) + NVIDIANumMNTilesPerSubgroup))) { return success(); + } const int subgroupSize = target.getPreferredSubgroupSize(); const std::array workgroupXY = {subgroupSize, 8}; @@ -79,8 +80,9 @@ static LogicalResult setNVIDIAMatmulConfig(linalg::LinalgOp op, LogicalResult setNVIDIACodeGenConfig(IREE::GPU::TargetAttr target, Operation *rootOp) { if (auto linalgOp = dyn_cast(rootOp)) { - if (isMatmulOrBatchMatmul(linalgOp)) + if (isMatmulOrBatchMatmul(linalgOp)) { return setNVIDIAMatmulConfig(linalgOp, target); + } } return failure(); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp index 07737c4b9d4e..df7987e986e4 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp @@ -94,8 +94,9 @@ static LogicalResult gpuCopyFn(OpBuilder &builder, Location loc, Value from, bool needsBarrier = hasSharedMemoryAddressSpace(fromType) || hasSharedMemoryAddressSpace(toType); - if (needsBarrier) + if (needsBarrier) { gpu::BarrierOp::create(builder, loc); + } Operation *copy = memref::CopyOp::create(builder, loc, from, to); if (needsBarrier) { setMarker(copy, getCopyToWorkgroupMemoryMarker()); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVAnnotateWinogradLoops.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVAnnotateWinogradLoops.cpp index dbe01cc10a12..a42d102e6d36 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVAnnotateWinogradLoops.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVAnnotateWinogradLoops.cpp @@ -25,16 +25,18 @@ class SPIRVAnnotateWinogradLoopsPass final mlir::FunctionOpInterface funcOp = getOperation(); SmallVector forOps; funcOp.walk([&](scf::ForOp forOp) { - if (!isTiledAndDistributedLoop(forOp)) + if (!isTiledAndDistributedLoop(forOp)) { forOps.push_back(forOp); + } }); MLIRContext *context = &getContext(); OpBuilder builder(context); const char *attrName = getGPUDistributeAttrName(); for (auto [index, forOp] : llvm::enumerate(forOps)) { - if (index > kNumGPUDims) + if (index > kNumGPUDims) { break; + } forOp->setAttr(attrName, builder.getIndexAttr(index)); } } diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVBreakDownLargeVector.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVBreakDownLargeVector.cpp index c9ef3ffe0fd6..c64f7c3158f3 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVBreakDownLargeVector.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVBreakDownLargeVector.cpp @@ -67,26 +67,31 @@ struct BreakDownCastExtractExtend final : OpRewritePattern { PatternRewriter &rewriter) const override { auto extractOp = extOp.getIn().getDefiningOp(); - if (!extractOp) + if (!extractOp) { return failure(); + } auto bitCastOp = extractOp.getSource().getDefiningOp(); - if (!bitCastOp) + if (!bitCastOp) { return failure(); + } VectorType extractSrcType = extractOp.getSourceVectorType(); VectorType extractDstType = extractOp.getType(); // We expect high-D vectors are broken down into 1-D ones so here we only // handle 1-D vectors. - if (extractSrcType.getRank() != 1 || extractDstType.getRank() != 1) + if (extractSrcType.getRank() != 1 || extractDstType.getRank() != 1) { return failure(); + } // We only have power-of-two bitwidth cases for now. if (!llvm::isPowerOf2_64(extractSrcType.getNumElements()) || - !llvm::isPowerOf2_64(extractDstType.getNumElements())) + !llvm::isPowerOf2_64(extractDstType.getNumElements())) { return failure(); + } // We only handle not directly supported vector sizes. - if (extractSrcType.getNumElements() <= 4) + if (extractSrcType.getNumElements() <= 4) { return failure(); + } int64_t srcElemBitwidth = bitCastOp.getSourceVectorType().getElementTypeBitWidth(); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVConvertGPUTarget.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVConvertGPUTarget.cpp index 3ef39a32875c..c7209f2075a8 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVConvertGPUTarget.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVConvertGPUTarget.cpp @@ -65,8 +65,9 @@ std::optional processCapabilities(ArrayRef features, SetVector &caps) { for (StringRef feature : features) { if (feature.consume_front("cap:")) { - if (std::optional cap = spirv::symbolizeCapability(feature)) + if (std::optional cap = spirv::symbolizeCapability(feature)) { caps.insert(*cap); + } } } return std::nullopt; @@ -78,8 +79,9 @@ std::optional processExtensions(ArrayRef features, SetVector &exts) { for (StringRef feature : features) { if (feature.consume_front("ext:")) { - if (std::optional ext = spirv::symbolizeExtension(feature)) + if (std::optional ext = spirv::symbolizeExtension(feature)) { exts.insert(*ext); + } } } return std::nullopt; @@ -99,16 +101,21 @@ ClientAPI deduceClientAPI(StringRef backend) { } Vendor deduceVendor(IREE::GPU::TargetAttr target) { - if (target.isAMD()) + if (target.isAMD()) { return Vendor::AMD; - if (target.isApple()) + } + if (target.isApple()) { return Vendor::Apple; - if (target.isARM()) + } + if (target.isARM()) { return Vendor::ARM; - if (target.isNVIDIA()) + } + if (target.isNVIDIA()) { return Vendor::NVIDIA; - if (target.isQualcomm()) + } + if (target.isQualcomm()) { return Vendor::Qualcomm; + } return Vendor::Unknown; } @@ -118,19 +125,24 @@ Vendor deduceVendor(IREE::GPU::TargetAttr target) { void addComputeFeatures(ComputeBitwidths compute, SetVector &caps, SetVector &exts) { - if (bitEnumContainsAny(compute, ComputeBitwidths::FP64)) + if (bitEnumContainsAny(compute, ComputeBitwidths::FP64)) { caps.insert(Capability::Float64); + } // FP32 does not need special capabilities or extensions. - if (bitEnumContainsAny(compute, ComputeBitwidths::FP16)) + if (bitEnumContainsAny(compute, ComputeBitwidths::FP16)) { caps.insert(Capability::Float16); + } - if (bitEnumContainsAny(compute, ComputeBitwidths::Int64)) + if (bitEnumContainsAny(compute, ComputeBitwidths::Int64)) { caps.insert(Capability::Int64); + } // Int32 does not need special capabilities or extensions. - if (bitEnumContainsAny(compute, ComputeBitwidths::Int16)) + if (bitEnumContainsAny(compute, ComputeBitwidths::Int16)) { caps.insert(Capability::Int16); - if (bitEnumContainsAny(compute, ComputeBitwidths::Int8)) + } + if (bitEnumContainsAny(compute, ComputeBitwidths::Int8)) { caps.insert(Capability::Int8); + } } void addStorageFeatures(StorageBitwidths storage, SetVector &caps, @@ -280,8 +292,9 @@ struct SPIRVConvertGPUTargetPass final FailureOr spirvTarget = convertGPUTarget(context, variant); - if (failed(spirvTarget)) + if (failed(spirvTarget)) { return signalPassFailure(); + } moduleOp->setAttr(spirv::getTargetEnvAttrName(), *spirvTarget); } diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEmulateI64.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEmulateI64.cpp index 8291c1cf7384..e7f06146b9a8 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEmulateI64.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEmulateI64.cpp @@ -55,10 +55,11 @@ struct ConvertHalInterfaceBindingSubspan final matchAndRewrite(IREE::HAL::InterfaceBindingSubspanOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type newResultTy = getTypeConverter()->convertType(op.getType()); - if (!newResultTy) + if (!newResultTy) { return rewriter.notifyMatchFailure( op->getLoc(), llvm::formatv("failed to legalize memref type: {}", op.getType())); + } auto newOp = rewriter.replaceOpWithNewOp( @@ -111,8 +112,9 @@ struct ConvertUtilAssumeIntOp final unsigned replacementLoc = 0; for (auto result : newOp.getResults()) { - while (replacements[replacementLoc] != nullptr) + while (replacements[replacementLoc] != nullptr) { replacementLoc++; + } Value replacement = result; Type newType = getTypeConverter()->convertType( op.getResult(replacementLoc).getType()); @@ -138,11 +140,13 @@ struct ConvertUtilAssumeIntOp final // Tries to flatten `type` to a 1-D vector type. Returns `nullptr` on failure. static VectorType flattenVectorType(Type type) { auto vecTy = dyn_cast(type); - if (!vecTy) + if (!vecTy) { return nullptr; + } - if (vecTy.isScalable() || vecTy.getRank() <= 1) + if (vecTy.isScalable() || vecTy.getRank() <= 1) { return nullptr; + } int64_t totalElements = vecTy.getNumElements(); return VectorType::get(llvm::ArrayRef(totalElements), vecTy.getElementType()); @@ -167,13 +171,15 @@ struct FlattenElementwisePattern final : RewritePattern { LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - if (!OpTrait::hasElementwiseMappableTraits(op)) + if (!OpTrait::hasElementwiseMappableTraits(op)) { return failure(); + } auto newResultTypes = llvm::to_vector_of( llvm::map_range(op->getResultTypes(), flattenVectorType)); - if (llvm::any_of(newResultTypes, [](Type type) { return !type; })) + if (llvm::any_of(newResultTypes, [](Type type) { return !type; })) { return failure(); + } Location loc = op->getLoc(); @@ -181,8 +187,9 @@ struct FlattenElementwisePattern final : RewritePattern { auto operands = llvm::to_vector_of(op->getOperands()); for (Value &operand : operands) { VectorType newOperandTy = flattenVectorType(operand.getType()); - if (!newOperandTy) + if (!newOperandTy) { return failure(); + } operand = rewriter.createOrFold(loc, newOperandTy, operand); @@ -233,8 +240,9 @@ struct SPIRVEmulateI64Pass final void runOnOperation() override { mlir::FunctionOpInterface op = getOperation(); - if (supportsI64(op)) + if (supportsI64(op)) { return; + } arith::WideIntEmulationConverter typeConverter(32); memref::populateMemRefWideIntEmulationConversions(typeConverter); @@ -263,8 +271,9 @@ struct SPIRVEmulateI64Pass final memref::populateMemRefWideIntEmulationPatterns(typeConverter, patterns); populateIreeI64EmulationPatterns(typeConverter, patterns); - if (failed(applyPartialConversion(op, target, std::move(patterns)))) + if (failed(applyPartialConversion(op, target, std::move(patterns)))) { signalPassFailure(); + } } // Clean up any new 2-D vectors. We need to do it here because later passed @@ -279,8 +288,9 @@ struct SPIRVEmulateI64Pass final vector::InsertStridedSliceOp::getCanonicalizationPatterns(patterns, ctx); vector::ShapeCastOp::getCanonicalizationPatterns(patterns, ctx); - if (failed(applyPatternsGreedily(op, std::move(patterns)))) + if (failed(applyPatternsGreedily(op, std::move(patterns)))) { return signalPassFailure(); + } } } }; diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEraseStorageBufferStaticShape.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEraseStorageBufferStaticShape.cpp index df98ee11741d..1eb346982475 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEraseStorageBufferStaticShape.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEraseStorageBufferStaticShape.cpp @@ -36,12 +36,14 @@ class EraseStorageBufferStaticShapePass final bool is1DStaticShapedStorageBuffer( IREE::HAL::InterfaceBindingSubspanOp subspanOp) { auto type = dyn_cast(subspanOp.getType()); - if (!type) + if (!type) { return false; + } auto attr = dyn_cast_if_present(type.getMemorySpace()); - if (!attr) + if (!attr) { return false; + } return type.hasStaticShape() && type.getRank() == 1 && attr.getValue() == IREE::HAL::DescriptorType::StorageBuffer; } diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp index a6aa5ac39397..adec7266b5f2 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp @@ -51,29 +51,34 @@ void debugPrint(Operation *op, const char *message) { int getComputeVectorSize(int64_t size) { for (int i : {4, 3, 2}) { - if (size % i == 0) + if (size % i == 0) { return i; + } } return 1; } int getMemoryVectorSize(Value source, Type scalarType, int64_t size) { int bitwidth = scalarType.getIntOrFloatBitWidth(); - while (auto sliceOp = source.getDefiningOp()) + while (auto sliceOp = source.getDefiningOp()) { source = sliceOp.getSource(); + } if (!matchPattern(source, m_Constant())) { // If we are not reading from a constant array that is embedded in the // kernel, try to use a large vector size matching the bitwidth to read in // 128-bit chunks. This helps with memory access performance. Such vector // sizes are not native in SPIR-V though; this relies on following passes to // bitcast them to 32-bit 4-element vectors to be valid. - if (bitwidth <= 8 && size % 16 == 0) + if (bitwidth <= 8 && size % 16 == 0) { return 16; - if (bitwidth <= 16 && size % 8 == 0) + } + if (bitwidth <= 16 && size % 8 == 0) { return 8; + } } - if (bitwidth <= 32 && size % 4 == 0) + if (bitwidth <= 32 && size % 4 == 0) { return 4; + } return size % 2 == 0 ? 2 : 1; } @@ -108,8 +113,9 @@ Operation *stripElementBitPatternPreservingParents(Value op) { }) .Default([](Operation *) { return nullptr; }); - if (!source) + if (!source) { break; + } op = source; } @@ -119,8 +125,9 @@ Operation *stripElementBitPatternPreservingParents(Value op) { /// Returns true when |op| has the i32 element type that is likely to be result /// of a zero/sign extension from i8. bool mayExtI8ToI32(Value op) { - if (!getElementTypeOrSelf(op.getType()).isInteger(32)) + if (!getElementTypeOrSelf(op.getType()).isInteger(32)) { return false; + } // Look through vector operations created by vector unrolling patterns, // hoping to find a zero/sign extension op. Note that we do not need to find @@ -146,15 +153,18 @@ bool mayExtI8ToI32(Value op) { /// Succeeds when |contract| is a i32 matmul whose LHS and RHS operands may be /// result of zero/sign extension of i8 inputs. LogicalResult detectI8ToI32Matmul(vector::ContractionOp contract) { - if (contract.getKind() != vector::CombiningKind::ADD) + if (contract.getKind() != vector::CombiningKind::ADD) { return failure(); + } - if (!mayExtI8ToI32(contract.getLhs()) || !mayExtI8ToI32(contract.getRhs())) + if (!mayExtI8ToI32(contract.getLhs()) || !mayExtI8ToI32(contract.getRhs())) { return failure(); + } ArrayRef iteratorTypes = contract.getIteratorTypes().getValue(); - if (iteratorTypes.size() != 3) + if (iteratorTypes.size() != 3) { return failure(); + } return success(vector::isParallelIterator(iteratorTypes[0]) && vector::isParallelIterator(iteratorTypes[1]) && @@ -265,12 +275,14 @@ bool supportsIntegerDotProductOps(mlir::FunctionOpInterface fn) { // First check if the function op itself has a target env attribute. This may // be preferred in tests. auto targetEnvAttr = getGPUTargetAttr(fn); - if (!targetEnvAttr) + if (!targetEnvAttr) { return false; + } if (!IREE::GPU::bitEnumContainsAll(targetEnvAttr.getWgp().getDot().getValue(), - IREE::GPU::DotProductOps::DP4xI8ToI32)) + IREE::GPU::DotProductOps::DP4xI8ToI32)) { return false; + } return true; } @@ -332,8 +344,9 @@ class SPIRVInitialLoweringPass final // batch dimension. Try to drop that to map to matmul dimensions better. SmallVector contractOps; funcOp.walk([&](vector::ContractionOp op) { - if (op.getIteratorTypes().size() > 3) + if (op.getIteratorTypes().size() > 3) { contractOps.push_back(op); + } }); for (vector::ContractionOp op : contractOps) { OpBuilder builder(op); @@ -373,8 +386,9 @@ class SPIRVInitialLoweringPass final funcOp.walk([&](vector::MultiDimReductionOp reductionOp) { if (llvm::any_of(reductionOp->getOperands(), [](Value operand) { return operand.getDefiningOp(); - })) + })) { reductionOps.push_back(reductionOp); + } return WalkResult::advance(); }); RewritePatternSet patterns(context); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp index 4c3f84e2211a..1d70a29637f0 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp @@ -70,8 +70,9 @@ struct SPIRVLinkExecutablesPass final // Collect all source executable ops. auto sourceExecutableOps = gatherExecutablesForSPIRVCodegen(moduleOp); - if (sourceExecutableOps.size() <= 1) + if (sourceExecutableOps.size() <= 1) { return; + } // Note that at runtime, for a particular executable, only one variant of it // will be loaded. So, all variants of an executable are expected to provide @@ -154,8 +155,9 @@ struct SPIRVLinkExecutablesPass final } }); - if (failed(linkOneExecutableBucket(moduleOp, moduleName, key, bucket))) + if (failed(linkOneExecutableBucket(moduleOp, moduleName, key, bucket))) { return signalPassFailure(); + } } } diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMapMemRefStorageClass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMapMemRefStorageClass.cpp index 220bb9df8071..6c1af71fc727 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMapMemRefStorageClass.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMapMemRefStorageClass.cpp @@ -75,16 +75,18 @@ mapHALDescriptorTypeForOpenCL(Attribute attr) { bool allowsShaderCapability(ArrayRef features) { for (StringRef feature : features) { - if (feature.consume_front("cap:") && feature == "Shader") + if (feature.consume_front("cap:") && feature == "Shader") { return true; + } } return false; } bool allowsKernelCapability(ArrayRef features) { for (StringRef feature : features) { - if (feature.consume_front("cap:") && feature == "Kernel") + if (feature.consume_front("cap:") && feature == "Kernel") { return true; + } } return false; } diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMaterializeExecutableConditions.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMaterializeExecutableConditions.cpp index d01eab942010..40750858d891 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMaterializeExecutableConditions.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMaterializeExecutableConditions.cpp @@ -157,8 +157,9 @@ LogicalResult mapToDeviceQuery(IREE::HAL::ExecutableExportOp entryPoint, entryPoint->getAttrOfType("iree.spirv.coopmatrix.type"); auto coopmatShape = entryPoint->getAttrOfType( "iree.spirv.coopmatrix.shape"); - if (!coopmatType || !coopmatShape) + if (!coopmatType || !coopmatShape) { return failure(); + } Type inputType = cast(coopmatType.getValue().front()).getValue(); Type outputType = cast(coopmatType.getValue().back()).getValue(); @@ -277,8 +278,9 @@ struct SPIRVMaterializeExecutableConditionsPass final SPIRVMaterializeExecutableConditionsPass> { void runOnOperation() override { IREE::HAL::ExecutableVariantOp variantOp = getOperation(); - if (!usesSPIRVCodeGen(variantOp)) + if (!usesSPIRVCodeGen(variantOp)) { return; + } IREE::HAL::ExecutableTargetAttr executableTarget = variantOp.getTarget(); DictionaryAttr configuration = executableTarget.getConfiguration(); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp index 822e4f8c1ba0..9028aa7ab93c 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp @@ -50,8 +50,9 @@ verifyLoweringConfiguration(FunctionOpInterface funcOp, auto walkResult = funcOp.walk([&](Operation *op) -> WalkResult { auto loweringConfig = getLoweringConfig(op); - if (!loweringConfig) + if (!loweringConfig) { return WalkResult::advance(); + } return verificationFn(op, loweringConfig, translationInfo, workgroupSize); }); return failure(walkResult.wasInterrupted()); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp index 1b68acef1783..8665bb9619d2 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp @@ -127,15 +127,18 @@ class SPIRVTileAndDistributePass final void SPIRVTileAndDistributePass::runOnOperation() { MLIRContext *context = &getContext(); mlir::FunctionOpInterface funcOp = getOperation(); - if (!isEntryPoint(funcOp)) + if (!isEntryPoint(funcOp)) { return; + } auto threadTileComputeFn = getSPIRVTileSizeComputeFn(funcOp, 1); - if (failed(threadTileComputeFn)) + if (failed(threadTileComputeFn)) { return signalPassFailure(); + } auto reductionTileComputeFn = getSPIRVScfTileSizeComputeFn(funcOp, 2); - if (failed(reductionTileComputeFn)) + if (failed(reductionTileComputeFn)) { return signalPassFailure(); + } { // Tile and distribute to invocations. if (failed(tileToInvocation(funcOp, *threadTileComputeFn))) { diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp index 9824035c8186..6a41bc1cc7bd 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp @@ -132,16 +132,19 @@ void SPIRVTileAndPromotePass::runOnOperation() { mlir::FunctionOpInterface funcOp = getOperation(); auto threadTileComputeFn = getSPIRVTileSizeComputeFn(funcOp, 1); - if (failed(threadTileComputeFn)) + if (failed(threadTileComputeFn)) { return signalPassFailure(); + } auto reductionTileComputeFn = getSPIRVScfTileSizeComputeFn(funcOp, 2); - if (failed(reductionTileComputeFn)) + if (failed(reductionTileComputeFn)) { return signalPassFailure(); + } // Promote C matrix and propagate the potential fill producer into the // allocation. This needs to be done before reduction tiling. - if (failed(doPromoteCMatrix(funcOp))) + if (failed(doPromoteCMatrix(funcOp))) { return signalPassFailure(); + } StringLiteral markerAttrName = LinalgTransforms::kLinalgTransformMarker; auto workgroupMarker = StringAttr::get(context, getWorkgroupMemoryMarker()); @@ -219,10 +222,12 @@ void SPIRVTileAndPromotePass::runOnOperation() { // that there are no subview ops), clear markers to enable following steps. funcOp.walk([&](linalg::LinalgOp linalgOp) { auto marker = linalgOp->getAttrOfType(markerAttrName); - if (!marker) + if (!marker) { return WalkResult::advance(); - if (marker.getValue() == promoteBothMarker) + } + if (marker.getValue() == promoteBothMarker) { linalgOp->removeAttr(markerAttrName); + } return WalkResult::advance(); }); } @@ -271,14 +276,16 @@ void SPIRVTileAndPromotePass::runOnOperation() { LogicalResult SPIRVTileAndPromotePass::doPromoteCMatrix( mlir::FunctionOpInterface funcOp) const { MLIRContext *context = funcOp.getContext(); - if (!promoteCMatrix) + if (!promoteCMatrix) { return success(); + } SmallVector computeOps = getComputeOps(funcOp); SmallVector linalgOps; for (Operation *op : computeOps) { - if (isa(op)) + if (isa(op)) { continue; // Don't care + } if (auto linalgOp = dyn_cast(op)) { linalgOps.push_back(linalgOp); } else { @@ -291,8 +298,9 @@ LogicalResult SPIRVTileAndPromotePass::doPromoteCMatrix( } // If there are no fused elementwise ops, we can avoid promoting C matrix. - if (linalgOps.size() <= 1) + if (linalgOps.size() <= 1) { return success(); + } auto matmulOp = cast(linalgOps.front()); auto genericOp = cast(*linalgOps.back()); @@ -311,8 +319,9 @@ LogicalResult SPIRVTileAndPromotePass::doPromoteCMatrix( // If the fused elementwise ops are allowed to use cooperative types, we can // also avoid promoting C matrix. - if (isCooperativeMatrixFusable(genericOp)) + if (isCooperativeMatrixFusable(genericOp)) { return success(); + } // Finally do promote C matrix. RewritePatternSet patterns(context); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp index 3d754ef10cfb..6e3b1a0bdf83 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp @@ -90,8 +90,9 @@ void setSPIRVCooperativeMatrixInfo(mlir::FunctionOpInterface funcOp, ArrayRef getSPIRVCooperativeMatrixShape(mlir::FunctionOpInterface funcOp) { auto attr = funcOp->getAttrOfType(coopMatShapeAttrName); - if (!attr) + if (!attr) { return {}; + } return attr.asArrayRef(); } @@ -110,10 +111,12 @@ static SmallVector deduceSubgroupCounts(linalg::LinalgOp op) { SmallVector subgroupCounts; for (int i = 0, e = workgroupTileSizes.size(); i < e; ++i) { - if (subgroupTileSizes[i] == 0) + if (subgroupTileSizes[i] == 0) { continue; - if (linalg::isReductionIterator(op.getIteratorTypesArray()[i])) + } + if (linalg::isReductionIterator(op.getIteratorTypesArray()[i])) { continue; + } assert(workgroupTileSizes[i] % subgroupTileSizes[i] == 0); subgroupCounts.push_back(workgroupTileSizes[i] / subgroupTileSizes[i]); } @@ -174,17 +177,20 @@ std::optional> getExtOpVectorShape(ExtOpTy op, ArrayRef nativeShape) { auto insert = op.getOperand().template getDefiningOp(); - if (!insert) + if (!insert) { return std::nullopt; + } VectorType sliceType = insert.getSourceVectorType(); for (Operation *users : op->getUsers()) { auto extract = dyn_cast(users); - if (!extract) + if (!extract) { return std::nullopt; + } auto vecType = cast(extract.getResult().getType()); - if (!llvm::equal(sliceType.getShape(), vecType.getShape())) + if (!llvm::equal(sliceType.getShape(), vecType.getShape())) { return std::nullopt; + } } return llvm::to_vector(sliceType.getShape()); @@ -201,8 +207,9 @@ getCooperativeOpVectorShape(Operation *op, ArrayRef nativeShape) { // Unroll elementwise ops according to native cooperative matrix size. if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) { - if (auto vecType = dyn_cast(op->getResultTypes()[0])) + if (auto vecType = dyn_cast(op->getResultTypes()[0])) { return llvm::to_vector(nativeShape.drop_back()); // Drop K dim size + } } // Unrolling vector.contract generates vector.{insert|extract}_strided_slice @@ -231,27 +238,32 @@ getCooperativeOpVectorShape(Operation *op, ArrayRef nativeShape) { auto sourceOp = op; if (op->hasOneUse()) { auto user = *op->user_begin(); - if (isa(user) || isa(user)) + if (isa(user) || isa(user)) { sourceOp = user; + } } VectorType sliceType; for (Operation *users : sourceOp->getUsers()) { auto extract = dyn_cast(users); - if (!extract) + if (!extract) { return std::nullopt; + } auto vecType = cast(extract.getResult().getType()); - if (sliceType && sliceType != vecType) + if (sliceType && sliceType != vecType) { return std::nullopt; + } sliceType = vecType; } return llvm::to_vector(sliceType.getShape()); } - if (auto extOp = dyn_cast(op)) + if (auto extOp = dyn_cast(op)) { return getExtOpVectorShape(extOp, nativeShape); - if (auto extOp = dyn_cast(op)) + } + if (auto extOp = dyn_cast(op)) { return getExtOpVectorShape(extOp, nativeShape); + } return std::nullopt; } @@ -309,8 +321,9 @@ class CombineContractTranspose final newSources.push_back(transposeOp.getVector()); foundTranspose = true; } - if (!foundTranspose) + if (!foundTranspose) { return failure(); + } Value res = vector::ContractionOp::create( rewriter, loc, newSources[0], newSources[1], newSources[2], diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp index f9e163d75303..926ec2b04f53 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp @@ -62,8 +62,9 @@ static bool getUsesIfAllTransferOp(Value value, } continue; } - if (isa(userOp)) + if (isa(userOp)) { continue; + } if (!isa(userOp)) { @@ -109,15 +110,18 @@ calculateMemRefVectorNumBits(SmallVectorImpl &uses) { continue; } auto transferOp = dyn_cast(op); - if (!transferOp) + if (!transferOp) { return 0; + } // Masked transfers must be scalarized. - if (transferOp.getMask()) + if (transferOp.getMask()) { return 0; + } std::optional transferSize = getBitWidth(transferOp.getVectorType()); - if (!transferSize) + if (!transferSize) { return 0; + } minBits = std::min(minBits, *transferSize); } @@ -131,8 +135,9 @@ calculateMemRefVectorNumBits(SmallVectorImpl &uses) { memrefVal = storeOp.getDstMemref(); stride = storeOp.getLeadDimension().getSExtValue(); } - if (!memrefVal) + if (!memrefVal) { continue; + } // GPU subgroup MMA ops do not care about the memref element type. But we // still need to make sure we can load/store with good strides. @@ -141,12 +146,14 @@ calculateMemRefVectorNumBits(SmallVectorImpl &uses) { auto memrefType = cast(memrefVal.getType()); std::optional elementBits = getBitWidth(memrefType.getElementType()); - if (!elementBits) + if (!elementBits) { return 0; + } int64_t strideBits = stride * *elementBits; // Make sure the stride is aligned with the planned vector bitwidth. - if (strideBits % minBits != 0) + if (strideBits % minBits != 0) { return 0; + } } return minBits; @@ -197,8 +204,9 @@ static unsigned isMemRefVectorizable(Value value, if (getUsesIfAllTransferOp(value, uses)) { unsigned vectorBits = calculateMemRefVectorNumBits(uses); LLVM_DEBUG(llvm::dbgs() << "vectorBits=" << vectorBits << "\n"); - if (!vectorBits) + if (!vectorBits) { return 0; + } // TODO: Fix sub-byte type support in vector.bitcast lowering. if (vectorBits % 32 != 0) { @@ -377,8 +385,9 @@ class ProcessTransferRead final FailureOr> indices = adjustIndices(scalarMemrefType, vectorMemrefType, adaptor.getIndices(), rewriter, loc); - if (failed(indices)) + if (failed(indices)) { return rewriter.notifyMatchFailure(read, "failed to adjust indices"); + } // If the transfer_read can be replaced by a load after vectorization use // LoadOp and cast back to the original type. @@ -480,8 +489,9 @@ class ProcessTransferWrite final FailureOr> indices = adjustIndices(scalarMemrefType, vectorMemrefType, adaptor.getIndices(), rewriter, loc); - if (failed(indices)) + if (failed(indices)) { return rewriter.notifyMatchFailure(write, "failed to adjust indices"); + } // If the transfer_write can be replaced by a store after vectorization cast // the original value and use StoreOp. @@ -572,8 +582,9 @@ MemRefConversionPattern::getVectorizedMemRefType( Type vectorType = VectorType::get(vectorNumElements, scalarType); auto newShape = llvm::to_vector<2>(type.getShape()); unsigned ratio = vectorNumBits / type.getElementTypeBitWidth(); - if (newShape.back() % ratio != 0) + if (newShape.back() % ratio != 0) { return {}; + } newShape.back() = newShape.back() / ratio; MemRefLayoutAttrInterface layout = {}; @@ -605,8 +616,9 @@ FailureOr> MemRefConversionPattern::adjustIndices( getBitWidth(vectorMemrefType.getElementType()); std::optional scalarMemrefElemSize = getBitWidth(scalarMemrefType.getElementType()); - if (!vectorMemrefElemSize || !scalarMemrefElemSize) + if (!vectorMemrefElemSize || !scalarMemrefElemSize) { return failure(); + } MLIRContext *context = rewriter.getContext(); AffineExpr sym0, sym1; @@ -629,8 +641,9 @@ class ProcessAlloc final : public MemRefConversionPattern { matchAndRewrite(memref::AllocOp alloc, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto memrefType = getVectorizedMemRefType(rewriter, alloc.getResult()); - if (!memrefType) + if (!memrefType) { return failure(); + } rewriter.replaceOpWithNewOp(alloc, *memrefType, alloc.getDynamicSizes()); return success(); @@ -647,8 +660,9 @@ class ProcessInterfaceBindingSubspan final OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto memrefType = dyn_cast(subspanOp.getType()); - if (!memrefType) + if (!memrefType) { return failure(); + } // This should be guaranteed by the analysis step. But just double check. assert(memrefType.getRank() > 0 && @@ -696,8 +710,9 @@ struct ProcessSubgroupMMALoad final Location loc = loadOp.getLoc(); auto indices = adjustIndices(scalarMemrefType, vectorMemrefType, adaptor.getIndices(), rewriter, loc); - if (failed(indices)) + if (failed(indices)) { return failure(); + } // Compute how many bits the mma op stride corresponds to for the scalar // memref, and rescale it to vector memref. @@ -730,8 +745,9 @@ struct ProcessSubgroupMMAStore final Location loc = storeOp.getLoc(); auto indices = adjustIndices(scalarMemrefType, vectorMemrefType, adaptor.getIndices(), rewriter, loc); - if (failed(indices)) + if (failed(indices)) { return failure(); + } // Compute how many bits the mma op stride corresponds to for the scalar // memref, and rescale it to vector memref. @@ -804,8 +820,9 @@ struct ScalarizeVectorTransferRead final PatternRewriter &rewriter) const override { VectorType vectorType = readOp.getType(); auto map = readOp.getPermutationMap(); - if (vectorType.getRank() > 1 || !map.isProjectedPermutation()) + if (vectorType.getRank() > 1 || !map.isProjectedPermutation()) { return failure(); + } Location loc = readOp.getLoc(); Value maybeMask = readOp.getMask(); @@ -883,8 +900,9 @@ struct ScalarizeVectorLoad final : public OpRewritePattern { LogicalResult matchAndRewrite(vector::LoadOp loadOp, PatternRewriter &rewriter) const override { VectorType vectorType = loadOp.getType(); - if (vectorType.getRank() > 1) + if (vectorType.getRank() > 1) { return failure(); + } Location loc = loadOp.getLoc(); if (vectorType.getRank() == 0) { @@ -929,8 +947,9 @@ struct ScalarizeVectorTransferWrite final PatternRewriter &rewriter) const override { VectorType vectorType = writeOp.getVectorType(); auto map = writeOp.getPermutationMap(); - if (vectorType.getRank() > 1 || !map.isProjectedPermutation()) + if (vectorType.getRank() > 1 || !map.isProjectedPermutation()) { return failure(); + } Location loc = writeOp.getLoc(); Value maybeMask = writeOp.getMask(); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp index 2685c41b5e5e..9dfb15d68388 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp @@ -36,8 +36,9 @@ const char *getSPIRVDistributeAttrName() { return "iree.spirv.distribute_dim"; } DictionaryAttr getTargetConfigAttr(Operation *op) { auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(op); - if (!targetAttr) + if (!targetAttr) { return nullptr; + } return targetAttr.getConfiguration(); } @@ -62,8 +63,9 @@ getSPIRVTileSize(mlir::FunctionOpInterface funcOp, int tilingLevel) { FailureOr getSPIRVTileSizeComputeFn(mlir::FunctionOpInterface funcOp, int tilingLevel) { auto tileSizes = getSPIRVTileSize(funcOp, tilingLevel); - if (failed(tileSizes)) + if (failed(tileSizes)) { return failure(); + } linalg::TileSizeComputationFunction computeFn = [tileSizes](OpBuilder &builder, Operation *op) { auto range = llvm::map_range(*tileSizes, [&](int64_t size) -> Value { @@ -79,8 +81,9 @@ getSPIRVScfTileSizeComputeFn(mlir::FunctionOpInterface funcOp, int tilingLevel) { FailureOr> tileSizes = getSPIRVTileSize(funcOp, tilingLevel); - if (failed(tileSizes)) + if (failed(tileSizes)) { return failure(); + } scf::SCFTileSizeComputationFunction computeFn = [tileSizes](OpBuilder &builder, Operation *op) -> SmallVector { diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Verifiers.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Verifiers.cpp index fe00bd1e6be4..2a1bafa26e01 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Verifiers.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Verifiers.cpp @@ -34,8 +34,9 @@ LogicalResult verifySPIRVMatmulPromoteVectorizePassPipeline( << stringifyEnum(CodeGenPipeline::SPIRVMatmulPromoteVectorize); } - if (!isa(op)) + if (!isa(op)) { return success(); + } LLVM_DEBUG(llvm::dbgs() << "verifying op: " << *op << "\n" << "chosen workgroup size: " @@ -55,8 +56,9 @@ LogicalResult verifySPIRVMatmulPromoteVectorizePassPipeline( auto funcOp = op->getParentOfType(); std::optional subgroupSize = getGPUSubgroupSize(funcOp); - if (!subgroupSize) + if (!subgroupSize) { return funcOp->emitError("failed to query subgroup size"); + } const int maxThreads = target.getWgp().getMaxThreadCountPerWorkgroup(); const auto maxWorkGroupSize = target.getWgp().getMaxWorkgroupSizes().asArrayRef(); @@ -164,8 +166,9 @@ LogicalResult verifySPIRVCooperativeMatrixVectorizePassPipeline( auto funcOp = op->getParentOfType(); std::optional subgroupSize = getGPUSubgroupSize(funcOp); - if (!subgroupSize) + if (!subgroupSize) { return funcOp->emitError("failed to query subgroup size"); + } const int maxThreads = target.getWgp().getMaxThreadCountPerWorkgroup(); const auto maxWorkGroupSize = target.getWgp().getMaxWorkgroupSizes().asArrayRef(); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel index afeb7e4a61ea..47d4524d055d 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "annotate_winograd_loops.mlir", "break_down_large_vector.mlir", @@ -38,8 +39,8 @@ iree_lit_test_suite( "config_nvidia_matmul.mlir", "config_nvidia_matmul_cooperative_ops.mlir", "config_user.mlir", - "convert_to_spirv.mlir", "convert_gpu_target.mlir", + "convert_to_spirv.mlir", "emulate_i64.mlir", "erase_storage_buffer_static_shape.mlir", "illegal_configuration.mlir", @@ -49,17 +50,17 @@ iree_lit_test_suite( "lowering_matmul_fusion.mlir", "lowering_matmul_promotion.mlir", "lowering_matvec.mlir", - "lowering_scalar_dispatch.mlir", "lowering_reduction.mlir", + "lowering_scalar_dispatch.mlir", "map_memref_storage_class.mlir", "materialize_executable_conditions.mlir", + "physical_storage_buffer_addresses.mlir", "pipeline_matmul_cooperative_ops.mlir", "pipeline_matmul_promotion.mlir", "pipeline_matmul_vectorization.mlir", "pipeline_matvec.mlir", "pipeline_reduction_subgroup.mlir", "pipeline_sub_byte_dequant.mlir", - "physical_storage_buffer_addresses.mlir", "tile_and_distribute.mlir", "tile_and_distribute_scatter.mlir", "tile_and_distribute_sort.mlir", @@ -74,8 +75,8 @@ iree_lit_test_suite( "vectorize_conv.mlir", "vectorize_elementwise_ops.mlir", "vectorize_gather.mlir", - "vectorize_matmul.mlir", "vectorize_load_store.mlir", + "vectorize_matmul.mlir", "vectorize_reduction.mlir", ], include = ["*.mlir"], diff --git a/compiler/src/iree/compiler/Codegen/Transforms/AffineMinDistributedSCFCanonicalization.cpp b/compiler/src/iree/compiler/Codegen/Transforms/AffineMinDistributedSCFCanonicalization.cpp index 4817e365192a..8e2227590d04 100644 --- a/compiler/src/iree/compiler/Codegen/Transforms/AffineMinDistributedSCFCanonicalization.cpp +++ b/compiler/src/iree/compiler/Codegen/Transforms/AffineMinDistributedSCFCanonicalization.cpp @@ -40,8 +40,9 @@ static bool affineMinOpDivisible(affine::AffineMinOp minOp, int64_t dividend) { // Check if any of the dimensions is a ForOp or ParallelOp induction variable. for (auto dim : minOp.getDimOperands()) { auto ivArg = dyn_cast(dim); - if (!ivArg) + if (!ivArg) { continue; + } Operation *containingOp = ivArg.getOwner()->getParentOp(); auto forOp = dyn_cast_if_present(containingOp); if (forOp && forOp.getInductionVar() == dim) { @@ -52,8 +53,9 @@ static bool affineMinOpDivisible(affine::AffineMinOp minOp, int64_t dividend) { break; } auto parallelOp = dyn_cast_if_present(containingOp); - if (!parallelOp) + if (!parallelOp) { continue; + } for (auto [index, inductionVar] : llvm::enumerate(parallelOp.getInductionVars())) { if (inductionVar == dim) { @@ -64,11 +66,13 @@ static bool affineMinOpDivisible(affine::AffineMinOp minOp, int64_t dividend) { break; } } - if (iv) + if (iv) { break; + } } - if (!iv) + if (!iv) { return false; + } // Calculate the affine map representing `%ub - %iv`. AffineExpr ivDim; AffineExpr ubDim; @@ -94,11 +98,13 @@ static bool affineMinOpDivisible(affine::AffineMinOp minOp, int64_t dividend) { // `dividend` or equal to `%ub - %iv`. for (AffineExpr result : minOp.getAffineMap().getResults()) { if (auto cst = dyn_cast(result)) { - if (cst.getValue() <= 0 || cst.getValue() % dividend != 0) + if (cst.getValue() <= 0 || cst.getValue() % dividend != 0) { return false; + } } else { - if (diffExp != result) + if (diffExp != result) { return false; + } } } // Now check that for every value of the induction variable `%ub - %iv` is @@ -121,13 +127,15 @@ static bool isDivisible(Value v, int64_t dividend) { affine::canonicalizeMapAndOperands(&modMap, &ops); modMap = simplifyAffineMap(modMap); auto cst = dyn_cast(modMap.getResult(0)); - if (cst) + if (cst) { return (cst.getValue() == 0); + } // If the map doesn't fold to 0 but simplifies to (d0 %n) with d0 an // affine.min, check if all the results of the affine.min's map are divisible // by `dividend`. - if (modMap.getResult(0) != mod) + if (modMap.getResult(0) != mod) { return false; + } assert(ops.size() == 1); auto minOp = ops[0].getDefiningOp(); return (minOp && affineMinOpDivisible(minOp, dividend)); @@ -149,12 +157,14 @@ static std::optional foldAffineMin(affine::AffineMinOp minOp) { constantResult = cst.getValue(); } } - if (constantResult == 0) + if (constantResult == 0) { return {}; + } // If afine.min map's results are all positive and divisible by // `constantResult` then it can be replaced by `constantResult`. - if (affineMinOpDivisible(minOp, constantResult)) + if (affineMinOpDivisible(minOp, constantResult)) { return constantResult; + } return {}; } @@ -167,8 +177,9 @@ struct AffineMinDistributedSCFCanonicalizationPattern matchAndRewrite(mlir::affine::AffineMinOp minOp, mlir::PatternRewriter &rewriter) const override { std::optional cst = foldAffineMin(minOp); - if (!cst) + if (!cst) { return failure(); + } rewriter.replaceOpWithNewOp(minOp, rewriter.getIndexAttr(*cst)); return success(); diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp index d1a94c33567f..c49671bb9c5e 100644 --- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp @@ -45,11 +45,13 @@ namespace mlir::iree_compiler { static bool sliceFilter(Operation *op, ValueRange nonIndexComputationOperands, Operation *baseOp) { for (auto val : nonIndexComputationOperands) { - if (op == val.getDefiningOp()) + if (op == val.getDefiningOp()) { return false; + } } - if (op->isProperAncestor(baseOp)) + if (op->isProperAncestor(baseOp)) { return false; + } return !isa(op); } @@ -154,16 +156,18 @@ std::optional hoistOneStaticallyBoundAllocation( vector::ScalableValueBoundsConstraintSet::computeScalableBound( value, std::nullopt, vscaleRange->vscaleMin, vscaleRange->vscaleMax, presburger::BoundType::UB); - if (failed(ub)) + if (failed(ub)) { return failure(); + } if (ub->map.isSingleConstant()) { auto constantBound = ub->map.getSingleConstantResult(); return OpFoldResult(builder.getIndexAttr(constantBound)); } - if (!vscale) + if (!vscale) { vscale = vector::VectorScaleOp::create(builder, loc); + } return affine::materializeComputedBound( builder, loc, ub->map, {std::make_pair(vscale, std::nullopt)}); } @@ -172,8 +176,9 @@ std::optional hoistOneStaticallyBoundAllocation( presburger::BoundType::UB, {value, std::nullopt}, /*stopCondition=*/nullptr, /*closedUB=*/true); - if (failed(ub)) + if (failed(ub)) { return failure(); + } return OpFoldResult(builder.getIndexAttr(*ub)); }; @@ -202,8 +207,9 @@ std::optional hoistOneStaticallyBoundAllocation( Value dynamicSize = dynamicSizes[index++]; auto ub = computeAllocationBound(dynamicSize); - if (failed(ub)) + if (failed(ub)) { return std::nullopt; + } allocSizes.push_back(*ub); subviewSizes.push_back(dynamicSize); @@ -270,8 +276,9 @@ void hoistStaticallyBoundAllocationsInFunc( // Collect all allocLikes that are hoistable. funcOp.walk([&](AllocLikeOpType allocLikeOp) { - if (allocLikeOp->getBlock() == &funcOp.getFunctionBody().front()) + if (allocLikeOp->getBlock() == &funcOp.getFunctionBody().front()) { return; + } if (allocLikeOp.getDynamicSizes().empty()) { allocLikeOps.push_back(allocLikeOp); return; @@ -290,8 +297,9 @@ void hoistStaticallyBoundAllocationsInFunc( SmallVector deallocOps; for (Operation *user : allocLikeOp->getUsers()) { auto dealloc = dyn_cast(user); - if (dealloc) + if (dealloc) { deallocOps.push_back(dealloc); + } } LLVM_DEBUG({ @@ -303,8 +311,9 @@ void hoistStaticallyBoundAllocationsInFunc( }); std::optional replacement = hoistOneStaticallyBoundAllocation( funcOp, rewriter, allocLikeOp, vscaleRange); - if (!replacement) + if (!replacement) { continue; + } LLVM_DEBUG({ llvm::dbgs() << "Replacement : "; replacement->dump(); @@ -312,8 +321,9 @@ void hoistStaticallyBoundAllocationsInFunc( Value replacementVal = replacement.value(); rewriter.replaceOp(allocLikeOp, replacementVal); - for (memref::DeallocOp deallocOp : deallocOps) + for (memref::DeallocOp deallocOp : deallocOps) { rewriter.eraseOp(deallocOp); + } } } @@ -651,9 +661,8 @@ struct FoldSplitReductionForallWithWorkgroupForall } std::optional workgroupMapping = workgroupLoop.getMapping(); if (!workgroupMapping || - llvm::any_of(workgroupMapping->getValue(), [](Attribute attr) { - return !isa(attr); - })) { + !llvm::all_of(workgroupMapping->getValue(), + llvm::IsaPred)) { return rewriter.notifyMatchFailure( workgroupLoop, "nested loop is not a workgroup mapping loop"); } @@ -751,10 +760,12 @@ void moveLoopInvariantCodeFromGuaranteedLoops(Operation *target) { // like scf.for, since the value bounds interface requires index types. auto maybeLb = getConstantIntValue(lb); auto maybeUb = getConstantIntValue(ub); - if (!maybeLb || !maybeUb) + if (!maybeLb || !maybeUb) { return; - if (*maybeLb >= *maybeUb) + } + if (*maybeLb >= *maybeUb) { return; + } } } @@ -812,8 +823,9 @@ void analyseAllocsForPacking(mlir::FunctionOpInterface funcOp, // Skip the whole analysis if any user is a subview. // TODO: This could be extended if needed by recursively merging // liveness. - if (isa(user)) + if (isa(user)) { return; + } if (group.liveness.count(user)) { aliasGroups.push_back(i); break; @@ -851,14 +863,16 @@ void analyseAllocsForPacking(mlir::FunctionOpInterface funcOp, LLVM_DEBUG({ for (size_t i = 0; i < groups.size(); i++) { llvm::dbgs() << "Alias group " << i << ":\n"; - for (Operation *op : groups[i].allocs) + for (Operation *op : groups[i].allocs) { op->dump(); + } } }); for (size_t i = 0; i < groups.size(); i++) { - if (groups[i].allocs.empty()) + if (groups[i].allocs.empty()) { continue; + } aliasGroups.push_back(std::move(groups[i].allocs)); } } @@ -873,8 +887,9 @@ static int64_t getAllocSize(Operation *op, DataLayout &dataLayout) { void packAllocs(OpBuilder &builder, mlir::FunctionOpInterface funcOp, ArrayRef aliasGroups) { - if (aliasGroups.empty()) + if (aliasGroups.empty()) { return; + } DataLayout dataLayout = DataLayout::closest(funcOp); builder.setInsertionPointToStart(&(*funcOp.getFunctionBody().begin())); int64_t maxAlloc = 0; @@ -1061,8 +1076,9 @@ struct HoistForallFromFor : public OpRewritePattern { BlockArgument destBbArg = cast(parallelInsert.getDest()); tensor::ExtractSliceOp destSlice; for (auto user : destBbArg.getUsers()) { - if (user == parallelInsert) + if (user == parallelInsert) { continue; + } auto maybeSlice = dyn_cast(user); if (!maybeSlice) { // Fail if the destination has more users than a direct insert and @@ -1099,8 +1115,9 @@ struct HoistForallFromFor : public OpRewritePattern { for (auto [dim, size] : llvm::enumerate(insert.getMixedSizes())) { FailureOr equalDimSize = ValueBoundsConstraintSet::areEqual( {size}, {insert.getDest(), static_cast(dim)}); - if (failed(equalDimSize) || !*equalDimSize) + if (failed(equalDimSize) || !*equalDimSize) { return false; + } } return true; }; diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h index ec139f1ae402..9435589f61a7 100644 --- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h +++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h @@ -226,10 +226,12 @@ struct LinalgBasePromotionPattern : public RewritePattern { LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - if (failed(filter.checkAndNotify(rewriter, op))) + if (failed(filter.checkAndNotify(rewriter, op))) { return failure(); - if (failed(promoteSubviewsPrecondition(op, options))) + } + if (failed(promoteSubviewsPrecondition(op, options))) { return failure(); + } // TODO: We cannot use root update here. This // pattern is creating other ops, so if the diff --git a/compiler/src/iree/compiler/Codegen/Utils/CPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/CPUUtils.cpp index d15aed405b31..c1f6e8fd31bb 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/CPUUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/CPUUtils.cpp @@ -32,8 +32,9 @@ FailureOr getRootOperation(ArrayRef computeOps) { if (auto linalgOp = dyn_cast(op)) { // Do not treat linalg ops that are all parallel as root operations in // this sweep. - if (linalgOp.getNumLoops() == linalgOp.getNumParallelLoops()) + if (linalgOp.getNumLoops() == linalgOp.getNumParallelLoops()) { continue; + } // All other linalg ops are root ops. rootOperation = op; diff --git a/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.cpp index b2983b5bfe4a..02a03a6d2295 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.cpp @@ -57,8 +57,9 @@ FailureOr> getInnerTileSizesOfrImpl( if (ShapedType::isStaticShape(staticTileSizes)) { if (!materializeEncodingInfo.scalableTiles || llvm::none_of(materializeEncodingInfo.scalableTiles.value(), - [](bool scalable) { return scalable; })) + [](bool scalable) { return scalable; })) { return getAsOpFoldResult(rewriter.getI64ArrayAttr(staticTileSizes)); + } // In this case, we have scalable tiles present and we have to generate the // necessary vscale operation and the corresponding static_size * vscale // values. diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp index 4c711767a532..b444cdf50d0c 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp @@ -122,27 +122,32 @@ bool canPerformVectorAccessUsingAllThreads(ArrayRef shape, // Verify that each dimension of the shape can be distributed on the // threads // For zero dim tensor, consider it's too small to access using all threads. - if (shape.size() == 0) + if (shape.size() == 0) { return false; + } int64_t threadsAvailable = threadCount; for (const auto &[index, dim] : llvm::enumerate(llvm::reverse(shape))) { int64_t numElementPerThread = index == 0 ? vectorSize : 1; int64_t numThreads = dim / numElementPerThread; - if (numThreads == 0) + if (numThreads == 0) { return false; + } if (numThreads > threadsAvailable) { // If there are no enough remaining threads to distribute the current // dimension, try to use all remaining threads. But we still need to make // sure all work can be distributed to these threads evenly. - if (numThreads % threadsAvailable != 0) + if (numThreads % threadsAvailable != 0) { return false; + } numThreads = threadsAvailable; } - if (threadsAvailable % numThreads != 0) + if (threadsAvailable % numThreads != 0) { return false; + } threadsAvailable = threadsAvailable / numThreads; - if (threadsAvailable == 1) + if (threadsAvailable == 1) { break; + } } return threadsAvailable == 1; } @@ -200,8 +205,9 @@ FailureOr getGPUScfTileSizeComputeFn(mlir::FunctionOpInterface funcOp, int tilingLevel) { FailureOr> tileSizes = getGPUTileSize(funcOp, tilingLevel); - if (failed(tileSizes)) + if (failed(tileSizes)) { return failure(); + } scf::SCFTileSizeComputationFunction computeFn = [tileSizes](OpBuilder &builder, Operation *op) -> SmallVector { @@ -230,16 +236,18 @@ std::optional allocateWorkgroupMemory(OpBuilder &builder, mlir::FunctionOpInterface funcOp = subview->getParentOfType(); - if (!funcOp) + if (!funcOp) { return std::nullopt; + } // The subview size bounds are expected to be constant; they specify the shape // of the allocation. SmallVector shape; for (Value bound : sizeBounds) { APInt value; - if (!matchPattern(bound, m_ConstantInt(&value))) + if (!matchPattern(bound, m_ConstantInt(&value))) { return std::nullopt; + } shape.push_back(value.getSExtValue()); } @@ -272,10 +280,12 @@ static bool propagateCopyDestIntoProducerFill(memref::CopyOp copyOp) { } auto fillOp = dyn_cast(prevOp); - if (!fillOp) + if (!fillOp) { break; - if (fillOp.output() != copyOp.getSource()) + } + if (fillOp.output() != copyOp.getSource()) { break; + } // Move the fillOp and change the destination to the copy destination. fillOp->moveBefore(copyOp); fillOp.getOutputsMutable().assign(copyOp.getTarget()); @@ -327,10 +337,12 @@ propagateCopySourceIntoConsumerGeneric(memref::CopyOp copyOp, auto consumer = dyn_cast(nextOp); if (!consumer || consumer.getNumDpsInits() != 1 || !consumer.getMatchingIndexingMap(consumer.getDpsInitOperand(0)) - .isIdentity()) + .isIdentity()) { break; - if (*consumer.getOutputs().begin() != copyOp.getTarget()) + } + if (*consumer.getOutputs().begin() != copyOp.getTarget()) { break; + } insertInputValueIntoGeneric(copyOp.getSource(), consumer); toDelete.push_back(consumer); return true; @@ -346,12 +358,14 @@ void propagateSharedMemoryCopy(mlir::FunctionOpInterface funcOp) { funcOp.walk([&toDelete](memref::CopyOp copyOp) { if (hasMarker(copyOp, getCopyToWorkgroupMemoryMarker())) { if (propagateCopyDestIntoProducerFill(copyOp) || - propagateCopySourceIntoConsumerGeneric(copyOp, toDelete)) + propagateCopySourceIntoConsumerGeneric(copyOp, toDelete)) { toDelete.push_back(copyOp.getOperation()); + } } }); - for (Operation *op : toDelete) + for (Operation *op : toDelete) { op->erase(); + } } void insertBarriersAroundSharedMemoryCopy(mlir::FunctionOpInterface funcOp) { @@ -461,16 +475,18 @@ static Value warpReduction(Location loc, OpBuilder &builder, Value input, // integer type. auto unpack = [loc, &builder, needsPacking, equivIntType, origInputType](Value packedVal) -> Value { - if (!needsPacking) + if (!needsPacking) { return packedVal; + } auto asInt = arith::TruncIOp::create(builder, loc, equivIntType, packedVal); return arith::BitcastOp::create(builder, loc, origInputType, asInt); }; auto pack = [loc, &builder, needsPacking, equivIntType, shuffleIntType](Value unpackedVal) -> Value { - if (!needsPacking) + if (!needsPacking) { return unpackedVal; + } auto asInt = arith::BitcastOp::create(builder, loc, equivIntType, unpackedVal); return arith::ExtUIOp::create(builder, loc, shuffleIntType, asInt); @@ -667,8 +683,9 @@ std::optional> getWmmaNativeVectorSize(Operation *op) { return nativeSize; } if (auto writeOp = dyn_cast(op)) { - if (writeOp.getVectorType().getRank() < 2) + if (writeOp.getVectorType().getRank() < 2) { return std::nullopt; + } SmallVector nativeSize(writeOp.getVectorType().getRank() - 2, 1); nativeSize.append({m, n}); return nativeSize; @@ -679,11 +696,13 @@ std::optional> getWmmaNativeVectorSize(Operation *op) { VectorType sliceType; for (Operation *users : op->getUsers()) { auto extract = dyn_cast(users); - if (!extract) + if (!extract) { return std::nullopt; + } auto vecType = cast(extract.getResult().getType()); - if (sliceType && sliceType != vecType) + if (sliceType && sliceType != vecType) { return std::nullopt; + } sliceType = vecType; } return llvm::to_vector(sliceType.getShape()); @@ -692,8 +711,9 @@ std::optional> getWmmaNativeVectorSize(Operation *op) { if (auto vecType = dyn_cast(op->getResultTypes()[0])) { // TODO: The condition for unrolling elementwise should be restricted // only to operations that need unrolling (connected to the contract). - if (vecType.getRank() < 2) + if (vecType.getRank() < 2) { return std::nullopt; + } // First check whether there is a slice to infer the shape from. This is // required for cases where the accumulator type differs from the input @@ -702,15 +722,18 @@ std::optional> getWmmaNativeVectorSize(Operation *op) { VectorType sliceType; for (Operation *users : op->getUsers()) { auto extract = dyn_cast(users); - if (!extract) + if (!extract) { return std::nullopt; + } auto vecType = cast(extract.getResult().getType()); - if (sliceType && sliceType != vecType) + if (sliceType && sliceType != vecType) { return std::nullopt; + } sliceType = vecType; } - if (sliceType) + if (sliceType) { return llvm::to_vector(sliceType.getShape()); + } // Else unroll for trailing elementwise. SmallVector nativeSize(vecType.getRank() - 2, 1); @@ -729,12 +752,15 @@ std::optional> getWmmaNativeVectorSize(Operation *op) { static std::optional getVectorContractOpOperandId(vector::ContractionOp contractOp, OpResult result) { - if (contractOp.getLhs() == result) + if (contractOp.getLhs() == result) { return 0; - if (contractOp.getRhs() == result) + } + if (contractOp.getRhs() == result) { return 1; - if (contractOp.getAcc() == result) + } + if (contractOp.getAcc() == result) { return 2; + } return std::nullopt; } @@ -747,24 +773,30 @@ getVectorContractOpOperandIdForVectorReadOp(Operation *op) { // Check if the vector::TransferReadOp is consumed directly by // vector::ContractionOp. - if (op->use_empty()) + if (op->use_empty()) { return std::nullopt; + } Operation *firstLevelUser = *((op->getUsers()).begin()); - if (!firstLevelUser) + if (!firstLevelUser) { return std::nullopt; - if (auto contractOp = dyn_cast(firstLevelUser)) + } + if (auto contractOp = dyn_cast(firstLevelUser)) { return getVectorContractOpOperandId(contractOp, op->getResult(0)); + } // Check if the vector::TransferReadOp is consumed indirectly by // vector::ContractionOp. Only check until the second level of use-def chain. - if (firstLevelUser->use_empty()) + if (firstLevelUser->use_empty()) { return std::nullopt; + } Operation *secondLevelUser = *((firstLevelUser->getUsers()).begin()); - if (!secondLevelUser) + if (!secondLevelUser) { return std::nullopt; - if (auto contractOp = dyn_cast(secondLevelUser)) + } + if (auto contractOp = dyn_cast(secondLevelUser)) { return getVectorContractOpOperandId(contractOp, firstLevelUser->getResult(0)); + } return std::nullopt; } @@ -780,15 +812,15 @@ std::optional> getMmaNativeVectorSize(Operation *op) { Type sourceType = contract.getLhsType().getElementType(); // Set mmaShapeK based on sourceType. - if (sourceType.isInteger(4)) + if (sourceType.isInteger(4)) { mmaShapeK = 64; - else if (sourceType.isInteger(8)) + } else if (sourceType.isInteger(8)) { mmaShapeK = 32; - else if (sourceType.isF16() || sourceType.isBF16()) + } else if (sourceType.isF16() || sourceType.isBF16()) { mmaShapeK = 16; - else if (sourceType.isF32()) + } else if (sourceType.isF32()) { mmaShapeK = 8; - else { + } else { LDBG() << "unsupported shape for vector.contract: "; return std::nullopt; } @@ -803,8 +835,9 @@ std::optional> getMmaNativeVectorSize(Operation *op) { // Shape of warp-level vector write operation. if (auto writeOp = dyn_cast(op)) { - if (writeOp.getVectorType().getRank() < 2) + if (writeOp.getVectorType().getRank() < 2) { return std::nullopt; + } SmallVector outputShape(writeOp.getVectorType().getRank() - 2, 1); outputShape.append({mmaShapeM, mmaShapeN}); LDBG() << "shape for vector.xfer_write: " << llvm::interleaved(outputShape); @@ -892,11 +925,13 @@ std::optional> getMmaNativeVectorSize(Operation *op) { VectorType sliceType; for (Operation *users : op->getUsers()) { auto extract = dyn_cast(users); - if (!extract) + if (!extract) { return std::nullopt; + } auto vecType = cast(extract.getResult().getType()); - if (sliceType && sliceType != vecType) + if (sliceType && sliceType != vecType) { return std::nullopt; + } sliceType = vecType; } LDBG() << "shape for vector.xfer_read: " @@ -911,19 +946,24 @@ std::optional> getMmaNativeVectorSize(Operation *op) { bool hasGlobalMemoryAddressSpace(MemRefType memrefType) { Attribute addrSpace = memrefType.getMemorySpace(); - if (!addrSpace) + if (!addrSpace) { return true; + } auto intAttr = dyn_cast(addrSpace); // Accept both default numeric address space and HAL descriptor type address // space--the former is used by LLVMGPU while the latter is used by SPIR-V. - if (intAttr && intAttr.getInt() == 0) + if (intAttr && intAttr.getInt() == 0) { return true; + } auto gpuAttr = dyn_cast(addrSpace); - if (gpuAttr && gpuAttr.getValue() == gpu::AddressSpace::Global) + if (gpuAttr && gpuAttr.getValue() == gpu::AddressSpace::Global) { return true; + } auto amdgpuAttr = dyn_cast(addrSpace); - if (amdgpuAttr && amdgpuAttr.getValue() == amdgpu::AddressSpace::FatRawBuffer) + if (amdgpuAttr && + amdgpuAttr.getValue() == amdgpu::AddressSpace::FatRawBuffer) { return true; + } return isa(addrSpace); } @@ -970,8 +1010,9 @@ bool sharedMemTransposeFilter(AffineMap indexMap) { //===----------------------------------------------------------------------===// IREE::GPU::TargetAttr getCLGPUTarget(MLIRContext *context) { - if (clTestTarget.empty()) + if (clTestTarget.empty()) { return nullptr; + } auto [archAndFeatures, backend] = StringRef(clTestTarget).split("@"); if (backend.empty()) { @@ -979,16 +1020,17 @@ IREE::GPU::TargetAttr getCLGPUTarget(MLIRContext *context) { // for cases like "ampere" which can be accepted by both CUDA and Vulkan; // it's very limited. So it's targeting common cases to make writing tests // simpler. - if (StringRef(clTestTarget).starts_with("sm_")) + if (StringRef(clTestTarget).starts_with("sm_")) { backend = "cuda"; - else if (StringRef(clTestTarget).starts_with("gfx")) + } else if (StringRef(clTestTarget).starts_with("gfx")) { backend = "hip"; - else if (StringRef(clTestTarget).starts_with("adreno")) + } else if (StringRef(clTestTarget).starts_with("adreno")) { backend = "vulkan"; - else if (StringRef(clTestTarget).starts_with("apple")) + } else if (StringRef(clTestTarget).starts_with("apple")) { backend = "vulkan"; - else if (StringRef(clTestTarget).starts_with("valhall")) + } else if (StringRef(clTestTarget).starts_with("valhall")) { backend = "vulkan"; + } } auto [arch, features] = StringRef(archAndFeatures).split(':'); // Use the target specified in the command line for testing purposes. @@ -1041,11 +1083,13 @@ void addConfigWavesPerEu(MLIRContext *context, int64_t wavesPerEu, std::optional getGPUSubgroupSize(mlir::FunctionOpInterface func) { // First try to see if there is a subgroup size chosen in the CodeGen pipeline // configuration. - if (std::optional subgroupSize = getSubgroupSize(func)) + if (std::optional subgroupSize = getSubgroupSize(func)) { return subgroupSize.value(); + } // Then try to find the subgroup size from the target description. - if (IREE::GPU::TargetAttr target = getGPUTargetAttr(func)) + if (IREE::GPU::TargetAttr target = getGPUTargetAttr(func)) { return target.getPreferredSubgroupSize(); + } return std::nullopt; } diff --git a/compiler/src/iree/compiler/Codegen/Utils/LinalgOpInfo.cpp b/compiler/src/iree/compiler/Codegen/Utils/LinalgOpInfo.cpp index d07d8e28f060..70ed11c1b093 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/LinalgOpInfo.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/LinalgOpInfo.cpp @@ -127,8 +127,9 @@ void LinalgOpInfo::computeInfo(LinalgOp linalgOp) { bool isMatmulOrBatchMatmul(linalg::LinalgOp linalgOp) { // (Batch) matmul should be a reduction op with 2/3 parallel dimensions. if (!linalg::isaContractionOpInterface(linalgOp) || - !llvm::is_contained({2u, 3u}, linalgOp.getNumParallelLoops())) + !llvm::is_contained({2u, 3u}, linalgOp.getNumParallelLoops())) { return false; + } // Also exclude the case of matvec, which has only one non-unit parallel dim. // They should go down different pipelines. diff --git a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp index 40a174f5a877..f7563963b036 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp @@ -93,8 +93,9 @@ mergeModuleInto(Operation *sourceModuleOp, Operation *targetModuleOp, llvm::map_to_vector<8>(sourceBlock, [&](Operation &op) { return &op; }); for (auto &sourceOp : allOps) { - if (sourceOp->hasTrait()) + if (sourceOp->hasTrait()) { continue; + } if (auto symbolOp = dyn_cast(sourceOp)) { auto symbolName = symbolOp.getName(); @@ -172,13 +173,15 @@ replaceEntryPointUses(mlir::ModuleOp moduleOp, auto replaceSymbolRefs = [](Operation *rootOp, const DenseMap &map) { auto allUses = SymbolTable::getSymbolUses(rootOp); - if (!allUses) + if (!allUses) { return; + } for (auto use : *allUses) { auto oldAttr = use.getSymbolRef(); auto newAttr = map.lookup(oldAttr); - if (!newAttr) + if (!newAttr) { continue; + } auto newDict = use.getUser()->getAttrDictionary().replace( [&](Attribute attr) -> std::pair { if (attr == oldAttr) { @@ -267,8 +270,9 @@ LogicalResult linkExecutablesInto( // Merge sources into the linked source listing. if (auto sourcesAttr = variantOp.getSourcesAttr()) { - for (auto sourceAttr : sourcesAttr.getValue()) + for (auto sourceAttr : sourcesAttr.getValue()) { linkedSourceAttrs.set(sourceAttr.getName(), sourceAttr.getValue()); + } } // Remap variant refs. diff --git a/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.cpp index f4797747da7e..12d68539d0e7 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.cpp @@ -124,8 +124,9 @@ StringRef getDeleteMarker() { return "delete"; } StringRef getMarkerOrNull(Operation *op) { StringAttr attr = op->getAttrOfType(LinalgTransforms::kLinalgTransformMarker); - if (!attr) + if (!attr) { return ""; + } return attr.getValue(); } diff --git a/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.h b/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.h index f2c9a3fa80f6..3d2325d42c38 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.h +++ b/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.h @@ -51,8 +51,9 @@ struct LinalgTransformationFilter { bool hasReplacementFilter(Operation *op) const; LinalgTransformationFilter &addFilter(const FilterFunction &f) { - if (f) + if (f) { filters.push_back(f); + } return *this; } diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp index 8e7d511128cf..892194dd2513 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp @@ -296,8 +296,9 @@ std::array getMaxWorkgroupCount(Operation *op) { bool isReadOnly(Value v) { Operation *definingOp = v.getDefiningOp(); - if (!definingOp) + if (!definingOp) { return false; + } return TypeSwitch(definingOp) .Case( [&](arith::ConstantOp constantOp) { return true; }) @@ -536,8 +537,9 @@ LogicalResult setDefaultCustomOpLoweringConfig( for (Operation &op : dummyFuncOp.getBody().front()) { auto currLoweringConfig = getLoweringConfig(&op); - if (!currLoweringConfig) + if (!currLoweringConfig) { continue; + } // Translate the lowering config to the original operation. if (std::optional originalOperation = @@ -546,8 +548,9 @@ LogicalResult setDefaultCustomOpLoweringConfig( } auto currWorkgroupTileSizes = currLoweringConfig.getWorkgroupTileSizes(); - if (currWorkgroupTileSizes.empty()) + if (currWorkgroupTileSizes.empty()) { continue; + } workgroupTileSizes = currWorkgroupTileSizes; workgroupInterchange = currLoweringConfig.getWorkgroupInterchange(); } @@ -572,8 +575,9 @@ LogicalResult setDefaultCustomOpLoweringConfig( /// Returns the first of `exprs` which is of the type `T`. template static AffineExpr getAffineExprOfType(ArrayRef exprs) { - if (auto it = llvm::find_if(exprs, llvm::IsaPred); it != exprs.end()) + if (auto it = llvm::find_if(exprs, llvm::IsaPred); it != exprs.end()) { return *it; + } return nullptr; } @@ -611,8 +615,9 @@ static std::optional getDimension(Operation *op) { } template static std::optional getDimension(Operation *op) { - if (!op) + if (!op) { return std::nullopt; + } if (auto dimension = getDimension(op)) { return dimension; } @@ -630,8 +635,9 @@ checkDimensions(ArrayRef vals, std::optional refDimension = std::nullopt) { for (auto v : vals) { auto currDimension = getDimension(v.getDefiningOp()); - if (!currDimension) + if (!currDimension) { return std::nullopt; + } if (refDimension) { if (refDimension.value() != currDimension.value()) { return std::nullopt; @@ -891,8 +897,9 @@ isTiledAndDistributedLoop(scf::ForOp forOp) { countDim = ifx.getDimIndex(); } - if (!idDim || !countDim) + if (!idDim || !countDim) { return std::nullopt; + } Builder b(forOp.getContext()); loopInfo.untiledLowerBound = b.getIndexAttr(0); @@ -1083,8 +1090,9 @@ FailureOr getSoftwarePipelineStoreStage(DictionaryAttr config) { /// Returns a small tiling factor for the given reduction `dimSize`. /// Returns 0 to avoid tiling. int getReductionTilingFactor(int64_t dimSize) { - if (dimSize % 4 == 0) + if (dimSize % 4 == 0) { return 4; + } // Try to find the smallest prime factor as the tiling factor. As a trade off // between generated code size and compilation time, only look at prime @@ -1092,8 +1100,9 @@ int getReductionTilingFactor(int64_t dimSize) { static constexpr std::array primeNumbers = { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47}; for (int n : primeNumbers) { - if (dimSize % n == 0) + if (dimSize % n == 0) { return n; + } } return 1; // Otherwise just tile with size 1. @@ -1221,16 +1230,19 @@ Value findOrCreateSubspanBuffer( // Look for an existing op. Block *block = subspanOp->getBlock(); for (Operation &op : *block) { - if (&op == subspanOp.getOperation()) + if (&op == subspanOp.getOperation()) { break; + } auto bufferSubspanOp = dyn_cast(&op); - if (!bufferSubspanOp) + if (!bufferSubspanOp) { continue; + } auto bufferMemrefType = dyn_cast(bufferSubspanOp.getResult().getType()); - if (!bufferMemrefType) + if (!bufferMemrefType) { continue; + } if (bufferSubspanOp.getBinding() != subspanOp.getBinding() || bufferSubspanOp.getDescriptorType() != subspanOp.getDescriptorType() || @@ -1238,14 +1250,16 @@ Value findOrCreateSubspanBuffer( !llvm::equal(bufferSubspanOp.getDynamicDims(), subspanOp.getDynamicDims()) || bufferSubspanOp.getAlignment() != subspanOp.getAlignment() || - memRefType != bufferMemrefType) + memRefType != bufferMemrefType) { continue; + } if (useRocdlBuffers && bufferSubspanOp->hasOneUse()) { auto castOp = dyn_cast( *bufferSubspanOp->getUsers().begin()); - if (!castOp) + if (!castOp) { continue; + } return castOp.getResult(); } return bufferSubspanOp.getResult(); @@ -1284,8 +1298,9 @@ Operation *setInsertionPointAfterLastValue(OpBuilder &builder, definingOp = &cast(val).getOwner()->getOperations().front(); } - if (!definingOp) + if (!definingOp) { continue; + } if (lastOp && definingOp == lastOp) { // Combine 'setInsertionPointBefore' by ANDing because we only want to set // the insertion point before the last op if all values this operation is @@ -1293,8 +1308,9 @@ Operation *setInsertionPointAfterLastValue(OpBuilder &builder, setInsertionPointBefore &= isa(val); continue; } - if (lastOp && domInfo.dominates(definingOp, lastOp)) + if (lastOp && domInfo.dominates(definingOp, lastOp)) { continue; + } lastOp = definingOp; // For block arguments we want the insertion point to be at the start of @@ -1591,12 +1607,14 @@ void sinkOpsInCFG(const SmallVector &allocs, SmallVector getStaticNumWorkgroups(mlir::FunctionOpInterface funcOp) { SmallVector result; std::optional exportOp = getEntryPoint(funcOp); - if (!exportOp) + if (!exportOp) { return result; + } Block *body = exportOp->getWorkgroupCountBody(); - if (!body) + if (!body) { return result; + } auto returnOp = cast(body->getTerminator()); assert(returnOp.getNumOperands() == 3); @@ -1684,9 +1702,10 @@ computeDimUpperBound(Value shapedValue, unsigned dimNum, ValueBoundsConstraintSet::computeConstantBound( presburger::BoundType::UB, {shapedValue, dimNum}, /*stopCondition=*/nullptr, /*closedUB=*/true); - if (succeeded(maybeDimBoundSize)) + if (succeeded(maybeDimBoundSize)) { return DimBoundSize{/*baseSize=*/*maybeDimBoundSize, /*scalable=*/false}; + } return failure(); } FailureOr maybeDimBound = @@ -1694,21 +1713,26 @@ computeDimUpperBound(Value shapedValue, unsigned dimNum, shapedValue, dimNum, /*vscaleMin=*/vscaleRange->vscaleMin, /*vscaleMax=*/vscaleRange->vscaleMax, presburger::BoundType::UB); - if (failed(maybeDimBound)) + if (failed(maybeDimBound)) { return failure(); + } auto boundSize = maybeDimBound->getSize(); - if (succeeded(boundSize)) + if (succeeded(boundSize)) { return boundSize; - if (roundUp == RoundUpVscaleMultiple::No) + } + if (roundUp == RoundUpVscaleMultiple::No) { return failure(); + } // If the upper bound map is of the form `add(subExpr, cst)` (cst <= 0), // round it up to `subExpr` (and try matching the bound again). auto binOp = dyn_cast(maybeDimBound->map.getResult(0)); - if (!binOp || binOp.getKind() != AffineExprKind::Add) + if (!binOp || binOp.getKind() != AffineExprKind::Add) { return failure(); + } auto cst = dyn_cast(binOp.getRHS()); - if (!cst || cst.getValue() > 0) + if (!cst || cst.getValue() > 0) { return failure(); + } DimBound roundedDimBound{AffineMap::get(maybeDimBound->map.getNumDims(), maybeDimBound->map.getNumSymbols(), binOp.getLHS())}; @@ -2052,8 +2076,9 @@ std::optional static inferSizesFromMixedSizes( } std::optional inferSizesFromIR(Value val) { - if (!val.getDefiningOp()) + if (!val.getDefiningOp()) { return std::nullopt; + } std::optional result; LDBG() << "Inferring sizes for: " << val; @@ -2076,20 +2101,23 @@ std::optional inferSizesFromIR(Value val) { } std::optional getConstantIndex(Value value) { - if (!isa(value.getType())) + if (!isa(value.getType())) { return std::nullopt; + } APInt val; - if (!matchPattern(value, m_ConstantInt(&val))) + if (!matchPattern(value, m_ConstantInt(&val))) { return std::nullopt; + } return val.getSExtValue(); } bool alwaysRunsFirstIteration(scf::ForOp op) { // Can't perform the analysis if the loops's bounds aren't index-typed. - if (!op.getInductionVar().getType().isIndex()) + if (!op.getInductionVar().getType().isIndex()) { return false; + } FailureOr isLb = ValueBoundsConstraintSet::compare( getAsOpFoldResult(op.getLowerBound()), ValueBoundsConstraintSet::LT, getAsOpFoldResult(op.getUpperBound())); @@ -2098,8 +2126,9 @@ bool alwaysRunsFirstIteration(scf::ForOp op) { bool neverRunsSecondIteration(scf::ForOp op) { // Can't perform the analysis if the loops's bounds aren't index-typed. - if (!op.getInductionVar().getType().isIndex()) + if (!op.getInductionVar().getType().isIndex()) { return false; + } // If the upper bound (ub) is less than or equal to the loop step, then // lower bound + step must be greater than the upper bound, assuming the // lower bound is non-negative. diff --git a/compiler/src/iree/compiler/Codegen/VMVX/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/VMVX/KernelDispatch.cpp index a786176ea66a..5a8da4ac6778 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/KernelDispatch.cpp +++ b/compiler/src/iree/compiler/Codegen/VMVX/KernelDispatch.cpp @@ -28,8 +28,9 @@ getDefaultDistributionTileSizes(TilingInterface op) { llvm::DenseSet partitionedLoopsSet(partitionedLoops.begin(), partitionedLoops.end()); for (auto dim : llvm::seq(0, distTileSizes.size())) { - if (!partitionedLoopsSet.count(dim)) + if (!partitionedLoopsSet.count(dim)) { distTileSizes[dim] = 0; + } } return distTileSizes; diff --git a/compiler/src/iree/compiler/Codegen/VMVX/VMVXAssignConstantOrdinals.cpp b/compiler/src/iree/compiler/Codegen/VMVX/VMVXAssignConstantOrdinals.cpp index 89d64882c227..f36bb9ea059f 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/VMVXAssignConstantOrdinals.cpp +++ b/compiler/src/iree/compiler/Codegen/VMVX/VMVXAssignConstantOrdinals.cpp @@ -23,13 +23,15 @@ struct VMVXAssignConstantOrdinalsPass // Ignore non-VMVX variants. // TODO(benvanik): a way to nest this in the pipeline via dynamic passes. - if (variantOp.getTarget().getBackend().getValue() != "vmvx") + if (variantOp.getTarget().getBackend().getValue() != "vmvx") { return; + } // Get a constant key -> ordinal mapping. auto keyOrdinals = variantOp.gatherConstantOrdinals(); - if (keyOrdinals.empty()) + if (keyOrdinals.empty()) { return; + } // Update placeholders to hold the concrete ordinal values. // Eventually the VM global folding passes will inline them. @@ -39,8 +41,9 @@ struct VMVXAssignConstantOrdinalsPass moduleOp.getOps())) { auto keyAttr = globalOp->getAttr( IREE::HAL::ExecutableConstantBlockOp::getKeyAttrName()); - if (!keyAttr) + if (!keyAttr) { continue; + } auto it = keyOrdinals.find(keyAttr); if (it == keyOrdinals.end()) { globalOp.emitOpError() diff --git a/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerExecutableTargetPass.cpp b/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerExecutableTargetPass.cpp index ae8e5a3ebe9a..695af393f023 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerExecutableTargetPass.cpp +++ b/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerExecutableTargetPass.cpp @@ -54,8 +54,9 @@ void VMVXLowerExecutableTargetPass::runOnOperation() { mlir::FunctionOpInterface funcOp = getOperation(); auto translationInfo = getTranslationInfo(funcOp); - if (!translationInfo) + if (!translationInfo) { return; + } std::optional maybePipeline = getFunctionOpInterfacePassManager(funcOp); diff --git a/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerLinalgMicrokernels.cpp b/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerLinalgMicrokernels.cpp index fd40ecda75df..650e870d73bd 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerLinalgMicrokernels.cpp +++ b/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerLinalgMicrokernels.cpp @@ -179,8 +179,9 @@ class StridedBufferAnalysis { StridedBufferDescriptor &getDesc(OpBuilder &builder) { assert(isValid() && "invalid StridedBufferAnalysis"); - if (desc) + if (desc) { return *desc; + } OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointAfterValue(buffer); @@ -257,10 +258,12 @@ struct BinaryEmitter { } LogicalResult initialize(Location loc, PatternRewriter &rewriter) { - if (!isProjectedPermutation()) + if (!isProjectedPermutation()) { return rewriter.notifyMatchFailure(loc, "not projected permutation"); - if (maxRank() > 2) + } + if (maxRank() > 2) { return rewriter.notifyMatchFailure(loc, "rank > 2"); + } if (!operands.first.bufferAnal.isValid() || !operands.second.bufferAnal.isValid() || !result.bufferAnal.isValid()) { return rewriter.notifyMatchFailure(loc, @@ -370,10 +373,12 @@ struct UnaryEmitter { unsigned maxRank() { return std::max(operand.getRank(), result.getRank()); } LogicalResult initialize(Location loc, PatternRewriter &rewriter) { - if (!isProjectedPermutation()) + if (!isProjectedPermutation()) { return rewriter.notifyMatchFailure(loc, "not projected permutation"); - if (maxRank() > 2) + } + if (maxRank() > 2) { return rewriter.notifyMatchFailure(loc, "rank > 2"); + } if (!operand.bufferAnal.isValid() || !result.bufferAnal.isValid()) { return rewriter.notifyMatchFailure(loc, "could not compute buffer descriptor"); @@ -463,10 +468,12 @@ struct CopyEmitter { } LogicalResult initialize(Location loc, PatternRewriter &rewriter) { - if (!isProjectedPermutation()) + if (!isProjectedPermutation()) { return rewriter.notifyMatchFailure(loc, "not projected permutation"); - if (maxRank() > 2) + } + if (maxRank() > 2) { return rewriter.notifyMatchFailure(loc, "rank > 2"); + } // Initialize buffer descriptors. for (auto © : copies) { @@ -529,11 +536,13 @@ struct LinalgBinaryGenericConversion PatternRewriter &rewriter) const override { auto &children = op.getBlock()->getOperations(); // Only match two children (op + yield). - if (children.size() != 2) + if (children.size() != 2) { return failure(); + } // Only match parallel loops. - if (op.getNumParallelLoops() != op.getNumLoops()) + if (op.getNumParallelLoops() != op.getNumLoops()) { return failure(); + } // Match: // %0 = someop %arg2, %arg3 @@ -548,8 +557,9 @@ struct LinalgBinaryGenericConversion dyn_cast(binaryOp->getOperands()[0]); BlockArgument operandScalar1 = dyn_cast(binaryOp->getOperands()[1]); - if (!operandScalar0 || !operandScalar1) + if (!operandScalar0 || !operandScalar1) { return failure(); + } // Construct the emitter and start lowering. // Note that the operands may map to an out if the aliasing is safe, @@ -597,8 +607,9 @@ struct LinalgBinaryGenericConversion // Select the op to lower to and configure the emitter. // Emit from the iree_ukernel_x32b_opcode_t table. Type resultType = binaryOp->getResult(0).getType(); - if (!resultType.isIntOrFloat()) + if (!resultType.isIntOrFloat()) { return failure(); + } std::optional emitter = TypeSwitch>(binaryOp) .Case([&](arith::AddFOp op) -> std::optional { @@ -691,8 +702,9 @@ struct LinalgBinaryGenericConversion if (!emitter) { return rewriter.notifyMatchFailure(op, "unrecognized binary op"); } - if (failed(emitter->initialize(op.getLoc(), rewriter))) + if (failed(emitter->initialize(op.getLoc(), rewriter))) { return failure(); + } emitter->emit(op.getLoc(), rewriter); rewriter.eraseOp(op); @@ -709,11 +721,13 @@ struct LinalgUnaryGenericConversion PatternRewriter &rewriter) const override { auto &children = op.getBlock()->getOperations(); // Only match two children (op + yield). - if (children.size() != 2) + if (children.size() != 2) { return failure(); + } // Only match parallel loops. - if (op.getNumParallelLoops() != op.getNumLoops()) + if (op.getNumParallelLoops() != op.getNumLoops()) { return failure(); + } // Match: // %0 = someop %arg2 @@ -726,8 +740,9 @@ struct LinalgUnaryGenericConversion } BlockArgument operandScalar0 = dyn_cast(unaryOp->getOperands()[0]); - if (!operandScalar0) + if (!operandScalar0) { return failure(); + } // Construct the emitter and start lowering. // Note that the operands may map to an out if the aliasing is safe, @@ -755,8 +770,9 @@ struct LinalgUnaryGenericConversion // Select the op to lower to and configure the emitter. // Emit from the iree_ukernel_x32b_opcode_t table. Type resultType = unaryOp->getResult(0).getType(); - if (!resultType.isIntOrFloat()) + if (!resultType.isIntOrFloat()) { return failure(); + } std::optional emitter = TypeSwitch>(unaryOp) .Case([&](math::AbsFOp op) -> std::optional { @@ -814,8 +830,9 @@ struct LinalgUnaryGenericConversion if (!emitter) { return rewriter.notifyMatchFailure(op, "unrecognized unary op"); } - if (failed(emitter->initialize(op.getLoc(), rewriter))) + if (failed(emitter->initialize(op.getLoc(), rewriter))) { return failure(); + } emitter->emit(op.getLoc(), rewriter); rewriter.eraseOp(op); @@ -832,11 +849,13 @@ struct LinalgTrivialGenericConversion PatternRewriter &rewriter) const override { auto &children = op.getBlock()->getOperations(); // Only match one child (yield). - if (children.size() != 1) + if (children.size() != 1) { return failure(); + } // Only match parallel loops. - if (op.getNumParallelLoops() != op.getNumLoops()) + if (op.getNumParallelLoops() != op.getNumLoops()) { return failure(); + } // Presumed to be a yield terminator: configure the emitter. CopyEmitter emitter; @@ -857,8 +876,9 @@ struct LinalgTrivialGenericConversion } } - if (failed(emitter.initialize(op.getLoc(), rewriter))) + if (failed(emitter.initialize(op.getLoc(), rewriter))) { return failure(); + } emitter.emit(op.getLoc(), rewriter); rewriter.eraseOp(op); return success(); diff --git a/compiler/src/iree/compiler/Codegen/WGSL/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/WGSL/test/BUILD.bazel index 26b0d8f876e2..f9654b309494 100644 --- a/compiler/src/iree/compiler/Codegen/WGSL/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/WGSL/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ ], include = ["*.mlir"], diff --git a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp index 46502f9f59e2..4c6cbda20509 100644 --- a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp +++ b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp @@ -49,10 +49,12 @@ static llvm::cl::opt clEnableDebug( namespace { static bool isDebugEnabled() { - if (clEnableDebug) + if (clEnableDebug) { return true; - if (std::getenv("IREE_COMPILER_DEBUG_CONSTEVAL")) + } + if (std::getenv("IREE_COMPILER_DEBUG_CONSTEVAL")) { return true; + } return false; } @@ -73,6 +75,7 @@ struct CompileOptions { BindingOptions bindingOptions; InputDialectOptions inputOptions; PreprocessingOptions preprocessingOptions; + ParameterOptions parameterOptions; GlobalOptimizationOptions globalOptimizationOptions; DispatchCreationOptions dispatchCreationOptions; SchedulingOptions schedulingOptions; @@ -82,8 +85,9 @@ struct CompileOptions { }; static inline bool isAttrParameterized(Attribute attr) { - if (!attr) + if (!attr) { return false; + } return !isa(attr) && !isa(attr) && !isa(attr); } @@ -93,8 +97,9 @@ static inline bool isAccessorParameterized(const SymbolTable &moduleSymbols, AccessorTy op) { auto global = moduleSymbols.lookup(op.getGlobalName()); - if (!global) + if (!global) { return true; + } return isAttrParameterized(global.getGlobalInitialValue()); } @@ -117,8 +122,9 @@ static bool isParameterized(const SymbolTable &moduleSymbols, return isAttrParameterized(accessor.getValueAttr()); }) .Default([=](auto) { return false; }); - if (parameterized) + if (parameterized) { return WalkResult::interrupt(); + } return WalkResult::advance(); }); return res.wasInterrupted(); @@ -156,8 +162,9 @@ class InitializationAnalysis { Availability getInitializerAvailability(IREE::Util::InitializerOpInterface initializerOp) { auto it = initializerAvailability.find(initializerOp); - if (it == initializerAvailability.end()) + if (it == initializerAvailability.end()) { return Availability::Unknown; + } return it->second; } @@ -192,11 +199,13 @@ class InitializationAnalysis { Availability queryGlobalInitializationStatus(StringRef globalName, unsigned opOrdinal) { auto &timeline = globalTimelines[globalName]; - if (timeline.empty()) + if (timeline.empty()) { return Availability::Unknown; + } for (auto &timepoint : timeline) { - if (timepoint.first > opOrdinal) + if (timepoint.first > opOrdinal) { return timepoint.second; + } } return timeline.back().second; } @@ -221,10 +230,11 @@ class InitializationAnalysis { availability = static_cast( std::min(static_cast(availability), static_cast(newAvailability))); - if (previousAvailability != availability) + if (previousAvailability != availability) { emitDebugWarning( initializerOp.getLoc(), [&](InFlightDiagnostic &diagnostic) { diagnostic << reason; }); + } }; if (initializerOp->getRegions().size() != 1 || @@ -403,8 +413,9 @@ static LogicalResult cloneUsedObjects(FunctionOpInterface funcOp, OpBuilder &moduleBuilder) { // Gather all symbol uses within the function. auto uses = SymbolTable::getSymbolUses(funcOp); - if (!uses.has_value()) + if (!uses.has_value()) { return success(); + } // Verify that all uses are to object-like types we can clone. for (auto use : uses.value()) { @@ -415,14 +426,16 @@ static LogicalResult cloneUsedObjects(FunctionOpInterface funcOp, return use.getUser()->emitOpError() << "references undefined symbol " << use.getSymbolRef(); } - if (!objectOp->hasTrait()) + if (!objectOp->hasTrait()) { continue; + } // Check if the object exists in the target yet. Since we create the // target we know there should be no conflicts: the only symbols with the // same name will be already cloned copies of the same source. - if (targetSymbolTable.lookup(objectNameAttr)) + if (targetSymbolTable.lookup(objectNameAttr)) { continue; + } // Clone the object. It's isolated and safe to copy wholesale. auto *clonedOp = moduleBuilder.clone(*objectOp); @@ -463,16 +476,18 @@ class ProgramBuilder { // compile dynamic initializers. auto availability = initializationAnalysis.getInitializerAvailability(initializerOp); - if (availability != InitializationAnalysis::Availability::Compiler) + if (availability != InitializationAnalysis::Availability::Compiler) { return failure(); + } OpBuilder moduleBuilder = OpBuilder::atBlockEnd(targetModuleOp.getBody()); // Find any object-like symbol references used by the initializer and // clone them. if (failed(cloneUsedObjects(initializerOp, sourceSymbolTable, - targetSymbolTable, moduleBuilder))) + targetSymbolTable, moduleBuilder))) { return failure(); + } auto funcOp = IREE::Util::FuncOp::create( moduleBuilder, initializerOp.getLoc(), "jit_eval", @@ -535,8 +550,9 @@ class ProgramBuilder { for (auto constantOp : funcOp.getOps()) { auto tensorType = dyn_cast(constantOp.getResult().getType()); auto elementsAttr = dyn_cast(constantOp.getValue()); - if (!tensorType || !elementsAttr) + if (!tensorType || !elementsAttr) { continue; + } if (!supportedTypes.supportsType(tensorType)) { emitDebugWarning(funcOp.getLoc(), [&](InFlightDiagnostic &diagnostic) { diagnostic << "skipping consteval initializer: unsupported type for " @@ -631,7 +647,7 @@ class JitGlobalsPass final : public impl::JitGlobalsPassBase { buildIREEVMTransformPassPipeline( *targetRegistry.value, compileOptions->pipelineOptions, compileOptions->bindingOptions, compileOptions->inputOptions, - compileOptions->preprocessingOptions, + compileOptions->preprocessingOptions, compileOptions->parameterOptions, compileOptions->globalOptimizationOptions, compileOptions->dispatchCreationOptions, compileOptions->schedulingOptions, compileOptions->executableOptions, @@ -667,15 +683,18 @@ class JitGlobalsPass final : public impl::JitGlobalsPassBase { FunctionCall call(binary, jitFunction.argumentBindings.size(), jitFunction.resultBindings.size()); - if (failed(call.initialize(jitFunction.loc))) + if (failed(call.initialize(jitFunction.loc))) { return failure(); + } // Convert arguments. for (ArgumentBinding &arg : jitFunction.argumentBindings) { switch (arg.getType()) { case ArgumentBinding::Type::ElementsAttr: { - if (failed(call.addArgument(jitFunction.loc, arg.getElementsAttr()))) + if (failed( + call.addArgument(jitFunction.loc, arg.getElementsAttr()))) { return failure(); + } break; } case ArgumentBinding::Type::GlobalOp: { @@ -686,8 +705,10 @@ class JitGlobalsPass final : public impl::JitGlobalsPassBase { "invalid: global " << arg.getGlobalOp().getGlobalName() << " has no value"; } - if (failed(call.addArgument(arg.getGlobalOp().getLoc(), globalValue))) + if (failed( + call.addArgument(arg.getGlobalOp().getLoc(), globalValue))) { return failure(); + } break; } } @@ -705,8 +726,9 @@ class JitGlobalsPass final : public impl::JitGlobalsPassBase { TypedAttr attr; if (failed(call.getResultAsAttr( resultBinding.getGlobalOp().getLoc(), it.index(), - resultBinding.getGlobalOp().getGlobalType(), attr))) + resultBinding.getGlobalOp().getGlobalType(), attr))) { return failure(); + } resultBinding.getGlobalOp().setGlobalInitialValue(attr); break; } diff --git a/compiler/src/iree/compiler/ConstEval/Runtime.cpp b/compiler/src/iree/compiler/ConstEval/Runtime.cpp index a7c48f011580..db59da95d0cd 100644 --- a/compiler/src/iree/compiler/ConstEval/Runtime.cpp +++ b/compiler/src/iree/compiler/ConstEval/Runtime.cpp @@ -22,8 +22,9 @@ namespace { LogicalResult handleRuntimeError(Location loc, iree_status_t status, bool freeStatus = true) { - if (iree_status_is_ok(status)) + if (iree_status_is_ok(status)) { return success(); + } std::string statusString = iree::Status::ToString(status); if (freeStatus) { iree_status_ignore(status); @@ -213,8 +214,9 @@ FunctionCall::importSerializableAttr( LogicalResult FunctionCall::addBufferArgumentAttr( Location loc, IREE::Util::SerializableAttrInterface serializableAttr) { auto buffer = importSerializableAttr(loc, serializableAttr); - if (failed(buffer)) + if (failed(buffer)) { return failure(); + } return handleRuntimeError( loc, iree_vm_list_push_ref_move(inputs.get(), std::move(*buffer))); } @@ -230,14 +232,16 @@ LogicalResult FunctionCall::addBufferViewArgumentAttr( shape[i] = shapedType.getDimSize(i); } iree_hal_element_type_t elementType = IREE_HAL_ELEMENT_TYPE_NONE; - if (failed( - convertToElementType(loc, shapedType.getElementType(), &elementType))) + if (failed(convertToElementType(loc, shapedType.getElementType(), + &elementType))) { return failure(); + } // Import buffer contents. auto buffer = importSerializableAttr(loc, serializableAttr); - if (failed(buffer)) + if (failed(buffer)) { return failure(); + } // Construct buffer view. iree::vm::ref bufferView; @@ -245,8 +249,9 @@ LogicalResult FunctionCall::addBufferViewArgumentAttr( loc, iree_hal_buffer_view_create(buffer->get(), rank, shape, elementType, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, - iree_allocator_system(), &bufferView)))) + iree_allocator_system(), &bufferView)))) { return failure(); + } return handleRuntimeError( loc, iree_vm_list_push_ref_move(inputs.get(), std::move(bufferView))); @@ -351,12 +356,14 @@ LogicalResult FunctionCall::getResultAsAttr(Location loc, size_t index, Type mlirType, TypedAttr &outAttr) { iree_vm_variant_t variant = iree_vm_variant_empty(); if (failed(handleRuntimeError(loc, iree_vm_list_get_variant_assign( - outputs.get(), index, &variant)))) + outputs.get(), index, &variant)))) { return failure(); + } outAttr = binary.convertVariantToAttribute(loc, variant, mlirType); - if (!outAttr) + if (!outAttr) { return failure(); + } return success(); } @@ -400,8 +407,9 @@ TypedAttr CompiledBinary::convertVariantToAttribute(Location loc, iree_hal_element_type_t halElementType = iree_hal_buffer_view_element_type(bufferView); Type elementType = mapElementType(loc, halElementType); - if (!elementType) + if (!elementType) { return {}; + } auto tensorType = RankedTensorType::get(shape, elementType); diff --git a/compiler/src/iree/compiler/ConstEval/test/BUILD.bazel b/compiler/src/iree/compiler/ConstEval/test/BUILD.bazel index 428ca152d8ac..5c4f59509016 100644 --- a/compiler/src/iree/compiler/ConstEval/test/BUILD.bazel +++ b/compiler/src/iree/compiler/ConstEval/test/BUILD.bazel @@ -16,6 +16,7 @@ iree_lit_test_suite( name = "lit", timeout = "moderate", srcs = enforce_glob( + # keep sorted [ "compile_regressions.mlir", "failing.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Encoding/IR/BUILD.bazel index 9b506a2a873a..df2e48c9a953 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "EncodingAttrs.td", "EncodingBase.td", @@ -70,6 +71,7 @@ iree_compiler_cc_library( ":EncodingTypesGen", "//compiler/src/iree/compiler/Dialect/LinalgExt/Utils", "//compiler/src/iree/compiler/Dialect/TensorExt/IR", + "//compiler/src/iree/compiler/Utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:DialectUtils", diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Encoding/IR/CMakeLists.txt index 708af4889087..968fbb11d817 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/CMakeLists.txt @@ -54,6 +54,7 @@ iree_cc_library( MLIRTensorUtils iree::compiler::Dialect::LinalgExt::Utils iree::compiler::Dialect::TensorExt::IR + iree::compiler::Utils PUBLIC ) diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp index 3e84bdf07581..80cb384caaaf 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp @@ -7,6 +7,7 @@ #include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h" #include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h" +#include "iree/compiler/Utils/EncodingUtils.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" @@ -228,9 +229,7 @@ static FailureOr getComposedAffineMap(Attribute attr) { return AffineMap(); } // All entries should have type `AffineMapAttr`. - if (!llvm::all_of(mapsAttr, [](Attribute attr) { - return isa(attr); - })) { + if (!llvm::all_of(mapsAttr, llvm::IsaPred)) { return failure(); } AffineMap map = @@ -257,60 +256,6 @@ EncodingAttr EncodingAttr::get(MLIRContext *ctx, int64_t operandIndex, b.getTypeArrayAttr(elemTypes), mapsAttr, iterationSizesAttr); } -/// Parse a list of integer values and/or dynamic values ('?') -static FailureOr> -parseDynamicI64IntegerList(AsmParser &parser) { - SmallVector integerVals; - if (failed(parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&] { - int64_t value = ShapedType::kDynamic; - if (failed(parser.parseOptionalQuestion()) && - failed(parser.parseInteger(value))) { - return failure(); - } - integerVals.push_back(value); - return success(); - }))) { - return failure(); - } - return integerVals; -} - -/// Utility to parse an array of integer and/or dynamic values (`?`). -static ParseResult parseDynamicI64ArrayAttr(AsmParser &p, ArrayAttr &attr) { - FailureOr> integerVals = parseDynamicI64IntegerList(p); - if (failed(integerVals)) { - return failure(); - } - auto integerValsAttr = - llvm::map_to_vector(integerVals.value(), [&](int64_t val) -> Attribute { - return IntegerAttr::get(IntegerType::get(p.getContext(), 64), val); - }); - attr = ArrayAttr::get(p.getContext(), integerValsAttr); - return success(); -} - -/// Print a list of integer values and/or dynamic values ('?') -static void printDynamicI64IntegerList(AsmPrinter &printer, - ArrayRef vals) { - printer << "["; - llvm::interleaveComma(vals, printer, [&](int64_t val) { - if (ShapedType::isDynamic(val)) { - printer << "?"; - } else { - printer << val; - } - }); - printer << "]"; -} - -/// Utility to print an array of integer and/or dynamic values. Dynamic values -/// are printed as `?`. -static void printDynamicI64ArrayAttr(AsmPrinter &p, ArrayAttr attrs) { - SmallVector intVals = llvm::map_to_vector( - attrs, [&](Attribute attr) { return cast(attr).getInt(); }); - return printDynamicI64IntegerList(p, intVals); -} - LogicalResult EncodingAttr::verify(function_ref emitError, IntegerAttr operandIndexAttr, @@ -373,8 +318,9 @@ SmallVector EncodingAttr::getRootMaps() const { return cast(m).getAffineMap(); } if (auto mapsAttr = dyn_cast(m)) { - if (mapsAttr.empty()) + if (mapsAttr.empty()) { return AffineMap(); + } return cast(mapsAttr[0]).getAffineMap(); } return AffineMap(); @@ -392,8 +338,9 @@ AffineMap EncodingAttr::getLastMapForOperandIndex() const { return mapAttr.getAffineMap(); } if (auto mapsAttr = dyn_cast(indexingMap)) { - if (mapsAttr.empty()) + if (mapsAttr.empty()) { return AffineMap(); + } return cast(mapsAttr[mapsAttr.size() - 1]).getAffineMap(); } return AffineMap(); diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.td b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.td index 29e7de2b410e..af6649e91bfd 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.td +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.td @@ -23,13 +23,19 @@ def IREEEncoding_SetEncodingOp : IREEEncoding_PureOp<"set_encoding",[ Operation to assign an encoding to a tensor. The operation does not change the rank or extent of a tensor. Instead it adds a LayoutResolverAttr attribute to the tensor type to represent a change in layout. + + The optional `encoding_dims` operand carries dynamic values needed by the + encoding (e.g., M, N, K dimensions for matmul encodings). These values are + used for runtime layout selection based on problem size. }]; - let arguments = (ins AnyRankedTensor:$source); + let arguments = (ins + AnyRankedTensor:$source, + Variadic:$encoding_dims); let results = (outs AnyRankedTensor:$result); let assemblyFormat = [{ - attr-dict $source `:` type($source) `->` type($result) + attr-dict $source (`encoding_dims` `{` $encoding_dims^ `}`)? `:` type($source) `->` type($result) }]; let hasVerifier = 1; @@ -49,21 +55,27 @@ def IREEEncoding_SetEncodingOp : IREEEncoding_PureOp<"set_encoding",[ //===----------------------------------------------------------------------===// def IREEEncoding_UnsetEncodingOp : IREEEncoding_PureOp<"unset_encoding", [ - DeclareOpInterfaceMethods, Pure + DeclareOpInterfaceMethods, + AttrSizedOperandSegments, Pure ]> { let summary = [{Perform unpack and extract operation on source.}]; let description = [{ Operation to convert a tensor with LayoutResolverAttr encoding that represents its data layout into a tensor with default layout (i.e. no encoding). For now in IREE the default layout is row-major. + + The optional `encoding_dims` operand carries dynamic values needed by the + encoding (e.g., M, N, K dimensions for matmul encodings). These values are + used for runtime layout selection based on problem size. }]; let arguments = (ins AnyRankedTensor:$source, - Variadic:$result_dims); + Variadic:$result_dims, + Variadic:$encoding_dims); let results = (outs AnyRankedTensor:$result); let assemblyFormat = [{ - attr-dict $source `:` type($source) `->` type($result) (`` `{` $result_dims^ `}`)? + attr-dict $source (`encoding_dims` `{` $encoding_dims^ `}`)? `:` type($source) `->` type($result) (`` `{` $result_dims^ `}`)? }]; let hasVerifier = 1; diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Encoding/IR/test/BUILD.bazel index 55216bea1866..18bbcfefa300 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "invalid.mlir", "roundtrip.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/test/roundtrip.mlir b/compiler/src/iree/compiler/Dialect/Encoding/IR/test/roundtrip.mlir index 7cfdf9dcda2d..cc623587b571 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/test/roundtrip.mlir +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/test/roundtrip.mlir @@ -259,3 +259,36 @@ func.func @identity_encoding(%arg0: tensor) -> tensor } // CHECK: func.func @identity_encoding(%[[ARG0:.+]]: tensor + +// ----- + +#encoding = #iree_encoding.testing<> +func.func @set_encoding_with_encoding_dims(%arg0: tensor, %m: index, %n: index, %k: index) -> tensor { + %0 = iree_encoding.set_encoding %arg0 encoding_dims{%m, %n, %k} : tensor -> tensor + return %0 : tensor +} +// CHECK: #[[ENCODING:.+]] = #iree_encoding.testing<> +// CHECK: func.func @set_encoding_with_encoding_dims +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[M:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[N:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[K:[a-zA-Z0-9]+]]: index +// CHECK: iree_encoding.set_encoding %[[ARG0]] encoding_dims{%[[M]], %[[N]], %[[K]]} : tensor -> tensor + +// ----- + +#encoding = #iree_encoding.testing<> +func.func @unset_encoding_with_encoding_dims( + %arg0: tensor, %d0: index, %d1: index, %m: index, %n: index, %k: index) -> tensor { + %0 = iree_encoding.unset_encoding %arg0 encoding_dims{%m, %n, %k} : tensor -> tensor{%d0, %d1} + return %0 : tensor +} +// CHECK: #[[ENCODING:.+]] = #iree_encoding.testing<> +// CHECK: func.func @unset_encoding_with_encoding_dims +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[D0:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[D1:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[M:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[N:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[K:[a-zA-Z0-9]+]]: index +// CHECK: iree_encoding.unset_encoding %[[ARG0]] encoding_dims{%[[M]], %[[N]], %[[K]]} : tensor -> tensor{%[[D0]], %[[D1]]} diff --git a/compiler/src/iree/compiler/Dialect/Encoding/Utils/ElementPackingUtils.cpp b/compiler/src/iree/compiler/Dialect/Encoding/Utils/ElementPackingUtils.cpp index 10d75a6ae56b..ea00b86f7878 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/Utils/ElementPackingUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/Encoding/Utils/ElementPackingUtils.cpp @@ -63,19 +63,22 @@ static Type legalizeStorageElementTypeImpl(Type elementType, bool isPackedStorage) { // Only handle integers; floats in MLIR all have aligned widths (today). auto intType = dyn_cast(elementType); - if (!intType) + if (!intType) { return elementType; + } // For sub-byte elements, default to pack them into bytes. unsigned bitWidth = intType.getWidth(); - if (needToPackSubByteElementBitWidthImpl(bitWidth, isPackedStorage)) + if (needToPackSubByteElementBitWidthImpl(bitWidth, isPackedStorage)) { return elementType; + } // Otherwise, extend them to the next power-of-two bit width. unsigned alignedBitWidth = IREE::Util::getRoundedElementByteWidth(intType) * 8; - if (alignedBitWidth == bitWidth) + if (alignedBitWidth == bitWidth) { return elementType; + } return IntegerType::get(elementType.getContext(), alignedBitWidth, intType.getSignedness()); } @@ -115,8 +118,9 @@ Value calculateStorageElementCountInBytes(Location loc, } for (unsigned i = 0; i < shapedType.getRank(); ++i) { - if (!shapedType.isDynamicDim(i)) + if (!shapedType.isDynamicDim(i)) { staticCount *= shapedType.getDimSize(i); + } } // Scale by dynamic dims, if present. auto value = diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/ShardToFlow/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Conversion/ShardToFlow/test/BUILD.bazel index 5f15951f4cc8..1e48034998f7 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Conversion/ShardToFlow/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/ShardToFlow/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "channel_creation.mlir", "collectives.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp index 4a36049c9151..c4600d8dc2fd 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp @@ -402,8 +402,9 @@ struct ConvertTensorReshapePattern : public OpRewritePattern { SmallVector outputDynamicShapes; for (auto [resultShape, outputShp] : llvm::zip_equal( reshapeOp.getResultType().getShape(), outputShape[0])) { - if (ShapedType::isStatic(resultShape)) + if (ShapedType::isStatic(resultShape)) { continue; + } outputDynamicShapes.push_back(getValueOrCreateConstantIndexOp( rewriter, reshapeOp.getLoc(), outputShp)); } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Utils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Utils.cpp index e09f44a6ead5..aef20cd47dc7 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Utils.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Utils.cpp @@ -104,12 +104,14 @@ bool isOffsetSizeAndStrideMappableToFlow(ArrayRef offsets, int64_t staticSize = getVal(size, ShapedType::kDynamic); int64_t staticStride = getVal(stride, ShapedType::kDynamic); - if (staticStride != 1) + if (staticStride != 1) { return false; + } if (fullSlices == false) { - if (staticSize != 1) + if (staticSize != 1) { return false; + } } else { // TODO: Use ValueBoundsAnalysis to check whether two dynamic values // are equal. diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/BUILD.bazel index 80695f1e9590..c3164bbea997 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "bitcast.mlir", "cast.mlir", @@ -23,8 +24,8 @@ iree_lit_test_suite( "extract_slice.mlir", "fill.mlir", "from_elements.mlir", - "insert_slice.mlir", "insert.mlir", + "insert_slice.mlir", "reshape.mlir", ], include = ["*.mlir"], diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/IR/BUILD.bazel index 4ecd72fbfa41..65e4304312a5 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "FlowBase.td", "FlowInterfaces.td", diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp index c11185012d6f..e67988ddd0d3 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp @@ -49,8 +49,9 @@ struct ElideUnusedOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const override { - if (!op.use_empty()) + if (!op.use_empty()) { return failure(); + } rewriter.eraseOp(op); return success(); } @@ -59,13 +60,15 @@ struct ElideUnusedOp : public OpRewritePattern { // Returns true if |value| is definitely empty at runtime. static bool isTensorZeroElements(Value value) { auto type = dyn_cast(value.getType()); - if (!type) + if (!type) { return false; + } // Any static dimension being zero is definitely empty. for (int64_t i = 0; i < type.getRank(); ++i) { int64_t dim = type.getDimSize(i); - if (dim == 0) + if (dim == 0) { return true; + } } return false; // may still be dynamically empty } @@ -90,8 +93,9 @@ struct ReplaceOpIfTensorOperandZeroElements : public OpRewritePattern { LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const override { auto operand = op->getOperand(OperandIdx); - if (!isTensorOperandZeroElements(operand)) + if (!isTensorOperandZeroElements(operand)) { return failure(); + } auto result = op->getResult(ResultIdx); auto dynamicDims = op.getResultDynamicDims(result.getResultNumber()); rewriter.replaceOpWithNewOp(op, result.getType(), @@ -106,8 +110,9 @@ struct ReplaceOpIfTensorResultZeroElements : public OpRewritePattern { LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const override { auto result = op->getResult(ResultIdx); - if (!isTensorResultZeroElements(result)) + if (!isTensorResultZeroElements(result)) { return failure(); + } auto dynamicDims = op.getResultDynamicDims(result.getResultNumber()); rewriter.replaceOpWithNewOp(op, result.getType(), dynamicDims); @@ -122,8 +127,9 @@ struct ReplaceOpIfTensorOperandEmpty : public OpRewritePattern { PatternRewriter &rewriter) const override { auto operand = op->getOperand(OperandIdx); auto emptyOp = dyn_cast_if_present(operand.getDefiningOp()); - if (!emptyOp) + if (!emptyOp) { return failure(); + } auto result = op->getResult(ResultIdx); auto dynamicDims = op.getResultDynamicDims(result.getResultNumber()); rewriter.replaceOpWithNewOp(op, result.getType(), @@ -139,8 +145,9 @@ static SmallVector refreshDimsOnTypeChange(Operation *op, Type oldType, Type newType, ValueRange oldDims, PatternRewriter &rewriter) { - if (oldType == newType) + if (oldType == newType) { return llvm::to_vector(oldDims); + } // Build an expanded list of all the dims - constants will be nullptr. // This lets us map back the new types without worrying about whether some @@ -212,8 +219,9 @@ struct ReplaceDispatchResultIfZeroElements // will drop it. bool didReplaceAny = false; for (auto result : op.getResults()) { - if (result.use_empty()) + if (result.use_empty()) { continue; + } if (isTensorResultZeroElements(result)) { auto dynamicDims = op.getResultDynamicDims(result.getResultNumber()); auto emptyOp = IREE::Flow::TensorEmptyOp::create( @@ -392,8 +400,9 @@ struct DeduplicateDispatchEntryRefs final PatternRewriter &rewriter) const override { auto originalAttr = dispatchOp.getEntryPointsAttr(); auto newAttr = deduplicateArrayElements(originalAttr); - if (newAttr == originalAttr) + if (newAttr == originalAttr) { return failure(); + } rewriter.modifyOpInPlace(dispatchOp, [&]() { dispatchOp.setEntryPointsAttr(newAttr); }); return success(); @@ -598,8 +607,9 @@ struct ResolveShapedDim : public OpRewritePattern { if (dynamicDims.has_value()) { unsigned dimOffset = 0; for (unsigned i = 0; i < idx; ++i) { - if (shapedType.isDynamicDim(i)) + if (shapedType.isDynamicDim(i)) { ++dimOffset; + } } rewriter.replaceOp(op, dynamicDims.value()[dimOffset]); return success(); @@ -679,8 +689,9 @@ struct FoldSplatLoadIntoPrimitive : public OpRewritePattern { PatternRewriter &rewriter) const override { auto sourceOp = dyn_cast_if_present(loadOp.getSource().getDefiningOp()); - if (!sourceOp) + if (!sourceOp) { return failure(); + } rewriter.replaceOp(loadOp, sourceOp.getValue()); return success(); } @@ -699,8 +710,9 @@ void TensorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, OpFoldResult TensorStoreOp::fold(FoldAdaptor operands) { auto value = operands.getValue(); - if (!value) + if (!value) { return {}; + } if (auto target = dyn_cast_if_present(operands.getTarget())) { // Store into the constant target tensor. auto targetType = cast(target.getType()); @@ -751,8 +763,9 @@ struct FoldSplatReshapeIntoSplat : public OpRewritePattern { PatternRewriter &rewriter) const override { auto splatOp = dyn_cast_if_present( reshapeOp.getSource().getDefiningOp()); - if (!splatOp) + if (!splatOp) { return failure(); + } rewriter.replaceOpWithNewOp( reshapeOp, reshapeOp.getResult().getType(), splatOp.getValue(), reshapeOp.getResultDims()); @@ -1067,8 +1080,9 @@ struct FoldTensorUpdateOpWithCasts : public OpRewritePattern { PatternRewriter &rewriter) const override { auto targetCastOp = updateOp.getTarget().getDefiningOp(); auto updateCastOp = updateOp.getUpdate().getDefiningOp(); - if (!targetCastOp && !updateCastOp) + if (!targetCastOp && !updateCastOp) { return failure(); + } Value target = (targetCastOp ? cast(targetCastOp.getSource()) : cast(updateOp.getTarget())); Value update = (updateCastOp ? cast(updateCastOp.getSource()) @@ -1094,8 +1108,9 @@ struct ReplaceOpIfTensorUpdateOperandZeroElements LogicalResult matchAndRewrite(TensorUpdateOp op, PatternRewriter &rewriter) const override { auto operand = op.getUpdate(); - if (!isTensorOperandZeroElements(operand)) + if (!isTensorOperandZeroElements(operand)) { return failure(); + } rewriter.replaceOp(op, op.getTarget()); return success(); } diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp index e5d40af3533c..4a63c88f382f 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp @@ -172,13 +172,15 @@ static ParseResult parseShapedOperandList( valueTypes.emplace_back(); if (failed(parser.parseOperand(values.back())) || failed(parser.parseColon()) || - failed(parser.parseType(valueTypes.back()))) + failed(parser.parseType(valueTypes.back()))) { return failure(); + } if (int64_t dynamicDimCount = cast(valueTypes.back()).getNumDynamicDims()) { if (failed(parser.parseOperandList(valueDims, dynamicDimCount, - AsmParser::Delimiter::Braces))) + AsmParser::Delimiter::Braces))) { return failure(); + } } } while (succeeded(parser.parseOptionalComma())); return success(); @@ -248,13 +250,15 @@ static ParseResult parseWorkgroupCountRegionWithoutKeyword(OpAsmParser &parser, static void printWorkgroupCountRegionWithoutKeyword(OpAsmPrinter &p, Operation *op, Region &body) { - if (body.empty()) + if (body.empty()) { return; + } p << "("; auto args = body.getArguments(); for (unsigned i = 0; i < args.size(); ++i) { - if (i > 0) + if (i > 0) { p << ", "; + } p.printRegionArgument(args[i]); } p << ")"; @@ -277,8 +281,9 @@ static ParseResult parseWorkgroupCountRegion(OpAsmParser &parser, static void printWorkgroupCountRegion(OpAsmPrinter &p, Operation *op, Region &body) { - if (body.empty()) + if (body.empty()) { return; + } p << "workgroups"; printWorkgroupCountRegionWithoutKeyword(p, op, body); } @@ -293,8 +298,9 @@ static ParseResult parseDispatchWorkgroupsCountRegion(OpAsmParser &parser, static void printDispatchWorkgroupsCountRegion(OpAsmPrinter &p, Operation *op, Region &body) { - if (body.empty()) + if (body.empty()) { return; + } p << " count"; printWorkgroupCountRegionWithoutKeyword(p, op, body); } @@ -309,16 +315,19 @@ static ParseResult parseDispatchEntryPoints(OpAsmParser &parser, if (succeeded(parser.parseOptionalLBrace())) { do { SymbolRefAttr entryPointAttr; - if (failed(parser.parseAttribute(entryPointAttr))) + if (failed(parser.parseAttribute(entryPointAttr))) { return failure(); + } entryPointAttrs.push_back(entryPointAttr); } while (succeeded(parser.parseOptionalComma())); - if (failed(parser.parseRBrace())) + if (failed(parser.parseRBrace())) { return failure(); + } } else { SymbolRefAttr entryPointAttr; - if (failed(parser.parseAttribute(entryPointAttr))) + if (failed(parser.parseAttribute(entryPointAttr))) { return failure(); + } entryPointAttrs.push_back(entryPointAttr); } entryPointAttrsArray = parser.getBuilder().getArrayAttr(entryPointAttrs); @@ -388,11 +397,12 @@ LogicalResult DispatchRegionOp::verify() { << returnOp.getNumOperands() << ")"; } for (const auto [resultType, returnType] : - llvm::zip_equal(getResultTypes(), returnOp->getOperandTypes())) + llvm::zip_equal(getResultTypes(), returnOp->getOperandTypes())) { if (resultType != returnType) { return returnOp->emitOpError() << "operand types do not match with parent results"; } + } } // Make sure that all returned values are ranked tensors. @@ -423,36 +433,45 @@ ParseResult DispatchRegionOp::parse(OpAsmParser &parser, (void)workloadOperandsLoc; if (succeeded(parser.parseOptionalLSquare())) { workloadOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperandList(workloadOperands)) + if (parser.parseOperandList(workloadOperands)) { return failure(); - if (parser.parseRSquare()) + } + if (parser.parseRSquare()) { return failure(); + } } if (succeeded(parser.parseOptionalArrow())) { ParseResult typeListResult = parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() { - if (parser.parseType(resultTypes.emplace_back())) + if (parser.parseType(resultTypes.emplace_back())) { return failure(); + } auto shapedType = dyn_cast(resultTypes.back()); - if (!shapedType) + if (!shapedType) { return success(); - if (shapedType.hasStaticShape()) + } + if (shapedType.hasStaticShape()) { return success(); + } SmallVector dynamicDims; if (parser.parseOperandList(dynamicDims, shapedType.getNumDynamicDims(), - OpAsmParser::Delimiter::Braces)) + OpAsmParser::Delimiter::Braces)) { return failure(); + } allOperands.append(dynamicDims.begin(), dynamicDims.end()); return success(); }); - if (typeListResult) + if (typeListResult) { return failure(); + } } - if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) { return failure(); - if (parser.parseRegion(*bodyRegion)) + } + if (parser.parseRegion(*bodyRegion)) { return failure(); + } if (parseDispatchWorkgroupsCountRegion(parser, *workloadCountRegion)) { return failure(); @@ -466,8 +485,9 @@ ParseResult DispatchRegionOp::parse(OpAsmParser &parser, static_cast(workloadOperands.size())})); if (parser.resolveOperands(allOperands, parser.getBuilder().getIndexType(), - result.operands)) + result.operands)) { return failure(); + } if (parser.resolveOperands(workloadOperands, parser.getBuilder().getIndexType(), workloadOperandsLoc, result.operands)) { @@ -498,8 +518,9 @@ void DispatchRegionOp::print(OpAsmPrinter &p) { resultDimCounter += shapedType.getNumDynamicDims(); } } - if (it.index() < getNumResults() - 1) + if (it.index() < getNumResults() - 1) { p << ", "; + } } p << ")"; p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs); @@ -519,9 +540,11 @@ void DispatchRegionOp::print(OpAsmPrinter &p) { ValueRange DispatchRegionOp::getResultDynamicDims(unsigned idx) { unsigned counter = 0; - for (unsigned i = 0; i < idx; ++i) - if (auto shapedType = dyn_cast(getResultTypes()[i])) + for (unsigned i = 0; i < idx; ++i) { + if (auto shapedType = dyn_cast(getResultTypes()[i])) { counter += shapedType.getNumDynamicDims(); + } + } auto shapedType = dyn_cast(getResultTypes()[idx]); return getResultDims().slice(counter, shapedType ? shapedType.getNumDynamicDims() : 0); @@ -590,8 +613,9 @@ bool dropUnusedAndRedundantDispatchRegionResults( "expected that all dynamic dims were processed"); // Nothing to do if all results are used. - if (droppedResultValues.empty()) + if (droppedResultValues.empty()) { return false; + } // Create new region and move over the body. auto newRegionOp = @@ -850,12 +874,14 @@ LogicalResult DispatchWorkgroupsOp::verify() { return success(); }; for (auto type : getOperandTypes()) { - if (failed(verifyIOType(type))) + if (failed(verifyIOType(type))) { return failure(); + } } for (auto type : getResultTypes()) { - if (failed(verifyIOType(type))) + if (failed(verifyIOType(type))) { return failure(); + } } // Workgroup count region is optional. @@ -879,22 +905,26 @@ BlockArgument DispatchWorkgroupsOp::getOutputBlockArgument(unsigned idx) { // Some outputs are tied to inputs and share their block arguments. int64_t tiedOperand = cast((*tiedOperands)[idx]).getValue().getSExtValue(); - if (tiedOperand != IREE::Util::TiedOpInterface::kUntiedIndex) + if (tiedOperand != IREE::Util::TiedOpInterface::kUntiedIndex) { // This output is tied to an input. return getInputBlockArgument(tiedOperand); + } unsigned nextOutArgIdx = getArguments().size(); - for (unsigned i = 0; i < idx; ++i) + for (unsigned i = 0; i < idx; ++i) { if (cast((*tiedOperands)[i]).getValue().getSExtValue() == - IREE::Util::TiedOpInterface::kUntiedIndex) + IREE::Util::TiedOpInterface::kUntiedIndex) { nextOutArgIdx++; + } + } return getWorkgroupBody().getArguments()[nextOutArgIdx]; } SmallVector DispatchWorkgroupsOp::getOutputBlockArguments() { SmallVector result; - for (unsigned i = 0; i < getNumResults(); ++i) + for (unsigned i = 0; i < getNumResults(); ++i) { result.push_back(getOutputBlockArgument(i)); + } return result; } @@ -954,10 +984,12 @@ refineTensorAccess(Value value, IREE::TensorExt::DispatchTensorType type) { hasWrites = true; }); } - if (hasReads && !hasWrites) + if (hasReads && !hasWrites) { tensorAccess = IREE::TensorExt::TensorAccess::ReadOnly; - if (!hasReads && hasWrites) + } + if (!hasReads && hasWrites) { tensorAccess = IREE::TensorExt::TensorAccess::WriteOnly; + } } return tensorAccess; } @@ -1071,16 +1103,18 @@ DispatchWorkgroupsOp::cloneReplacementExcludingOperandsAndResults( auto erasedArguments = llvm::to_vector(excludedOperandIndices); for (unsigned i = baseResultIndex, e = newBody.getNumArguments(); i != e; ++i) { - if (!is_contained(excludedResultIndices, i - baseResultIndex)) + if (!is_contained(excludedResultIndices, i - baseResultIndex)) { continue; + } auto arg = newBody.front().getArgument(i); eraseArgUseTree(arg, rewriter); erasedArguments.push_back(i); } auto &block = newBody.front(); BitVector eraseIndices(block.getNumArguments()); - for (auto i : erasedArguments) + for (auto i : erasedArguments) { eraseIndices.set(i); + } block.eraseArguments(eraseIndices); return newOp; @@ -1093,8 +1127,9 @@ DispatchWorkgroupsOp::getTiedOperandsIndexAndLength() { SmallVector DispatchWorkgroupsOp::getTiedOperandsAsIntegerList() { ArrayAttr attr = getTiedOperandsAttr(); - if (!attr) + if (!attr) { return {}; + } return llvm::map_to_vector(attr, [](Attribute intAttr) { return cast(intAttr).getInt(); }); diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp index ec6abe700972..63971aa4cd2f 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp @@ -53,8 +53,9 @@ Type FlowDialect::parseType(DialectAsmParser &parser) const { Type type; OptionalParseResult parseResult = generatedTypeParser(parser, &mnemonic, type); - if (parseResult.has_value()) + if (parseResult.has_value()) { return type; + } parser.emitError(parser.getCurrentLocation()) << "unknown Flow type: " << mnemonic; return {}; diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/IR/test/BUILD.bazel index 0e8edf6c4512..111c516e29aa 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "call_ops.mlir", "dispatch_folding.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/BUILD.bazel index dba8a9ed4574..aec29f76740e 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "FlowExtensionsOps.td", ], diff --git a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp index 1639cf7172a9..69ba78396530 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp @@ -34,8 +34,9 @@ void registerTransformDialectFlowExtension(DialectRegistry ®istry) { static SmallVector getIndicesOfDynamicDims(ShapedType t) { int64_t numDynamicDims = t.getNumDynamicDims(); SmallVector res(numDynamicDims); - for (int64_t dim = 0; dim != numDynamicDims; ++dim) + for (int64_t dim = 0; dim != numDynamicDims; ++dim) { res[dim] = t.getDynamicDimIndex(dim); + } return res; } @@ -61,8 +62,9 @@ static LogicalResult populateWorkgroupCountComputingRegion( // TODO: Iteratively pull operations that are only consuming IndexType. for (Value v : forallOp.getUpperBound(rewriter)) { auto op = dyn_cast_if_present(v.getDefiningOp()); - if (!op) + if (!op) { return failure(); + } results.push_back( cast(rewriter.clone(*op)).getResult()); } @@ -124,17 +126,21 @@ static void rewriteExtractSlices(RewriterBase &rewriter, scf::ForallOp forallOp, IRMapping tensorToFlowBvm) { dispatchOp->walk([&](tensor::ExtractSliceOp extractSliceOp) { Value source = extractSliceOp.getSource(); - if (auto sourceBbArg = dyn_cast(source)) - if (sourceBbArg.getOwner()->getParentOp() == forallOp.getOperation()) + if (auto sourceBbArg = dyn_cast(source)) { + if (sourceBbArg.getOwner()->getParentOp() == forallOp.getOperation()) { source = forallOp.getTiedOpOperand(sourceBbArg)->get(); + } + } auto it = llvm::find(tensorOperands, source); - if (it == tensorOperands.end()) + if (it == tensorOperands.end()) { return; + } int64_t index = std::distance(tensorOperands.begin(), it); Value sourceFlow = tensorToFlowBvm.lookupOrNull(source); - if (!sourceFlow) + if (!sourceFlow) { return; + } Location loc = extractSliceOp.getLoc(); OpBuilder::InsertionGuard g(rewriter); @@ -162,22 +168,26 @@ static void cloneOpsIntoForallOp(RewriterBase &rewriter, // Add all ops who's results are used inside the ForallOp to the // worklist. llvm::SetVector worklist; - for (Value v : valuesDefinedAbove) - if (Operation *op = v.getDefiningOp()) + for (Value v : valuesDefinedAbove) { + if (Operation *op = v.getDefiningOp()) { worklist.insert(op); + } + } llvm::SmallVector opsToClone; llvm::DenseSet visited; // Process all ops in the worklist. while (!worklist.empty()) { Operation *op = worklist.pop_back_val(); - if (visited.contains(op)) + if (visited.contains(op)) { continue; + } visited.insert(op); // Do not clone ops that are not clonable. - if (!IREE::Flow::isClonableIntoDispatchOp(op)) + if (!IREE::Flow::isClonableIntoDispatchOp(op)) { continue; + } // Do not clone ParallelInsertSliceOp destinations. bool isDestination = any_of( @@ -186,16 +196,18 @@ static void cloneOpsIntoForallOp(RewriterBase &rewriter, .getDest() .getDefiningOp() == op; }); - if (isDestination) + if (isDestination) { continue; + } opsToClone.push_back(op); // Add all operands to the worklist. for (Value operand : op->getOperands()) { Operation *operandOp = operand.getDefiningOp(); - if (!operandOp) + if (!operandOp) { continue; + } worklist.insert(operandOp); } } @@ -206,9 +218,11 @@ static void cloneOpsIntoForallOp(RewriterBase &rewriter, for (Operation *op : llvm::reverse(opsToClone)) { Operation *cloned = rewriter.clone(*op); SmallVector uses; - for (OpOperand &use : op->getUses()) - if (forallOp->isProperAncestor(use.getOwner())) + for (OpOperand &use : op->getUses()) { + if (forallOp->isProperAncestor(use.getOwner())) { uses.push_back(&use); + } + } for (OpOperand *use : uses) { unsigned resultNum = cast(use->get()).getResultNumber(); rewriter.modifyOpInPlace( @@ -264,13 +278,15 @@ rewriteForeachThreadToFlowDispatchWorkgroups(scf::ForallOp forallOp, BlockArgument destBbArg = cast(parallelInsertOp.getDest()); Value dest = forallOp.getTiedOpOperand(destBbArg)->get(); bool inserted = resultTensorOperands.insert(dest); - if (!inserted) + if (!inserted) { continue; + } auto dynamicDims = getIndicesOfDynamicDims(cast(dest.getType())); - for (int64_t dim : dynamicDims) + for (int64_t dim : dynamicDims) { resultTensorsDynamicDims.insert( tensor::DimOp::create(rewriter, loc, dest, dim)); + } } assert(resultTensorOperands.size() == forallOp.getNumResults() && "Expected as many resultTensorOperands as results of forallOp"); @@ -289,21 +305,25 @@ rewriteForeachThreadToFlowDispatchWorkgroups(scf::ForallOp forallOp, nonTensorOperands.push_back(v); continue; } - if (resultTensorOperands.contains(v)) + if (resultTensorOperands.contains(v)) { continue; + } tensorOperands.push_back(v); - for (int64_t dim : getIndicesOfDynamicDims(tensorType)) + for (int64_t dim : getIndicesOfDynamicDims(tensorType)) { tensorDynamicDims.push_back(tensor::DimOp::create(rewriter, loc, v, dim)); + } } // Also add shared outputs. (These are usually already added as result // tensor operands.) for (Value v : forallOp.getOutputs()) { auto tensorType = cast(v.getType()); - if (resultTensorOperands.contains(v)) + if (resultTensorOperands.contains(v)) { continue; + } tensorOperands.push_back(v); - for (int64_t dim : getIndicesOfDynamicDims(tensorType)) + for (int64_t dim : getIndicesOfDynamicDims(tensorType)) { tensorDynamicDims.push_back(tensor::DimOp::create(rewriter, loc, v, dim)); + } } // Step 3. Create ordered vectors of operands to pass to the builder and @@ -340,10 +360,11 @@ rewriteForeachThreadToFlowDispatchWorkgroups(scf::ForallOp forallOp, // Step 4. Outline the compute workload region and set up the workload // operands. if (failed(populateWorkgroupCountComputingRegion(rewriter, forallOp, - dispatchOp))) + dispatchOp))) { return forallOp->emitOpError( "failed to populate workload region for dispatchOp: ") << dispatchOp; + } // Step 5. Fixup dispatchOp bbArgs and terminator. // TODO: Ideally the builder would have created the proper bbArgs and the @@ -465,8 +486,9 @@ IREE::transform_dialect::ForeachThreadToFlowDispatchWorkgroupsOp::applyToOne( IRRewriter patternRewriter(target->getContext()); FailureOr result = rewriteForeachThreadToFlowDispatchWorkgroups(target, patternRewriter); - if (failed(result)) + if (failed(result)) { return emitDefaultDefiniteFailure(target); + } results.push_back(*result); return DiagnosedSilenceableFailure::success(); } @@ -484,8 +506,9 @@ IREE::transform_dialect::RegionToWorkgroupsOp::applyToOne( transform::ApplyToEachResultList &results, transform::TransformState &) { FailureOr result = rewriteFlowDispatchRegionToFlowDispatchWorkgroups(target, rewriter); - if (failed(result)) + if (failed(result)) { return emitDefaultDefiniteFailure(target); + } results.push_back(*result); return DiagnosedSilenceableFailure::success(); } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/AnnotateDispatches.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/AnnotateDispatches.cpp index fdb0b3e0a4fd..05a0b7727ad6 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/AnnotateDispatches.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/AnnotateDispatches.cpp @@ -80,8 +80,9 @@ static TensorType getMainTensorForLinalgExtOp(Operation *op) { auto resultTypes = llvm::to_vector(op->getResultTypes()); for (Type t : llvm::concat(operandTypes, resultTypes)) { auto tensorType = dyn_cast(t); - if (!tensorType) + if (!tensorType) { continue; + } if (!main) { main = tensorType; } else if (costOfDomain(tensorType.getShape()) > @@ -182,19 +183,22 @@ static std::string getLinalgDataTypes(linalg::LinalgOp op) { static std::string getOpNameWithoutDialectName(Operation *op) { auto opName = op->getName().getStringRef().drop_until([](char c) { return c == '.'; }); - if (opName.starts_with(".")) + if (opName.starts_with(".")) { opName = opName.drop_front(); + } return opName.str(); } static bool isMatvecLike(linalg::LinalgOp linalgOp) { - if (!linalg::isaContractionOpInterface(linalgOp)) + if (!linalg::isaContractionOpInterface(linalgOp)) { return false; + } FailureOr dims = linalg::inferContractionDims(linalgOp); - if (failed(dims)) + if (failed(dims)) { return false; + } // One of the input should have all the parallel dimensions with size one. SmallVector bounds = linalgOp.getStaticLoopRanges(); @@ -207,8 +211,9 @@ static bool isMatvecLike(linalg::LinalgOp linalgOp) { unsigned pos = cast(result).getPosition(); // For a parallel dim, the bounds can be non-one if it's batch dim. if (iterators[pos] == utils::IteratorType::parallel && bounds[pos] != 1 && - !llvm::is_contained(dims->batch, pos)) + !llvm::is_contained(dims->batch, pos)) { return false; + } } return true; }; @@ -316,8 +321,9 @@ static std::string summarizeLinalgOp(linalg::LinalgOp op) { if (prefix.empty()) { // By default, use the op name as prefix. auto opName = op->getName().getStringRef(); - if (!opName.consume_front("linalg.")) + if (!opName.consume_front("linalg.")) { return ""; + } prefix = opName.str(); } @@ -331,8 +337,9 @@ static std::string summarizeLinalgExtOp(Operation *op) { auto opName = op->getName().getStringRef(); // Currently, this utility is also invoked by Linalg::SoftmaxOp. if (!(opName.consume_front("iree_linalg_ext.") || - opName.consume_front("linalg."))) + opName.consume_front("linalg."))) { return ""; + } std::string suffix = ""; if (TensorType mainTensor = getMainTensorForLinalgExtOp(op)) { llvm::raw_string_ostream sstream(suffix); @@ -382,8 +389,9 @@ static std::string summarizeDispatchRegion(Region ®ion) { TypeSwitch(op) .Case([&](auto op) { int64_t estimatedCost = estimateLinalgSoftmaxOpCost(op); - if (estimatedCost < bestEstimatedCost) + if (estimatedCost < bestEstimatedCost) { return; + } bestEstimatedCost = estimatedCost; bestOp = op; LLVM_DEBUG(llvm::dbgs() << "// new best op: '" << bestOp->getName() @@ -391,8 +399,9 @@ static std::string summarizeDispatchRegion(Region ®ion) { }) .Case([&](auto op) { int64_t estimatedCost = estimateLinalgOpCost(op); - if (estimatedCost < bestEstimatedCost) + if (estimatedCost < bestEstimatedCost) { return; + } bestEstimatedCost = estimatedCost; bestOp = op; LLVM_DEBUG(llvm::dbgs() << "// new best op: '" << bestOp->getName() @@ -403,8 +412,9 @@ static std::string summarizeDispatchRegion(Region ®ion) { // SetEncoding/UnsetEncoding/PackOp/UnPackOp is the bestOp only if // there are no other operations. int64_t estimatedCost = kMinEstimatedCost + 1; - if (estimatedCost < bestEstimatedCost) + if (estimatedCost < bestEstimatedCost) { return; + } bestEstimatedCost = estimatedCost; bestOp = op; LLVM_DEBUG(llvm::dbgs() << "// new best op: '" << bestOp->getName() @@ -412,8 +422,9 @@ static std::string summarizeDispatchRegion(Region ®ion) { }) .Case([&](auto op) { int64_t estimatedCost = estimateLinalgExtOpCost(op); - if (estimatedCost < bestEstimatedCost) + if (estimatedCost < bestEstimatedCost) { return; + } bestEstimatedCost = estimatedCost; bestOp = op; LLVM_DEBUG(llvm::dbgs() << "// new best op: '" << bestOp->getName() @@ -507,8 +518,9 @@ struct AnnotateDispatchesPass for (auto executableOp : getOperation().getBody()->getOps()) { auto innerModuleOp = executableOp.getInnerModule(); - if (!innerModuleOp) + if (!innerModuleOp) { continue; + } for (auto exportOp : executableOp.getBlock().getOps()) { auto oldSymbolRefAttr = SymbolRefAttr::get( @@ -517,11 +529,13 @@ struct AnnotateDispatchesPass auto funcOp = innerModuleOp.lookupSymbol( exportOp.getFunctionRef()); - if (!funcOp) + if (!funcOp) { continue; // extern module, maybe + } std::string summary = summarizeDispatchRegion(funcOp.getFunctionBody()); - if (summary.empty()) + if (summary.empty()) { continue; // unable to tell + } std::string newName = funcOp.getName().str() + "_" + summary; diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalize.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalize.cpp index 109f00efd2a7..9450fd50cb9e 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalize.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalize.cpp @@ -94,8 +94,9 @@ class AffineApplyLowering : public OpRewritePattern { auto maybeExpandedMap = affine::expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), llvm::to_vector<8>(op.getOperands())); - if (!maybeExpandedMap) + if (!maybeExpandedMap) { return failure(); + } rewriter.replaceOp(op, *maybeExpandedMap); return success(); } @@ -113,10 +114,12 @@ struct CanonicalizePass : public impl::CanonicalizePassBase { mlir::GreedySimplifyRegionLevel::Normal); RewritePatternSet owningPatterns(context); - for (auto *dialect : context->getLoadedDialects()) + for (auto *dialect : context->getLoadedDialects()) { dialect->getCanonicalizationPatterns(owningPatterns); - for (RegisteredOperationName op : context->getRegisteredOperations()) + } + for (RegisteredOperationName op : context->getRegisteredOperations()) { op.getCanonicalizationPatterns(owningPatterns, context); + } // Pull in some borderline/downstream canonicalizations for the Flow // compilation phase. diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CaptureDynamicDims.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CaptureDynamicDims.cpp index 57b1d1ac093b..b53fe9e783a1 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CaptureDynamicDims.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CaptureDynamicDims.cpp @@ -47,8 +47,9 @@ static void captureDims(IREE::Flow::DispatchWorkgroupsOp dispatchOp) { outerToInnerMap[operand] = arg; } for (auto result : dispatchOp.getResults()) { - if (dispatchOp.getTiedResultOperand(result)) + if (dispatchOp.getTiedResultOperand(result)) { continue; // ignored tied + } auto arg = entryBlock->getArgument(argIdx++); outerToInnerMap[result] = arg; } @@ -59,16 +60,19 @@ static void captureDims(IREE::Flow::DispatchWorkgroupsOp dispatchOp) { auto captureTensorDims = [&](Value externalValue, Value internalValue) { auto tensorType = dyn_cast(internalValue.getType()); - if (!tensorType) + if (!tensorType) { return; - if (tensorType.hasStaticShape()) + } + if (tensorType.hasStaticShape()) { return; + } // Find the dimensions in the parent. auto maybeDynamicDims = IREE::Util::findDynamicDims( externalValue, dispatchOp->getBlock(), Block::iterator(dispatchOp)); - if (!maybeDynamicDims.has_value()) + if (!maybeDynamicDims.has_value()) { return; + } // Convert to a vector -- we cannot use the ValueRange directly because // it might point into the operand list of this op, which we might mutate // in-place. @@ -116,8 +120,9 @@ static void captureDims(IREE::Flow::DispatchWorkgroupsOp dispatchOp) { captureTensorDims(operand, outerToInnerMap[operand]); } for (auto result : dispatchOp.getResults()) { - if (dispatchOp.getTiedResultOperand(result)) + if (dispatchOp.getTiedResultOperand(result)) { continue; // ignore tied + } captureTensorDims(result, outerToInnerMap[result]); } } @@ -141,19 +146,22 @@ static void captureDims(scf::ForOp forOp) { llvm::zip_equal(forOp.getInitArgs(), forOp.getYieldedValues(), forOp.getRegionIterArgs(), forOp.getResults())) { auto tensorType = dyn_cast(init.getType()); - if (!tensorType || tensorType.hasStaticShape()) + if (!tensorType || tensorType.hasStaticShape()) { continue; + } // Make the transform idempotent by not caring about tensors only used // within 'flow.tensor.tie_shape' operations. - if (llvm::all_of(bbArg.getUsers(), llvm::IsaPred)) + if (llvm::all_of(bbArg.getUsers(), llvm::IsaPred)) { continue; + } dynamicTensorIterables.push_back({init, iter, bbArg, result}); } - if (dynamicTensorIterables.empty()) + if (dynamicTensorIterables.empty()) { return; + } // Create the new dimension loop variables. Since the dynamic tensors may be // of different types with varying number of dynamic dimensions, 'dimBounds' @@ -169,26 +177,31 @@ static void captureDims(scf::ForOp forOp) { dimBounds.push_back(newIterables.size()); std::optional initDynamicDims = IREE::Util::findDynamicDims( init, forOp->getBlock(), Block::iterator(forOp)); - if (!initDynamicDims) + if (!initDynamicDims) { continue; + } std::optional iterDynamicDims = IREE::Util::findDynamicDims( iter, forOp.getBody(), Block::iterator(forOp.getBody()->getTerminator())); - if (!iterDynamicDims) + if (!iterDynamicDims) { continue; + } - if (iterDynamicDims->size() != initDynamicDims->size()) + if (iterDynamicDims->size() != initDynamicDims->size()) { continue; + } for (auto [initDim, iterDim] : - llvm::zip_equal(*initDynamicDims, *iterDynamicDims)) + llvm::zip_equal(*initDynamicDims, *iterDynamicDims)) { newIterables.push_back({initDim, iterDim}); + } } dimBounds.push_back(newIterables.size()); - if (newIterables.empty()) + if (newIterables.empty()) { return; + } // A new 'scf.for' has to be created to replace the old one as new results // are being added. @@ -223,8 +236,9 @@ static void captureDims(scf::ForOp forOp) { auto dims = ArrayRef(newIterables) .slice(dimBounds[index], dimBounds[index + 1] - dimBounds[index]); - if (dims.empty()) + if (dims.empty()) { continue; + } Value tied = Flow::TensorTieShapeOp::create( builder, forOp.getLoc(), tensor.bbArg, @@ -242,8 +256,9 @@ static void captureDims(scf::ForOp forOp) { auto dims = ArrayRef(newIterables) .slice(dimBounds[index], dimBounds[index + 1] - dimBounds[index]); - if (dims.empty()) + if (dims.empty()) { continue; + } Value &replacement = results[tensor.result.getResultNumber()]; replacement = Flow::TensorTieShapeOp::create( diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CleanupTensorShapes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CleanupTensorShapes.cpp index db3957c96f2f..6b4d127ff8c7 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CleanupTensorShapes.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CleanupTensorShapes.cpp @@ -34,8 +34,9 @@ struct CleanupTensorShapesPass foundBadOps = true; } }); - if (foundBadOps) + if (foundBadOps) { return signalPassFailure(); + } } }; diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp index 3e3a46a678d2..f1214d07f4e2 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp @@ -36,8 +36,9 @@ static void appendDynamicDims(OpBuilder &b, Location loc, } for (auto dim : llvm::enumerate(tensorType.getShape())) { - if (ShapedType::isStatic(dim.value())) + if (ShapedType::isStatic(dim.value())) { continue; + } argumentDims.push_back( b.createOrFold(loc, tensor, dim.index())); } @@ -50,8 +51,9 @@ findFirstTiedValueOutsideOfRegionOp(IREE::Flow::DispatchRegionOp regionOp, Value value) { // Check if `v` is defined outside of `regionOp`. auto isOutside = [&](Value v) { - if (isa(v)) + if (isa(v)) { return !regionOp->isAncestor(v.getDefiningOp()); + } assert(isa(v) && "expected bbArg"); // DispatchRegionOp does not have block arguments. return true; @@ -99,16 +101,18 @@ rewriteFlowDispatchRegionToFlowDispatchWorkgroups( llvm::SetVector argumentsSet; mlir::getUsedValuesDefinedAbove(region, argumentsSet); // Unranked tensors are not supported. - assert(!llvm::any_of(argumentsSet, [](Value v) { - return isa(v.getType()); - }) && "unranked tensors are not supported"); + assert(llvm::none_of( + argumentsSet, + [](Value v) { return isa(v.getType()); }) && + "unranked tensors are not supported"); // Compute dimensions of tensor args. SmallVector argumentDims; for (Value tensor : argumentsSet) { auto tensorType = dyn_cast(tensor.getType()); - if (!tensorType) + if (!tensorType) { continue; + } appendDynamicDims(rewriter, loc, argumentDims, tensor); } @@ -129,13 +133,15 @@ rewriteFlowDispatchRegionToFlowDispatchWorkgroups( llvm::enumerate(origTerminators.front()->getOperands())) { auto tiedArgument = findFirstTiedValueOutsideOfRegionOp(regionOp, it.value()); - if (!tiedArgument.has_value()) + if (!tiedArgument.has_value()) { continue; + } assert(argumentsSet.contains(*tiedArgument) && "expected that tiedArgument is already an argument"); // Do not tie an argument to multiple results. - if (tiedArgumentsSet.contains(*tiedArgument)) + if (tiedArgumentsSet.contains(*tiedArgument)) { continue; + } tiedArgumentsSet.insert(*tiedArgument); tiedArguments[it.index()] = std::distance( argumentsSet.begin(), llvm::find(argumentsSet, *tiedArgument)); diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertShardToFlow.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertShardToFlow.cpp index 3b224f821408..19e294a1dc80 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertShardToFlow.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertShardToFlow.cpp @@ -6,7 +6,6 @@ #include "iree/compiler/Dialect/Flow/Conversion/ShardToFlow/Patterns.h" #include "iree/compiler/Dialect/Flow/Transforms/Passes.h" -#include "iree/compiler/Utils/Folding.h" #include "iree/compiler/Utils/Indexing.h" #include "iree/compiler/Utils/OpVisitor.h" #include "iree/compiler/Utils/Permutation.h" @@ -27,6 +26,20 @@ namespace mlir::iree_compiler::IREE::Flow { namespace { +// Convert a `Value` or an `Attribute` range to a range of `OpFoldResult`. +template +static void toOpFoldResults(Range &&range, OutIt outIt) { + llvm::transform(std::forward(range), outIt, + [](auto v) { return OpFoldResult(v); }); +} + +template +static SmallVector toOpFoldResults(Range &&range) { + SmallVector res; + toOpFoldResults(std::forward(range), std::back_inserter(res)); + return res; +} + static bool hasMoreThanOneShard(Operation *op) { int shardCount = 0; op->walk([&shardCount](shard::ShardOp shard) { @@ -126,11 +139,12 @@ static bool isDefaultChannel(shard::GridOp grid, static Value getDefaultChannel(Location loc, shard::GridOp grid, bool useNamedDefaultChannels, OpBuilder &builder) { - if (useNamedDefaultChannels) + if (useNamedDefaultChannels) { return IREE::Flow::ChannelDefaultOp::create(builder, loc, grid.getSymName()); - else + } else { return IREE::Flow::ChannelDefaultOp::create(builder, loc); + } } static Value buildCachedChannelLoading(Location loc, shard::GridOp grid, @@ -241,8 +255,9 @@ static void createChannels(ModuleOp moduleOp, llvm::sort(gridAndAxesSetSorted, [](auto &a, auto &b) { int nameCompareRes = std::get<0>(a).getSymName().compare(std::get<0>(b).getSymName()); - if (nameCompareRes == 0) + if (nameCompareRes == 0) { return std::get<1>(a) < std::get<1>(b); + } return nameCompareRes < 0; }); for (auto &[shard, shardAxes] : llvm::make_range( @@ -279,8 +294,9 @@ static void removeShardOps(GridAndAxesSet &gridAndAxesSet) { DenseSet gridOpsSet(std::begin(gridRange), std::end(gridRange)); for (shard::GridOp op : gridOpsSet) { - if (op) + if (op) { op.erase(); + } } } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DeduplicateExecutables.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DeduplicateExecutables.cpp index 31769d71e67c..e76352da21cf 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DeduplicateExecutables.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DeduplicateExecutables.cpp @@ -19,8 +19,9 @@ namespace { // Utilities to make SymbolRefAttr easier to construct. static SymbolRefAttr nestSymbolRef(SymbolRefAttr baseRefAttr, FlatSymbolRefAttr leafRefAttr) { - if (!baseRefAttr) + if (!baseRefAttr) { return leafRefAttr; + } SmallVector nestedRefAttrs; llvm::append_range(nestedRefAttrs, baseRefAttr.getNestedReferences()); nestedRefAttrs.push_back(leafRefAttr); @@ -43,8 +44,9 @@ static void gatherReplacements( for (auto [oldNestedSymbolOp, newNestedSymbolOp] : llvm::zip_equal(nestedOldRegion.getOps(), nestedNewRegion.getOps())) { - if (!oldNestedSymbolOp.isPublic()) + if (!oldNestedSymbolOp.isPublic()) { continue; // ignore private symbols + } auto oldNestedSymbolRefAttr = nestSymbolRef(oldSymbolRefAttr, oldNestedSymbolOp); auto newNestedSymbolRefAttr = @@ -140,8 +142,9 @@ static int deduplicateObjects(Operation *scopeOp, // We could rely on SymbolDCE for this but that makes looking at IR dumps // harder as after this pass runs and until SymbolDCE runs there are lots of // dead objects in the output. - for (auto *op : deadOps) + for (auto *op : deadOps) { op->erase(); + } return deadOps.size(); } @@ -156,11 +159,13 @@ class DeduplicateExecutablesPass mlir::ModuleOp moduleOp = getOperation(); SmallVector allObjects; for (auto &op : moduleOp.getOps()) { - if (op.hasTrait()) + if (op.hasTrait()) { allObjects.push_back(&op); + } } - if (allObjects.empty()) + if (allObjects.empty()) { return; + } (void)deduplicateObjects(moduleOp, allObjects); // totalObjects = allObjects.size(); // objectsDeduplicated = deduplicateObjects(moduleOp, allObjects); diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DumpDispatchGraph.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DumpDispatchGraph.cpp index e2082748d7dd..1f05d356b691 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DumpDispatchGraph.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DumpDispatchGraph.cpp @@ -47,8 +47,9 @@ static const StringRef kShapeNone = "plain"; static const StringRef kShapeEllipse = "ellipse"; static StringRef getShape(Operation *op) { - if (isa(op)) + if (isa(op)) { return kShapeBox; + } return kShapeEllipse; } @@ -57,8 +58,9 @@ static StringRef getShape(Operation *op) { static int64_t getLargeAttributeSizeLimit() { // Use the default from the printer flags if possible. if (std::optional limit = - OpPrintingFlags().getLargeElementsAttrLimit()) + OpPrintingFlags().getLargeElementsAttrLimit()) { return *limit; + } return 16; } @@ -142,8 +144,9 @@ class GraphPrinter { void emitFunctions(ModuleOp module) { auto funcOps = module.getOps(); - if (funcOps.empty()) + if (funcOps.empty()) { return; + } emitGraph([&]() { for (auto funcOp : funcOps) { @@ -167,8 +170,9 @@ class GraphPrinter { /// Emit all edges. This function should be called after all nodes have been /// emitted. void emitAllEdgeStmts() { - for (const std::string &edge : edges) + for (const std::string &edge : edges) { os << edge << ";\n"; + } edges.clear(); } @@ -243,13 +247,16 @@ class GraphPrinter { // Do not label edges that start/end at a cluster boundary. Such edges are // clipped at the boundary, but labels are not. This can lead to labels // floating around without any edge next to them. - if (!n1.clusterId && !n2.clusterId) + if (!n1.clusterId && !n2.clusterId) { attrs["label"] = quoteString(escapeString(std::move(label))); + } // Use `ltail` and `lhead` to draw edges between clusters. - if (n1.clusterId) + if (n1.clusterId) { attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId); - if (n2.clusterId) + } + if (n2.clusterId) { attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId); + } edges.push_back(strFromOs([&](raw_ostream &os) { os << llvm::format("v%i -> v%i ", n1.id, n2.id); @@ -344,12 +351,14 @@ class GraphPrinter { } void annotateOperation(raw_ostream &os, Operation *op, AsmState &state) { - if (isa(op)) + if (isa(op)) { return; + } if (op->hasTrait() && - isa(op->getParentOp())) + isa(op->getParentOp())) { return; + } if (auto load = dyn_cast(op)) { printDispatchTensorLoad(os, load, state); @@ -385,18 +394,21 @@ class GraphPrinter { auto entryPoint = *dispatchOp.getEntryPointRefs().begin(); auto executableOp = cast(SymbolTable::lookupNearestSymbolFrom( dispatchOp, entryPoint.getRootReference())); - if (!executableOp) + if (!executableOp) { return; + } auto calleeNameAttr = entryPoint.getLeafReference(); auto innerModule = executableOp.getInnerModule(); - if (!innerModule) + if (!innerModule) { return; + } auto funcOps = innerModule.getOps(); auto funcIt = llvm::find_if( funcOps, [&](auto op) { return op.getNameAttr() == calleeNameAttr; }); - if (funcIt == funcOps.end()) + if (funcIt == funcOps.end()) { return; + } auto callee = *funcIt; @@ -506,25 +518,29 @@ class GraphPrinter { /// operation inside the cluster. void processBlock(Block &block) { emitClusterStmt([&]() { - for (BlockArgument &blockArg : block.getArguments()) + for (BlockArgument &blockArg : block.getArguments()) { valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg)); + } // Emit a node for each operation. std::optional prevNode; for (Operation &op : block) { Node nextNode = processOperation(&op); - if (printControlFlowEdges && prevNode) + if (printControlFlowEdges && prevNode) { emitEdgeStmt(*prevNode, nextNode, /*label=*/"", kLineStyleControlFlow); + } prevNode = nextNode; } }); } bool isScalarConstantOp(Operation *op) { - if (auto constOp = dyn_cast(op)) - if (constOp.getResult().getType().isIntOrIndexOrFloat()) + if (auto constOp = dyn_cast(op)) { + if (constOp.getResult().getType().isIntOrIndexOrFloat()) { return true; + } + } return false; } @@ -555,8 +571,9 @@ class GraphPrinter { // Emit cluster for op with regions. node = emitClusterStmt( [&]() { - for (Region ®ion : op->getRegions()) + for (Region ®ion : op->getRegions()) { processRegion(region); + } }, getLabel(op)); } else { @@ -578,22 +595,25 @@ class GraphPrinter { } } - for (Value result : op->getResults()) + for (Value result : op->getResults()) { valueToNode[result] = node; + } return node; } /// Process a region. void processRegion(Region ®ion) { - for (Block &block : region.getBlocks()) + for (Block &block : region.getBlocks()) { processBlock(block); + } } /// Truncate long strings. std::string truncateString(std::string str) { - if (str.length() <= maxLabelLen) + if (str.length() <= maxLabelLen) { return str; + } return str.substr(0, maxLabelLen) + "..."; } @@ -629,8 +649,9 @@ class DumpDispatchGraphPass void runOnOperation() override { auto modOp = dyn_cast(getOperation()); - if (!modOp) + if (!modOp) { return; + } // Open the output file we'll be streaming to. // Since we are processing the entire module at once we overwrite the file. diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp index b9b8aafc7c56..9c2644adf868 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp @@ -225,8 +225,9 @@ createEntryPointBenchmarkFunc(mlir::ModuleOp moduleOp, for (auto arg : entryFuncOp.getArguments()) { auto dummyVar = createDummyInput(funcName, arg, symbolTable, moduleBuilder, explorer); - if (!dummyVar) + if (!dummyVar) { return failure(); + } dummyInputVariableOps.push_back(dummyVar); } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp index 93c48013053d..5f830e2feb07 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp @@ -32,22 +32,25 @@ TensorDimTrackingRewriter::TensorDimTrackingRewriter(Operation *op) } SmallVector TensorDimTrackingRewriter::getTensorDimOps() { SmallVector result; - for (Operation *op : dimOps) + for (Operation *op : dimOps) { result.push_back(cast(op)); + } return result; } void TensorDimTrackingRewriter::notifyOperationErased(Operation *op) { IRRewriter::Listener::notifyOperationErased(op); - if (isa(op)) + if (isa(op)) { dimOps.erase(op); + } } void TensorDimTrackingRewriter::notifyOperationInserted(Operation *op, InsertPoint previous) { IRRewriter::Listener::notifyOperationInserted(op, previous); auto dimOp = dyn_cast(op); - if (dimOp && isa(dimOp.getSource())) + if (dimOp && isa(dimOp.getSource())) { dimOps.insert(op); + } } } // namespace mlir @@ -59,8 +62,9 @@ LogicalResult simplifyDimOps(RewriterBase &rewriter, for (tensor::DimOp dimOp : dimOps) { // Only DimOps with static indices are supported. std::optional idx = dimOp.getConstantIndex(); - if (!idx.has_value()) + if (!idx.has_value()) { continue; + } if (isa(dimOp.getSource())) { continue; @@ -68,8 +72,9 @@ LogicalResult simplifyDimOps(RewriterBase &rewriter, // Only DimOps with ranked tensors are supported. auto tensorType = dyn_cast(dimOp.getSource().getType()); - if (!tensorType) + if (!tensorType) { continue; + } OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(dimOp); @@ -85,9 +90,11 @@ LogicalResult simplifyDimOps(RewriterBase &rewriter, if (succeeded(IREE::Flow::getOptimizedDynamicResultDims( rewriter, dimOp.getSource(), dynamicDims))) { unsigned ctr = 0; - for (int64_t i = 0; i < *dimOp.getConstantIndex(); ++i) - if (tensorType.isDynamicDim(i)) + for (int64_t i = 0; i < *dimOp.getConstantIndex(); ++i) { + if (tensorType.isDynamicDim(i)) { ++ctr; + } + } rewriter.replaceOp(dimOp, dynamicDims[ctr]); } } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InjectTensorTracing.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InjectTensorTracing.cpp index 675566d83072..91459acae3e8 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InjectTensorTracing.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InjectTensorTracing.cpp @@ -34,8 +34,9 @@ static std::string inferTraceKey(Operation *op) { static SmallVector filterTensorValues(ValueRange &&range) { SmallVector result; for (auto value : range) { - if (isa(value.getType())) + if (isa(value.getType())) { result.push_back(value); + } } return result; } @@ -76,10 +77,11 @@ struct InjectTensorTracingPass funcOp.walk([&](Operation *op) { if (auto attr = op->getAttr(attrName)) { std::string traceKey; - if (auto stringAttr = dyn_cast(attr)) + if (auto stringAttr = dyn_cast(attr)) { traceKey = stringAttr.getValue().str(); - else + } else { traceKey = inferTraceKey(op); + } injectTracingOnOp(op, traceKey); op->removeAttr(attrName); } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InsertDispatchDebugTargets.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InsertDispatchDebugTargets.cpp index a34129fdf49e..701fe0e257aa 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InsertDispatchDebugTargets.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InsertDispatchDebugTargets.cpp @@ -29,8 +29,9 @@ namespace mlir::iree_compiler::IREE::Flow { static SmallVector filterNonTensorValues(ValueRange &&range) { SmallVector result; for (auto value : range) { - if (isa(value.getType())) + if (isa(value.getType())) { result.push_back(value); + } } return result; } @@ -39,18 +40,21 @@ static SmallVector filterNonTensorValues(ValueRange &&range) { // a negative ordinal indicating no match. static std::tuple getOrdinalFromDebugTarget(std::string marker) { - if (marker.empty() || marker[0] != '@') + if (marker.empty() || marker[0] != '@') { return std::make_tuple("", -1); + } SmallVector parts; auto cropped = marker.substr(1); llvm::SplitString(llvm::StringRef(cropped), parts, ":"); - if (parts.size() != 2) + if (parts.size() != 2) { return std::make_tuple("", -1); + } int ordinal; - if (parts[1].getAsInteger(10, ordinal)) + if (parts[1].getAsInteger(10, ordinal)) { return std::make_tuple("", -1); + } return std::make_tuple(parts[0].str(), ordinal); } @@ -78,18 +82,21 @@ static void traceOpWithName(IREE::Flow::DispatchOp dispatchOp, static LogicalResult replaceReturnWithOpResults(mlir::ModuleOp moduleOp, IREE::Util::FuncOp funcOp, Operation *op) { - if (!funcOp->isProperAncestor(op)) + if (!funcOp->isProperAncestor(op)) { return failure(); + } // TODO: Handle nested function calls. - if (!SymbolTable::symbolKnownUseEmpty(funcOp, moduleOp)) + if (!SymbolTable::symbolKnownUseEmpty(funcOp, moduleOp)) { return failure(); + } // TODO: Handle (nested) control flow. auto funcBlock = op->getBlock(); if (funcBlock->getParentOp() != funcOp || - &funcOp.getBody().front() != funcBlock) + &funcOp.getBody().front() != funcBlock) { return failure(); + } // Collect the op results and create export ops for any tensor results. OpBuilder builder(funcOp); @@ -119,8 +126,9 @@ static LogicalResult replaceReturnWithOpResults(mlir::ModuleOp moduleOp, rewriter.replaceOpWithNewOp(oldTerminator, exports); SmallVector argTypes; - for (const auto &arg : llvm::enumerate(funcOp.getArguments())) + for (const auto &arg : llvm::enumerate(funcOp.getArguments())) { argTypes.push_back(arg.value().getType()); + } funcOp.setType(FunctionType::get(context, /*inputs=*/argTypes, /*results=*/newTypes)); @@ -151,12 +159,14 @@ struct InsertDebugTargetAtOrdinalPass // Only look for dispatches in util func ops. auto funcOp = dyn_cast(operation); - if (!funcOp) + if (!funcOp) { continue; + } std::string fName = funcOp.getName().str(); - if (fName != breakFname && fName != traceFname) + if (fName != breakFname && fName != traceFname) { continue; + } int localBreakOrdinal = -1; if (fName == breakFname) { @@ -188,8 +198,9 @@ struct InsertDebugTargetAtOrdinalPass if (localBreakOrdinal >= 0 && localBreakOrdinal < dispatchOps.size()) { auto breakTarget = dispatchOps[localBreakOrdinal]; if (failed(replaceReturnWithOpResults(getOperation(), funcOp, - breakTarget))) + breakTarget))) { return signalPassFailure(); + } } } @@ -252,8 +263,9 @@ struct InsertDebugTargetAtSymbolPass Operation *operation = funcOp; auto mlirFuncOp = dyn_cast(operation); if (!mlirFuncOp || failed(replaceReturnWithOpResults( - getOperation(), mlirFuncOp, breakTarget))) + getOperation(), mlirFuncOp, breakTarget))) { return signalPassFailure(); + } } } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineConstants.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineConstants.cpp index 5a726890523a..ca984231f5f6 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineConstants.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineConstants.cpp @@ -49,8 +49,9 @@ static SmallVector findConstantsInModule(mlir::ModuleOp moduleOp) { SmallVector results; for (auto callableOp : moduleOp.getOps()) { auto *region = callableOp.getCallableRegion(); - if (!region) + if (!region) { continue; + } region->walk([&](Operation *op) { if (auto constantOp = dyn_cast(op)) { if (isOutlinableValue(constantOp.getValue())) { @@ -80,8 +81,9 @@ static Operation *getParentInOp(Operation *childOp, Operation *ancestorOp) { assert(childOp != ancestorOp && "child can't be its own ancestor"); do { auto *parentOp = childOp->getParentOp(); - if (parentOp == ancestorOp) + if (parentOp == ancestorOp) { return childOp; + } childOp = parentOp; } while (childOp); assert(false && "child must be nested under ancestor"); @@ -94,16 +96,18 @@ static std::string getConstantName(ConstantDef &def) { if (auto parameterAttr = dyn_cast(def.value)) { os << "__parameter_"; - if (parameterAttr.getScope() && !parameterAttr.getScope().empty()) + if (parameterAttr.getScope() && !parameterAttr.getScope().empty()) { os << parameterAttr.getScope().getValue() << "_"; + } os << parameterAttr.getKey().getValue() << "_"; } else { os << "__constant_"; } def.type.print(os); str = sanitizeSymbolName(str); - if (str.substr(str.size() - 1) == "_") + if (str.substr(str.size() - 1) == "_") { str = str.substr(0, str.size() - 1); // strip trailing _ + } return str; } @@ -115,8 +119,9 @@ struct OutlineConstantsPass : public IREE::Flow::impl::OutlineConstantsPassBase { void runOnOperation() override { mlir::ModuleOp moduleOp = getOperation(); - if (moduleOp.getBody()->empty()) + if (moduleOp.getBody()->empty()) { return; + } SymbolTable moduleSymbols(moduleOp); @@ -127,8 +132,9 @@ struct OutlineConstantsPass // contains the constant. OpBuilder moduleBuilder(&moduleOp.getBody()->front()); auto parentFuncOp = getParentInOp(def.op, moduleOp); - if (parentFuncOp) + if (parentFuncOp) { moduleBuilder.setInsertionPoint(parentFuncOp); + } // New immutable global takes the constant attribute in its specified // encoding. diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchExterns.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchExterns.cpp index 7344bc78b1c7..8ad5c0f1a0df 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchExterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchExterns.cpp @@ -152,10 +152,12 @@ struct OutlineDispatchExternsPass }) .Default(WalkResult::advance()); }; - if (funcOp.walk(outlineOps).wasInterrupted()) + if (funcOp.walk(outlineOps).wasInterrupted()) { return signalPassFailure(); - for (auto *deadOp : deadOps) + } + for (auto *deadOp : deadOps) { deadOp->erase(); + } } } }; diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp index bd52c6fb29df..60f06cd6cbc5 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp @@ -194,10 +194,12 @@ struct OutlineDispatchRegionsPass }) .Default(WalkResult::advance()); }; - if (funcOp.walk(outlineOps).wasInterrupted()) + if (funcOp.walk(outlineOps).wasInterrupted()) { return signalPassFailure(); - for (auto *deadOp : deadOps) + } + for (auto *deadOp : deadOps) { deadOp->erase(); + } } } }; diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp index 445560f0cf27..b023712db9b6 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp @@ -110,8 +110,9 @@ SmallVector getLoopRanges(Operation *op, Location loc, /// Return `true` if an operation is within a `flow.dispatch.region` or /// `flow.dispatch.workgroups` op. bool isNonNullAndOutsideDispatch(Operation *op) { - if (!op) + if (!op) { return false; + } Operation *parentOp = op->getParentOp(); while (parentOp) { if (isa( @@ -204,8 +205,9 @@ static void createWorkgroupCountFromDagRootRegion( RewriterBase &rewriter, IREE::Flow::DispatchRegionOp ®ionOp, TypeRange workloadTypes, ArrayRef workloadLocs) { Region &countRegion = regionOp.getWorkgroupCount(); - if (!countRegion.empty()) + if (!countRegion.empty()) { return; + } Block *body = rewriter.createBlock(&countRegion, countRegion.begin(), workloadTypes, workloadLocs); auto args = body->getArguments(); @@ -221,8 +223,9 @@ static void createWorkgroupCountFromDagRootRegion( /// dynamic dimension. static bool hasDynamicShape(Type t) { auto shapedType = dyn_cast(t); - if (!shapedType) + if (!shapedType) { return false; + } return !shapedType.hasStaticShape(); } @@ -234,8 +237,9 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value, OpBuilder::InsertionGuard guard(b); // Case 1: No dynamic result dims. - if (!hasDynamicShape(value.getType())) + if (!hasDynamicShape(value.getType())) { return success(); + } // There is at least one dynamic dimension, continue... ShapedType shapedType = cast(value.getType()); @@ -252,8 +256,9 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value, // Case 2: Value is a block argument. if (auto bbArg = dyn_cast(value)) { - if (!createTensorDimOps) + if (!createTensorDimOps) { return failure(); + } b.setInsertionPointToStart(bbArg.getOwner()); emitTensorDimOps(); @@ -277,20 +282,24 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value, auto tiedOp = dyn_cast(op); if (tiedOp) { Value tiedOperand = tiedOp.getTiedResultOperand(value); - if (tiedOperand && tiedOperand.getType() == value.getType()) + if (tiedOperand && tiedOperand.getType() == value.getType()) { return reifyDynamicResultDimsImpl(b, tiedOperand, dynamicDims, /*createTensorDimOps=*/true); + } } // Case 5: Query ReifyRankedShapedTypeOpInterface. auto reifyShapeOp = dyn_cast(op); if (reifyShapeOp) { ReifiedRankedShapedTypeDims dims; - if (failed(reifyShapeOp.reifyResultShapes(b, dims))) + if (failed(reifyShapeOp.reifyResultShapes(b, dims))) { return failure(); - for (int64_t i = 0; i < shapedType.getRank(); ++i) - if (shapedType.isDynamicDim(i)) + } + for (int64_t i = 0; i < shapedType.getRank(); ++i) { + if (shapedType.isDynamicDim(i)) { dynamicDims.push_back(cast(dims[opResult.getResultNumber()][i])); + } + } return success(); } @@ -303,8 +312,9 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value, /*createTensorDimOps=*/true); } - if (!createTensorDimOps) + if (!createTensorDimOps) { return failure(); + } // None of the above. Insert tensor.dim ops. b.setInsertionPointAfter(op); @@ -416,8 +426,9 @@ clonePrecedingOpIntoDispatchRegion(RewriterBase &rewriter, Operation *target, Region *parentRegion = parentOperation->getParentRegion(); while ((parentOperation = parentOperation->getParentOp())) { - if (regionOp.getOperation() == parentOperation) + if (regionOp.getOperation() == parentOperation) { break; + } parentRegion = parentOperation->getParentRegion(); } @@ -632,9 +643,8 @@ FailureOr hoistOutOfDispatch(RewriterBase &rewriter, return producer && producer->getParentOfType(); })) { rewriter.setInsertionPoint(dispatchRegionOp); - } else if (llvm::all_of(op->getUsers(), [&](Operation *user) { - return isa(user); - })) { + } else if (llvm::all_of(op->getUsers(), + llvm::IsaPred)) { rewriter.setInsertionPointAfter(dispatchRegionOp); } else { return rewriter.notifyMatchFailure( @@ -876,8 +886,9 @@ bool isClonableIntoDispatchOp(Operation *op, } if (isa(op) || isa(op)) { - if (clInlineConstantByteLength == 0) + if (clInlineConstantByteLength == 0) { return false; + } Attribute constantValueAttr; if (!matchPattern(op->getResult(0), m_Constant(&constantValueAttr))) { return false; @@ -930,13 +941,15 @@ static bool hasUnfusableUseInDispatch(Value v, Operation *dispatchOp) { Operation *owner = ownerWorkgroupsOp ? ownerWorkgroupsOp : ownerRegionOp; // Ignore uses outside of dispatch workgroups op. - if (owner != dispatchOp) + if (owner != dispatchOp) { continue; + } // Cannot fuse producer of `dest` with `tensor.insert_slice`. if (auto insertSliceUser = dyn_cast(user)) { - if (insertSliceUser.getDest() == v) + if (insertSliceUser.getDest() == v) { return true; + } } } return false; @@ -948,8 +961,9 @@ SmallVector getCloneableOps(IREE::Flow::DispatchRegionOp regionOp, // of the dispatch region. llvm::SetVector valuesDefinedAbove; mlir::getUsedValuesDefinedAbove(regionOp.getBody(), valuesDefinedAbove); - if (valuesDefinedAbove.empty()) + if (valuesDefinedAbove.empty()) { return {}; + } // Traverse the defining ops of these values (and the ops on their reverse // SSA use-def chain). @@ -960,8 +974,9 @@ SmallVector getCloneableOps(IREE::Flow::DispatchRegionOp regionOp, while (!worklist.empty()) { Value outsideValue = worklist.pop_back_val(); // Skip values that were already visited. - if (visited.count(outsideValue)) + if (visited.count(outsideValue)) { continue; + } visited.insert(outsideValue); Operation *definingOp = outsideValue.getDefiningOp(); diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/TopLevelSCFToCFG.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/TopLevelSCFToCFG.cpp index 4090026d6305..ebb45e4be18a 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/TopLevelSCFToCFG.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/TopLevelSCFToCFG.cpp @@ -43,9 +43,10 @@ void TopLevelSCFToCFGPass::runOnOperation() { target.addLegalOp(); target.markOpRecursivelyLegal(); - if (failed( - applyPartialConversion(getOperation(), target, std::move(patterns)))) + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { signalPassFailure(); + } } } // namespace mlir::iree_compiler::IREE::Flow diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel index 44335a5d6b87..cfa0a91985f9 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "annotate_dispatches.mlir", "canonicalize.mlir", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp index 14790bf2fb90..cb37aae40366 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp @@ -191,8 +191,9 @@ BindingLayoutAnalysis::BindingLayoutAnalysis(Operation *rootOp, // before we derive the layouts. auto getExportInfo = [&](Operation *exportOp) -> ExportInfo & { auto &exportInfo = exportInfos[exportOp]; - if (!exportInfo) + if (!exportInfo) { exportInfo = std::make_unique(); + } return *exportInfo; }; rootOp->walk([&](Operation *op) { @@ -238,8 +239,9 @@ bool BindingLayoutAnalysis::hasDispatches() const { ArrayRef BindingLayoutAnalysis::getExportDispatches(Operation *exportOp) const { auto it = exportInfos.find(exportOp); - if (it == exportInfos.end()) + if (it == exportInfos.end()) { return {}; // not analyzed + } return it->second.get()->dispatchOps; } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD.bazel index b38bbbea7b32..a4085ef63ddd 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "device_ops.mlir", "pseudo_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertBufferViewOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertBufferViewOps.cpp index 931c825d98f2..35bf462c11bf 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertBufferViewOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertBufferViewOps.cpp @@ -18,9 +18,10 @@ struct ElementTypeOpConversion ConversionPatternRewriter &rewriter) const override { auto value = IREE::HAL::ElementTypeOp::getTypeValue(op.getTypeAttr().getValue()); - if (!value.has_value()) + if (!value.has_value()) { return rewriter.notifyMatchFailure(op.getLoc(), "unsupported element type"); + } rewriter.replaceOpWithNewOp(op, value.value()); return success(); } @@ -33,9 +34,10 @@ struct EncodingTypeOpConversion matchAndRewrite(IREE::HAL::EncodingTypeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto value = IREE::HAL::EncodingTypeOp::getTypeValue(op.getEncodingAttr()); - if (!value.has_value()) + if (!value.has_value()) { return rewriter.notifyMatchFailure(op.getLoc(), "unsupported encoding type"); + } rewriter.replaceOpWithNewOp(op, value.value()); return success(); } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp index 42f92ae80e93..8c0917a718ab 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp @@ -68,14 +68,16 @@ class CommandBufferCreateOpConversion }; auto modesValue = detail::rewriteAttrToOperands( op.getLoc(), adaptor.getModesAttr(), rewriter.getI32Type(), rewriter); - if (!modesValue.has_value()) + if (!modesValue.has_value()) { return failure(); + } callOperands.append(modesValue.value()); auto categoriesValue = detail::rewriteAttrToOperands( op.getLoc(), adaptor.getCommandCategoriesAttr(), rewriter.getI32Type(), rewriter); - if (!categoriesValue.has_value()) + if (!categoriesValue.has_value()) { return failure(); + } callOperands.append(categoriesValue.value()); callOperands.push_back(adaptor.getQueueAffinity()); if (adaptor.getBindingCapacity()) { diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp index 414cd14a46c0..82cd2872c849 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp @@ -25,10 +25,12 @@ class DeviceQueryCastOpConversion matchAndRewrite(IREE::HAL::DeviceQueryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto targetType = op.getValue().getType(); - if (targetType.isInteger(64)) + if (targetType.isInteger(64)) { return failure(); // handled natively - if (!targetType.isIntOrIndex()) + } + if (!targetType.isIntOrIndex()) { return rewriter.notifyMatchFailure(op, "unsupported result type"); + } // Query as i64. // Note that due to type conversion we need to handle the default logic @@ -94,12 +96,14 @@ class DeviceQueryI64OpConversion LogicalResult matchAndRewrite(IREE::HAL::DeviceQueryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!op.getValue().getType().isInteger(64)) + if (!op.getValue().getType().isInteger(64)) { return failure(); + } auto results = rewriteToCall(op, adaptor, importOp, *getTypeConverter(), rewriter); - if (!results.has_value()) + if (!results.has_value()) { return failure(); + } auto ok = results->front(); auto value = results->back(); if (op.getDefaultValue().has_value()) { diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp index 5b843f23bbab..771fac901c7a 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp @@ -56,8 +56,9 @@ Value createPackedConstantBuffer(Location loc, ValueRange constantValues, // extra IR for the indices. We should batch them up and append in one go. for (auto constantValue : llvm::enumerate(constantValues)) { // Buffer is zero-initialized so we can skip zero values. - if (mlir::matchPattern(constantValue.value(), m_Zero())) + if (mlir::matchPattern(constantValue.value(), m_Zero())) { continue; + } auto constantLoc = constantValue.value().getLoc(); IREE::VM::BufferStoreI32Op::create( builder, constantLoc, constantBuffer, diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/BUILD.bazel index c949d6efdf21..0a16a83b5434 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "allocator_ops.mlir", "buffer_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD.bazel index 779a62ec7217..853292ccd18f 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "shape_ops.mlir", ], diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp index 32442e52c18b..66c4f7490af9 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp @@ -703,8 +703,9 @@ convertCollectiveAttr(IREE::Stream::CollectiveAttr sourceAttr) { auto convertReductionOp = [](std::optional op) -> std::optional { - if (!op.has_value()) + if (!op.has_value()) { return std::nullopt; + } return static_cast(op.value()); }; return IREE::HAL::CollectiveAttr::get( @@ -1201,11 +1202,13 @@ static void insertSerializationBarriers(Location loc, Block &block, // Note that we can't mutate the block while iterating it so we first grab // all the original ops. SmallVector serialOps; - for (auto &op : block) + for (auto &op : block) { serialOps.push_back(&op); + } for (auto *op : serialOps) { - if (op->hasTrait()) + if (op->hasTrait()) { continue; + } builder.setInsertionPointAfter(op); IREE::HAL::CommandBufferExecutionBarrierOp::create( builder, loc, commandBuffer, sourceStage, targetStage, flags); @@ -1711,10 +1714,12 @@ struct GlobalTimepointConversionPattern matchAndRewrite(IREE::Util::GlobalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto initialValue = op.getInitialValue(); - if (!initialValue.has_value()) + if (!initialValue.has_value()) { return failure(); - if (!isa(*initialValue)) + } + if (!isa(*initialValue)) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.removeInitialValueAttr(); }); return success(); } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp index 79763bc7cc80..a4a629358ce1 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp @@ -104,8 +104,9 @@ lookupAllocatorAndQueueAffinityFor(Operation *op, Value memoryTypes, Value getOrCreateWaitFence(Location loc, Value timepointFence, PatternRewriter &rewriter) { - if (timepointFence) + if (timepointFence) { return timepointFence; + } return IREE::Util::NullOp::create(rewriter, loc, rewriter.getType()); } @@ -115,18 +116,21 @@ Value getOrCreateWaitFence(Location loc, Value timepointFence, // it is the only consumer of the timepoint and it is removed upon return. static Value consumeBoundFence(Value timepoint, PatternRewriter &rewriter) { // Must only have one use. We can't consume a fence multiple times. - if (!timepoint.hasOneUse()) + if (!timepoint.hasOneUse()) { return nullptr; // >1 use + } // The use must be an export to a fence. auto chainOp = dyn_cast( *timepoint.getUsers().begin()); - if (!chainOp) + if (!chainOp) { return nullptr; // non-export use + } assert(!chainOp.getExternalValues().empty()); auto fence = chainOp.getExternalValues().front(); - if (!fence || !isa(fence.getType())) + if (!fence || !isa(fence.getType())) { return nullptr; + } // Try really hard to figure out if the fence can be used. A larger analysis // pass running prior to conversion that did some code motion could help @@ -157,8 +161,9 @@ Value getOrCreateSignalFence(Location loc, Value device, Value timepoint, // Check to see if the timepoint is associated with a fence. In common cases // when along ABI boundaries we can usually find an association. auto fence = consumeBoundFence(timepoint, rewriter); - if (fence) + if (fence) { return fence; + } // Create a new fence. return IREE::HAL::FenceCreateOp::create( diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/BUILD.bazel index 2d6f7779493a..c91e92dab276 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "channel_ops.mlir", "cmd_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/UtilToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/UtilToHAL/Patterns.cpp index 0abf92970a0c..b5f633c5f4ca 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/UtilToHAL/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/UtilToHAL/Patterns.cpp @@ -24,8 +24,9 @@ struct GlobalConversionPattern matchAndRewrite(IREE::Util::GlobalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto newType = getTypeConverter()->convertType(op.getType()); - if (newType == op.getType()) + if (newType == op.getType()) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { // NOTE: the initial value may be invalid here! We rely on // dialect-specific conversions to handle it. diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/UtilToHAL/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/UtilToHAL/test/BUILD.bazel index b9874626c5aa..f35d67c7b58b 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/UtilToHAL/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/UtilToHAL/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted ["global_ops.mlir"], include = ["*.mlir"], ), diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel index 9d935df7f9b5..cac7679b8232 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel @@ -21,6 +21,7 @@ exports_files([ iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "HALAttrs.td", "HALBase.td", diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp index afbed1294831..f1cbee7207e5 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp @@ -833,8 +833,9 @@ Value IREE::HAL::DeviceSelectAttr::buildDeviceEnumeration( auto deviceAttr = deviceAttrs.front(); Value tryDevice = deviceAttr.buildDeviceEnumeration( loc, buildDeviceTargetMatch, tryBuilder); - if (deviceAttrs.size() == 1) + if (deviceAttrs.size() == 1) { return tryDevice; // termination case + } Value isNull = IREE::Util::CmpEQOp::create(tryBuilder, loc, tryDevice, nullDevice); auto ifOp = @@ -868,8 +869,9 @@ Attribute DeviceAffinityAttr::parse(AsmParser &p, Type type) { queueMask = 0; if (failed(p.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() { int64_t i = 0; - if (failed(p.parseInteger(i))) + if (failed(p.parseInteger(i))) { return failure(); + } queueMask |= 1ll << i; return success(); }))) { @@ -991,8 +993,9 @@ Attribute DevicePromiseAttr::parse(AsmParser &p, Type type) { queueMask = 0; if (failed(p.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() { int64_t i = 0; - if (failed(p.parseInteger(i))) + if (failed(p.parseInteger(i))) { return failure(); + } queueMask |= 1ll << i; return success(); }))) { @@ -1119,10 +1122,12 @@ bool DeviceTopologyAttr::hasTransparentAccess( Attribute sourceDevice = getAffinityDevice(source); Attribute targetDevice = getAffinityDevice(target); - if (!sourceDevice || !targetDevice) + if (!sourceDevice || !targetDevice) { return false; - if (sourceDevice == targetDevice) + } + if (sourceDevice == targetDevice) { return true; // Same device has transparent access. + } // Search for a matching link and check if it has transparent access. for (DeviceLinkAttr link : getLinks()) { @@ -1140,10 +1145,12 @@ bool DeviceTopologyAttr::hasUnifiedMemory( Attribute sourceDevice = getAffinityDevice(source); Attribute targetDevice = getAffinityDevice(target); - if (!sourceDevice || !targetDevice) + if (!sourceDevice || !targetDevice) { return false; - if (sourceDevice == targetDevice) + } + if (sourceDevice == targetDevice) { return true; // Same device has unified memory. + } // Search for a matching link and check if it has unified memory. for (DeviceLinkAttr link : getLinks()) { diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp index b984b4a74ceb..a03d8ca44067 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp @@ -34,8 +34,9 @@ struct ElideUnusedOp : public OpRewritePattern { : OpRewritePattern(context, /*benefit=*/1000) {} LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const override { - if (!op.use_empty()) + if (!op.use_empty()) { return failure(); + } rewriter.eraseOp(op); return success(); } @@ -230,8 +231,9 @@ struct FoldBufferViewCreateSubspan needsUpdate = true; } rewriter.restoreInsertionPoint(ip); - if (!needsUpdate) + if (!needsUpdate) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.getSourceBufferMutable().assign(newSourceBuffer); op.getSourceOffsetMutable().assign(newSourceOffset); @@ -318,8 +320,9 @@ struct FoldCommandBufferFillBufferSubspans needsUpdate = true; } rewriter.restoreInsertionPoint(ip); - if (!needsUpdate) + if (!needsUpdate) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.getTargetBufferMutable().assign(newTargetBuffer); op.getTargetOffsetMutable().assign(newTargetOffset); @@ -358,8 +361,9 @@ struct FoldCommandBufferUpdateBufferSubspans needsUpdate = true; } rewriter.restoreInsertionPoint(ip); - if (!needsUpdate) + if (!needsUpdate) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.getTargetBufferMutable().assign(newTargetBuffer); op.getTargetOffsetMutable().assign(newTargetOffset); @@ -408,8 +412,9 @@ struct FoldCommandBufferCopyBufferSubspans needsUpdate = true; } rewriter.restoreInsertionPoint(ip); - if (!needsUpdate) + if (!needsUpdate) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.getSourceBufferMutable().assign(newSourceBuffer); op.getSourceOffsetMutable().assign(newSourceOffset); @@ -444,8 +449,9 @@ struct FoldCommandBufferDispatchBufferSubspan : public OpRewritePattern { auto bindingOffsets = llvm::to_vector(op.getBindingOffsets()); for (size_t i = 0; i < bindingBuffers.size(); ++i) { auto *definingOp = bindingBuffers[i].getDefiningOp(); - if (!definingOp) + if (!definingOp) { continue; + } if (auto subspanOp = dyn_cast(definingOp)) { needsUpdate = true; bindingBuffers[i] = subspanOp.getSourceBuffer(); @@ -454,8 +460,9 @@ struct FoldCommandBufferDispatchBufferSubspan : public OpRewritePattern { } } rewriter.restoreInsertionPoint(ip); - if (!needsUpdate) + if (!needsUpdate) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { auto mutableBindingBuffers = op.getBindingBuffersMutable(); mutableBindingBuffers.clear(); @@ -489,8 +496,9 @@ struct FoldCommandBufferDispatchIndirectBufferSubspan PatternRewriter &rewriter) const override { Value workgroupsBuffer = op.getWorkgroupsBuffer(); auto *definingOp = workgroupsBuffer.getDefiningOp(); - if (!definingOp) + if (!definingOp) { return failure(); + } Value workgroupsOffset = op.getWorkgroupsOffset(); if (auto subspanOp = dyn_cast(definingOp)) { workgroupsBuffer = subspanOp.getSourceBuffer(); @@ -526,18 +534,21 @@ void CommandBufferDispatchIndirectOp::getCanonicalizationPatterns( // same basic block. We need an abstract interpreter to do much more as we'd // need to track conditionals/branching logic. static bool isOpAlwaysExecutedWith(Operation *before, Operation *after) { - if (before == after) + if (before == after) { return true; - if (before->getBlock() != after->getBlock()) + } + if (before->getBlock() != after->getBlock()) { return false; + } return before->isBeforeInBlock(after); } // Returns true if |op| was hoisted before |insertBefore| without breaking // SSA invariants. Returns false if no IR modifications were made. static bool tryHoistOpBeforeUser(Operation *op, Operation *insertBefore) { - if (op == insertBefore) + if (op == insertBefore) { return false; + } // Currently conservative - should be doing a domination check. if (op->getBlock() != insertBefore->getBlock()) { @@ -811,11 +822,13 @@ static void rewriteToOneReturn(int numResults, Region ®ion, // Get all of the return ops - if there's only one then the requirement is // already satisfied and we can exit early. auto returnOps = llvm::to_vector(region.getOps()); - if (returnOps.size() <= 1) + if (returnOps.size() <= 1) { return; // no-op + } SmallVector returnLocs; - for (auto returnOp : returnOps) + for (auto returnOp : returnOps) { returnLocs.push_back(returnOp.getLoc()); + } // Create the new exit block with arguments matching 1:1 with results. auto anyReturnOp = returnOps.front(); @@ -860,8 +873,9 @@ struct MergeExecutableConstantBlocks SmallVector resultLocs; for (auto blockOp : blockOps) { blockLocs.push_back(blockOp.getLoc()); - if (blockOp.getNumArguments() > 0) + if (blockOp.getNumArguments() > 0) { anyRequireDevice = true; + } llvm::append_range(resultTypes, blockOp.getResultTypes()); llvm::append_range(resultKeys, blockOp.getKeys().getValue()); llvm::append_range( @@ -967,8 +981,9 @@ static void filterReturnOperands(ExecutableConstantBlockOp blockOp, llvm::make_early_inc_range(blockOp.getOps())) { SmallVector operands; for (auto [i, operand] : llvm::enumerate(returnOp.getOperands())) { - if (preservedIndices.test(i)) + if (preservedIndices.test(i)) { operands.push_back(operand); + } } returnOp.getOperandsMutable().assign(operands); } @@ -980,11 +995,13 @@ struct DropUnusedExecutableConstantBlockDeviceArg using Base::Base; LogicalResult matchAndRewrite(ExecutableConstantBlockOp blockOp, PatternRewriter &rewriter) const override { - if (blockOp.getNumArguments() == 0) + if (blockOp.getNumArguments() == 0) { return failure(); + } auto deviceArg = blockOp.getArgument(0); - if (!deviceArg.use_empty()) + if (!deviceArg.use_empty()) { return failure(); + } rewriter.modifyOpInPlace(blockOp, [&]() { // Type conversion here shouldn't fail. (void)blockOp.eraseArgument(0); @@ -1057,8 +1074,9 @@ void FenceCreateOp::getCanonicalizationPatterns(RewritePatternSet &results, //===----------------------------------------------------------------------===// OpFoldResult FenceJoinOp::fold(FoldAdaptor operands) { - if (getFences().size() == 1) + if (getFences().size() == 1) { return getFences().front(); + } return {}; } @@ -1069,8 +1087,9 @@ struct ElideEmptyFenceJoin : public OpRewritePattern { using Base::Base; LogicalResult matchAndRewrite(FenceJoinOp op, PatternRewriter &rewriter) const override { - if (op.getNumOperands() != 0) + if (op.getNumOperands() != 0) { return failure(); + } rewriter.replaceOpWithNewOp(op, op.getResult().getType()); return success(); @@ -1091,8 +1110,9 @@ deduplicateFenceOperands(ValueRange operands) { newOperands.insert(operand); } - if (newOperands.size() == operands.size()) + if (newOperands.size() == operands.size()) { return std::nullopt; + } return newOperands.takeVector(); } @@ -1102,8 +1122,9 @@ struct DeduplicateFenceJoinFences : public OpRewritePattern { LogicalResult matchAndRewrite(FenceJoinOp op, PatternRewriter &rewriter) const override { auto newOperands = deduplicateFenceOperands(op.getFences()); - if (!newOperands) + if (!newOperands) { return failure(); + } rewriter.replaceOpWithNewOp( op, op.getResult().getType(), op.getFlagsAttr(), newOperands.value()); return success(); @@ -1143,8 +1164,9 @@ struct ElideSignaledFence : public OpRewritePattern { auto fence = signalOp.getFence(); auto createOp = dyn_cast_if_present(fence.getDefiningOp()); - if (!createOp) + if (!createOp) { return failure(); + } // TODO(benvanik): broader analysis - likely in a dedicated fence elision // pass so we can do IPO. For now block-only. @@ -1194,8 +1216,9 @@ struct ElideEmptyFenceAwait : public OpRewritePattern { using Base::Base; LogicalResult matchAndRewrite(FenceAwaitOp op, PatternRewriter &rewriter) const override { - if (!op.getFences().empty()) + if (!op.getFences().empty()) { return failure(); + } rewriter.replaceOpWithNewOp(op, /*ok=*/0, 32); return success(); } @@ -1207,8 +1230,9 @@ struct DeduplicateFenceAwaitFences : public OpRewritePattern { LogicalResult matchAndRewrite(FenceAwaitOp op, PatternRewriter &rewriter) const override { auto newOperands = deduplicateFenceOperands(op.getFences()); - if (newOperands == std::nullopt) + if (newOperands == std::nullopt) { return failure(); + } // TODO(benvanik): resolve flag sets. rewriter.replaceOpWithNewOp( op, op.getStatus().getType(), op.getTimeoutMillis(), op.getFlagsAttr(), diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index 157e82cf0651..445ba69e2870 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp @@ -120,12 +120,14 @@ static void printDeviceQueueAffinityList(OpAsmPrinter &p, Operation *, static ParseResult parseDescriptorType(OpAsmParser &parser, DescriptorTypeAttr &dtAttr) { StringRef enumKeyword; - if (failed(parser.parseKeyword(&enumKeyword))) + if (failed(parser.parseKeyword(&enumKeyword))) { return failure(); + } std::optional maybeEnum = symbolizeDescriptorType(enumKeyword); - if (!maybeEnum) + if (!maybeEnum) { return failure(); + } dtAttr = DescriptorTypeAttr::get(parser.getContext(), *maybeEnum); return success(); } @@ -381,8 +383,9 @@ static ParseResult parseTargetConditionRegion(OpAsmParser &parser, static void printTargetConditionRegion(OpAsmPrinter &p, Operation *op, Region &body) { - if (body.empty()) + if (body.empty()) { return; + } p << "("; llvm::interleaveComma(body.getArguments(), p, [&](BlockArgument arg) { p.printRegionArgument(arg); }); @@ -403,15 +406,17 @@ static ParseResult parseTargetConditionObjects( do { // #hal.executable.target<...> Attribute targetAttr; - if (failed(parser.parseAttribute(targetAttr))) + if (failed(parser.parseAttribute(targetAttr))) { return failure(); + } targetsAttrs.push_back(targetAttr); // if(...) -> i1 { ... } auto region = std::make_unique(); if (succeeded(parser.parseOptionalKeyword("if"))) { - if (failed(parseTargetConditionRegion(parser, *region))) + if (failed(parseTargetConditionRegion(parser, *region))) { return failure(); + } } targetRegions.push_back(std::move(region)); @@ -421,15 +426,17 @@ static ParseResult parseTargetConditionObjects( failed(parser.parseLParen()) || failed(parser.parseAttribute(targetOrdinalAttr, IndexType::get(parser.getContext()))) || - failed(parser.parseRParen())) + failed(parser.parseRParen())) { return failure(); + } targetOrdinalsAttrs.push_back(targetOrdinalAttr); // = [#hal.executable.object<...>, ...] ArrayAttr targetObjectsAttr; if (failed(parser.parseEqual()) || - failed(parser.parseAttribute(targetObjectsAttr))) + failed(parser.parseAttribute(targetObjectsAttr))) { return failure(); + } targetObjectsAttrs.push_back(targetObjectsAttr); } while (succeeded(parser.parseOptionalComma())); targetsAttr = ArrayAttr::get(parser.getContext(), targetsAttrs); @@ -506,8 +513,9 @@ static ParseResult parseWorkgroupCountRegion(OpAsmParser &parser, static void printWorkgroupCountRegion(OpAsmPrinter &p, Operation *op, Region &body) { - if (body.empty()) + if (body.empty()) { return; + } p << "("; llvm::interleaveComma(body.getArguments(), p, [&](BlockArgument arg) { p.printRegionArgument(arg); }); @@ -550,8 +558,9 @@ static ParseResult parseExportConditionRegion(OpAsmParser &parser, static void printExportConditionRegion(OpAsmPrinter &p, Operation *op, Region &body) { - if (body.empty()) + if (body.empty()) { return; + } p << "("; llvm::interleaveComma(body.getArguments(), p, [&](BlockArgument arg) { p.printRegionArgument(arg); }); @@ -627,8 +636,9 @@ void TensorImportOp::build(OpBuilder &builder, OperationState &result, "information is required"); SmallVector dynamicDims; for (int64_t i = 0; i < shapedType.getRank(); ++i) { - if (!shapedType.isDynamicDim(i)) + if (!shapedType.isDynamicDim(i)) { continue; + } dynamicDims.push_back(builder.createOrFold( result.location, builder.getIndexType(), source, builder.getIndexAttr(i))); @@ -641,12 +651,14 @@ void TensorImportOp::build(OpBuilder &builder, OperationState &result, static LogicalResult verifyTypeStorageCompatibility(Operation *op, Type encodingType, Type storageType) { - if (encodingType == storageType) + if (encodingType == storageType) { return success(); + } auto encodingShapedType = dyn_cast(encodingType); auto storageShapedType = dyn_cast(storageType); - if (!encodingShapedType || !storageShapedType) + if (!encodingShapedType || !storageShapedType) { return success(); + } if (IREE::Util::getRoundedElementByteWidth( encodingShapedType.getElementType()) != @@ -832,8 +844,9 @@ void DispatchExternOp::build(OpBuilder &builder, OperationState &state, state.addRegion(); // Add one empty region per target. - for (size_t i = 0; i < targetObjects.getTargets().size(); ++i) + for (size_t i = 0; i < targetObjects.getTargets().size(); ++i) { state.addRegion(); + } } // Verifies that |dynamicDims| contains the appropriate number of dims for all @@ -885,8 +898,9 @@ static LogicalResult verifyWorkgroupCountWorkload(Operation *op, // Verifies that the workgroup count region matches the expected // signature. Returns success if the region is empty. static LogicalResult verifyWorkgroupCountRegion(Operation *op, Region ®ion) { - if (region.empty()) + if (region.empty()) { return success(); + } // Verify one of the supported signatures. bool validArguments = true; @@ -946,12 +960,14 @@ LogicalResult DispatchExternOp::verify() { return success(); }; for (auto type : getOperandTypes()) { - if (failed(verifyIOType(type))) + if (failed(verifyIOType(type))) { return failure(); + } } for (auto type : getResultTypes()) { - if (failed(verifyIOType(type))) + if (failed(verifyIOType(type))) { return failure(); + } } if (failed(verifyWorkgroupCountRegion(op, getWorkgroupCount()))) { @@ -1219,16 +1235,18 @@ LogicalResult ElementTypeOp::verify() { // static std::optional EncodingTypeOp::getTypeValue(Attribute attr) { // TODO(#6762): encoding attribute handling/mapping to enums. - if (attr) + if (attr) { return std::nullopt; + } // Default to IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR for now. return 1; } void EncodingTypeOp::getAsmResultNames( function_ref setNameFn) { - if (!getEncodingAttr()) + if (!getEncodingAttr()) { setNameFn(getResult(), "dense_row_major"); + } } LogicalResult EncodingTypeOp::verify() { @@ -1637,9 +1655,10 @@ LogicalResult ExecutableSourceOp::verify() { ExecutableSourceOp op = *this; auto conditionOps = getOps(); - if (llvm::range_size(conditionOps) > 1) + if (llvm::range_size(conditionOps) > 1) { return op.emitOpError() << "only one condition op is allowed in an executable"; + } return success(); } @@ -1668,8 +1687,9 @@ LogicalResult ExecutableOp::verify() { // signature. Returns success if the region is empty. static LogicalResult verifyExportConditionRegion(Operation *op, Region ®ion) { - if (region.empty()) + if (region.empty()) { return success(); + } // Verify one of the supported signatures. bool validArguments = true; @@ -1937,8 +1957,9 @@ LogicalResult ExecutableVariantOp::verify() { ExecutableVariantOp op = *this; auto conditionOps = getOps(); - if (llvm::range_size(conditionOps) > 1) + if (llvm::range_size(conditionOps) > 1) { return op.emitOpError() << "only one condition op is allowed in a variant"; + } return success(); } @@ -2016,13 +2037,15 @@ void ExecutableConditionOp::build(OpBuilder &builder, OperationState &result, ParseResult ExecutableConditionOp::parse(OpAsmParser &parser, OperationState &result) { - if (parseTargetConditionRegion(parser, *result.addRegion())) + if (parseTargetConditionRegion(parser, *result.addRegion())) { return failure(); + } result.addAttribute( "function_type", TypeAttr::get(getTargetConditionRegionType(parser.getContext()))); - if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) { return failure(); + } return success(); } @@ -2068,8 +2091,9 @@ ParseResult ExecutableConstantBlockOp::parse(OpAsmParser &parser, return failure(); } SmallVector argTypes; - for (auto &arg : entryArgs) + for (auto &arg : entryArgs) { argTypes.push_back(arg.type); + } auto fnType = builder.getFunctionType(argTypes, resultTypes); result.addAttribute(getFunctionTypeAttrName(result.name), TypeAttr::get(fnType)); @@ -2078,20 +2102,23 @@ ParseResult ExecutableConstantBlockOp::parse(OpAsmParser &parser, // There must be one key per result. Note that we support omitted parens when // only one result is present. SmallVector keyAttrs; - if (failed(parser.parseKeyword("as"))) + if (failed(parser.parseKeyword("as"))) { return failure(); + } if (resultTypes.size() == 1) { std::string key; - if (failed(parser.parseString(&key))) + if (failed(parser.parseString(&key))) { return failure(); + } keyAttrs.push_back(builder.getStringAttr(key)); } else { if (failed(parser.parseCommaSeparatedList( AsmParser::Delimiter::OptionalParen, [&]() { std::string key; - if (failed(parser.parseString(&key))) + if (failed(parser.parseString(&key))) { return failure(); + } keyAttrs.push_back(builder.getStringAttr(key)); return success(); }, @@ -2138,12 +2165,14 @@ void ExecutableConstantBlockOp::print(OpAsmPrinter &p) { p, cast(op), argTypes, /*isVariadic=*/false, resultTypes); p << " as "; - if (resultTypes.size() != 1) + if (resultTypes.size() != 1) { p << '('; + } llvm::interleaveComma(getKeys().getValue(), p, [&](Attribute attr) { p << attr; }); - if (resultTypes.size() != 1) + if (resultTypes.size() != 1) { p << ')'; + } mlir::function_interface_impl::printFunctionAttributes( p, op, {getFunctionTypeAttrName(), getKeysAttrName()}); p << " "; @@ -2316,20 +2345,23 @@ llvm::Align InterfaceBindingSubspanOp::calculateAlignment() { // If the binding has no assigned alignment we fall back to natural alignment. auto baseAlignment = getBaseAlignment(); - if (!baseAlignment) + if (!baseAlignment) { return naturalAlignment; + } // If there's no offset specified then we can use the binding alignment // directly. - if (!getByteOffset()) + if (!getByteOffset()) { return baseAlignment.value(); + } // Try to get the alignment of the byte offset. If it's a constant then we can // find a common alignment between it and the base and otherwise we need to // try to infer the alignment from the IR - otherwise we fall back. auto offsetOrAlignment = lookupOffsetOrAlignment(getByteOffset()); - if (!offsetOrAlignment.has_value()) + if (!offsetOrAlignment.has_value()) { return naturalAlignment; + } // Compute the common alignment between that of the binding base and that of // the byte offset. diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp index c2ecfa3d94bb..340677b79dc5 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp @@ -21,8 +21,9 @@ namespace mlir::iree_compiler::IREE::HAL { //===----------------------------------------------------------------------===// llvm::MaybeAlign commonAlignment(llvm::MaybeAlign lhs, llvm::MaybeAlign rhs) { - if (!lhs.has_value() || !rhs.has_value()) + if (!lhs.has_value() || !rhs.has_value()) { return std::nullopt; + } return llvm::MaybeAlign( llvm::MinAlign(lhs.value().value(), rhs.value().value())); } @@ -37,8 +38,9 @@ std::optional lookupOffsetOrAlignment(Value value) { } auto op = value.getDefiningOp(); - if (!op) + if (!op) { return std::nullopt; + } if (auto alignmentAttr = op->getAttrOfType("stream.alignment")) { // The op has an alignment tagged on it we can use directly. return alignmentAttr.getValue().getZExtValue(); @@ -107,8 +109,9 @@ void HALDialect::registerTypes() { Type HALDialect::parseType(DialectAsmParser &parser) const { StringRef typeKind; - if (parser.parseKeyword(&typeKind)) + if (parser.parseKeyword(&typeKind)) { return {}; + } auto type = llvm::StringSwitch(typeKind) .Case("allocator", AllocatorType::get(getContext())) .Case("buffer", BufferType::get(getContext())) diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/IR/test/BUILD.bazel index 39c200fd5d06..18609d466992 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "allocator_folding.mlir", "allocator_ops.mlir", @@ -34,8 +35,8 @@ iree_lit_test_suite( "experimental_ops.mlir", "fence_folding.mlir", "fence_ops.mlir", - "interface_ops.mlir", "interface_folding.mlir", + "interface_ops.mlir", "invalid.mlir", "tensor_folding.mlir", "tensor_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.cpp index e341232dc131..9d5b1d0c3869 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.cpp @@ -49,16 +49,21 @@ void TargetOptions::bindOptions(OptionsBinder &binder) { "executable files (sources, benchmarks, intermediates, binaries) " "to."), llvm::cl::callback([&](const std::string &path) { - if (executableSourcesPath.empty()) + if (executableSourcesPath.empty()) { executableSourcesPath = path; - if (executableConfigurationsPath.empty()) + } + if (executableConfigurationsPath.empty()) { executableConfigurationsPath = path; - if (executableBenchmarksPath.empty()) + } + if (executableBenchmarksPath.empty()) { executableBenchmarksPath = path; - if (executableIntermediatesPath.empty()) + } + if (executableIntermediatesPath.empty()) { executableIntermediatesPath = path; - if (executableBinariesPath.empty()) + } + if (executableBinariesPath.empty()) { executableBinariesPath = path; + } }), llvm::cl::cat(halTargetOptionsCategory)); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetRegistry.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetRegistry.cpp index 34687e438195..23eae255b440 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetRegistry.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetRegistry.cpp @@ -155,8 +155,9 @@ bool llvm::cl::parser::parse(Option &O, StringRef ArgName, // We ignore Arg here and just use the global registry. We could parse a list // of target backends and create a new registry with just that subset but // ownership gets tricky. - if (Arg != "global") + if (Arg != "global") { return true; + } Val.value = &mlir::iree_compiler::IREE::HAL::TargetRegistry::getGlobal(); return false; } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/AnnotateTargetDevices.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/AnnotateTargetDevices.cpp index 337cfc0950a9..bbc81d624e2e 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/AnnotateTargetDevices.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/AnnotateTargetDevices.cpp @@ -91,8 +91,9 @@ static void annotateOperandsAndResults(Operation *op, static void annotateFuncOp(FunctionOpInterface funcOp, DeviceAnalysis &deviceAnalysis) { - if (funcOp.empty()) + if (funcOp.empty()) { return; + } for (auto arg : funcOp.front().getArguments()) { if (isa(arg.getType())) { funcOp.setArgAttr( @@ -117,8 +118,9 @@ struct AnnotateTargetDevicesPass // Annotate all ops with derived affinities. for (auto &op : getOperation().getOps()) { - if (op.hasTrait()) + if (op.hasTrait()) { continue; + } if (auto globalOp = dyn_cast(op)) { annotateGlobalOp(globalOp, deviceAnalysis); } else if (auto funcOp = dyn_cast(op)) { diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CaptureExecutableSources.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CaptureExecutableSources.cpp index 31424147a4f7..cd3c5bfd86d5 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CaptureExecutableSources.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CaptureExecutableSources.cpp @@ -35,8 +35,9 @@ static void insertDictionaryAttrEntry(Operation *op, StringRef dictionaryName, StringRef key, Attribute value) { NamedAttrList attrs; auto dictionaryAttr = op->getAttrOfType(dictionaryName); - if (dictionaryAttr) + if (dictionaryAttr) { attrs.assign(dictionaryAttr.getValue()); + } attrs.set(key, value); op->setAttr(dictionaryName, DictionaryAttr::get(op->getContext(), attrs)); } @@ -67,15 +68,17 @@ struct CaptureExecutableSourcesPass for (auto variantOp : executableOp.getOps()) { // Skip externally defined variants as there's no source to capture. - if (variantOp.isExternal()) + if (variantOp.isExternal()) { continue; + } // Ignore if there is already source assigned. auto fileName = (moduleName + "_" + executableOp.getName() + "_" + variantOp.getName() + "." + stage + ".mlir") .str(); - if (hasDictionaryAttrEntry(variantOp, "sources", fileName)) + if (hasDictionaryAttrEntry(variantOp, "sources", fileName)) { continue; + } // Create a standalone executable with just the variant being captured. // This allows the source to be passed to iree-compile in the diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConfigureExecutables.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConfigureExecutables.cpp index 5e2e466fa7fb..13effa02246d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConfigureExecutables.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConfigureExecutables.cpp @@ -49,8 +49,9 @@ class ConfigureTargetExecutableVariantsPass void runOnOperation() override { IREE::HAL::ExecutableVariantOp variantOp = getOperation(); - if (variantOp.getTarget().getBackend().getValue() != target) + if (variantOp.getTarget().getBackend().getValue() != target) { return; + } auto targetBackend = targetRegistry->getTargetBackend(target); if (!targetBackend) { diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp index e2a500172300..0ec0cde692f1 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp @@ -445,8 +445,9 @@ buildBenchmarkModule(IREE::HAL::ExecutableOp sourceExecutableOp, } // Skip the file when we could not generate any benchmarks. - if (!hasAnyBenchmarks) + if (!hasAnyBenchmarks) { return {}; + } IRRewriter rewriter(moduleOp->getContext()); DominanceInfo domInfo; @@ -478,8 +479,9 @@ struct DumpExecutableBenchmarksPass SymbolTable symbolTable(moduleOp); DeviceAnalysis deviceAnalysis(moduleOp); - if (failed(deviceAnalysis.run())) + if (failed(deviceAnalysis.run())) { return signalPassFailure(); + } if (deviceAnalysis.getDeviceGlobals().empty()) { mlir::emitRemark(moduleOp.getLoc()) << "Executable benchmarks were requested but no devices were " @@ -516,8 +518,9 @@ struct DumpExecutableBenchmarksPass executableOp.getOps()) { auto benchmarkModuleOp = buildBenchmarkModule( executableOp, variantOp, dispatchParamsMap, deviceAnalysis); - if (!benchmarkModuleOp) + if (!benchmarkModuleOp) { continue; + } auto fileName = (moduleName + "_" + executableOp.getName() + "_" + variantOp.getName() + "_benchmark.mlir") .str(); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ElideRedundantCommands.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ElideRedundantCommands.cpp index f496b19cc7a9..d2d284ed72c6 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ElideRedundantCommands.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ElideRedundantCommands.cpp @@ -91,8 +91,9 @@ struct ElideRedundantCommandsPass stateMap[commandBuffer].previousFullBarrier = {}; }; for (auto &op : llvm::make_early_inc_range(block.getOperations())) { - if (!op.getDialect()) + if (!op.getDialect()) { continue; + } TypeSwitch(&op) .Case([&](IREE::HAL::CommandBufferFinalizeOp op) { invalidateState(op.getCommandBuffer()); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeDispatchInstrumentation.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeDispatchInstrumentation.cpp index 96ff5613f0ee..7a75fc5b992d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeDispatchInstrumentation.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeDispatchInstrumentation.cpp @@ -28,8 +28,9 @@ namespace mlir::iree_compiler::IREE::HAL { namespace { static std::string getAttrStr(Attribute attr) { - if (!attr) + if (!attr) { return ""; + } std::string result; llvm::raw_string_ostream os(result); attr.print(os, /*elideType=*/true); @@ -69,8 +70,9 @@ static Value createChunkHeader(Location loc, iree_idbts_chunk_type_t type, static Value createPadding(Location loc, uint64_t unalignedLength, OpBuilder &builder) { uint64_t padding = llvm::alignTo(unalignedLength, 16) - unalignedLength; - if (!padding) + if (!padding) { return nullptr; + } auto i8Type = builder.getI8Type(); auto zeroAttr = IntegerAttr::get(i8Type, 0); auto dataAttr = DenseElementsAttr::get( @@ -107,8 +109,9 @@ struct MaterializeDispatchInstrumentationPass MaterializeDispatchInstrumentationPassBase; void runOnOperation() override { mlir::ModuleOp moduleOp = getOperation(); - if (moduleOp.getBody()->empty()) + if (moduleOp.getBody()->empty()) { return; + } auto moduleBuilder = OpBuilder(&moduleOp.getBody()->front()); auto i8Type = moduleBuilder.getI8Type(); @@ -170,8 +173,9 @@ struct MaterializeDispatchInstrumentationPass for (auto exportOp : executableOp.getOps()) { auto funcOp = exportOp.lookupFunctionRef(); - if (!funcOp) + if (!funcOp) { continue; + } // Capture the source before we mess with it. auto originalSource = getOpStr(funcOp); @@ -256,8 +260,9 @@ struct MaterializeDispatchInstrumentationPass break; } } - if (!functionId) + if (!functionId) { return; // not instrumented + } // Append dispatch site ID to correlate this op with where it lives in // the program and what is being dispatched. Note that multiple diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp index ac0ead563b77..4d0022eff77f 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp @@ -295,8 +295,9 @@ convertBindingUsage(mlir::FunctionOpInterface sourceFuncOp, BlockArgument arg, IREE::HAL::PipelineLayoutAttr pipelineLayoutAttr, int64_t bindingOrdinal, IREE::HAL::PipelineBindingAttr bindingAttr) { - if (arg.use_empty()) + if (arg.use_empty()) { return; // no-op + } for (auto &use : llvm::make_early_inc_range(arg.getUses())) { auto oldOp = dyn_cast(use.getOwner()); assert(oldOp && "bindings are only usable by stream.binding.subspan"); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp index e23253927427..85ab0fe7780b 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp @@ -335,8 +335,9 @@ getDeviceFallbackGlobals(IREE::Util::GlobalOpInterface deviceGlobal, SymbolTable &symbolTable) { SetVector resultSet; auto processAttr = [&](Attribute attr) { - if (!attr) + if (!attr) { return true; // ignore uninitialized devices + } return TypeSwitch(attr) .Case([](auto attr) { return true; }) .Case([](auto attr) { return true; }) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp index 48a7b97a1baa..78a353eeb86c 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp @@ -85,8 +85,9 @@ struct MemoizeDeviceQueriesPass // we can't memoize the query today. auto deviceGlobals = deviceAnalysis.lookupDeviceGlobals(queryOp.getDevice()); - if (!deviceGlobals || deviceGlobals->size() != 1) + if (!deviceGlobals || deviceGlobals->size() != 1) { return WalkResult::advance(); + } IREE::Util::GlobalOpInterface deviceGlobalOp = deviceGlobals->front(); // Construct key used to dedupe/lookup the query. diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/PreprocessExecutables.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/PreprocessExecutables.cpp index 63900f1d03fd..7fb2dde0442f 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/PreprocessExecutables.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/PreprocessExecutables.cpp @@ -144,8 +144,9 @@ static LogicalResult preprocessWithCommand(IREE::HAL::ExecutableOp executableOp, #endif // _WIN32 Tokenize(command, stringSaver, rawArgs, /*MarkEOLs=*/false); SmallVector args; - for (auto rawArg : rawArgs) + for (auto rawArg : rawArgs) { args.push_back(StringRef(rawArg)); + } // Try to find the tool either by absolute path or by looking it up in env. auto tool = findTool(args[0].str()); @@ -156,8 +157,9 @@ static LogicalResult preprocessWithCommand(IREE::HAL::ExecutableOp executableOp, LLVM_DEBUG({ llvm::dbgs() << "Launching hal.executable preprocessor: "; - for (auto arg : args) + for (auto arg : args) { llvm::dbgs() << arg << " "; + } llvm::dbgs() << " 1> " << stdoutFile.str() << " 2> " << stderrFile.str() << "\n"; }); @@ -242,8 +244,9 @@ struct PreprocessExecutablesWithPipelinePass } void runOnOperation() override { - if (!pipeline.hasValue()) + if (!pipeline.hasValue()) { return; + } IREE::HAL::ExecutableOp executableOp = getOperation(); OpPassManager passManager(executableOp.getOperationName()); if (failed(buildPassPipeline(pipeline, passManager))) { @@ -270,8 +273,9 @@ struct PreprocessExecutablesWithToolPass using IREE::HAL::impl::PreprocessExecutablesWithToolPassBase< PreprocessExecutablesWithToolPass>::PreprocessExecutablesWithToolPassBase; void runOnOperation() override { - if (!command.hasValue()) + if (!command.hasValue()) { return; + } IREE::HAL::ExecutableOp executableOp = getOperation(); if (failed(preprocessWithCommand(executableOp, command))) { llvm::errs() << "ERROR: failed to preprocess executable `" diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/PruneExecutables.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/PruneExecutables.cpp index 6768ae50c5c3..19fd69e6a205 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/PruneExecutables.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/PruneExecutables.cpp @@ -35,8 +35,9 @@ static void markReferenced(SymbolRefAttr symbolRefAttr, ? SymbolRefAttr::get(rootRefAttr) : SymbolRefAttr::get(rootRefAttr, nestedRefAttrs); auto it = referenceMap.find(nestedRefAttr); - if (it != referenceMap.end()) + if (it != referenceMap.end()) { ++it->second.count; + } }; auto rootRefAttr = symbolRefAttr.getRootReference(); auto nestedRefAttrs = symbolRefAttr.getNestedReferences(); @@ -47,8 +48,9 @@ static void markReferenced(SymbolRefAttr symbolRefAttr, static void processOp(Operation *op, SymbolReferenceMap &referenceMap) { SmallVector worklist; - for (auto namedAttr : op->getAttrs()) + for (auto namedAttr : op->getAttrs()) { worklist.push_back(namedAttr.getValue()); + } while (!worklist.empty()) { auto attr = worklist.pop_back_val(); if (auto symbolRefAttr = dyn_cast(attr)) { @@ -107,8 +109,9 @@ struct PruneExecutablesPass SetVector exportRefAttrs; for (auto executableOp : moduleOp.getOps()) { ignoredOps.insert(executableOp); - if (!executableOp.isPrivate()) + if (!executableOp.isPrivate()) { continue; + } auto executableRefAttr = FlatSymbolRefAttr::get(executableOp.getSymNameAttr()); referenceMap[executableRefAttr].symbolOp = executableOp; @@ -156,8 +159,9 @@ struct PruneExecutablesPass // accumulate the usage counts. SymbolTable symbolTable(moduleOp); moduleOp.walk([&](Operation *op) -> WalkResult { - if (ignoredOps.contains(op)) + if (ignoredOps.contains(op)) { return WalkResult::skip(); + } processOp(op, referenceMap); return op->hasTrait() ? WalkResult::skip() diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/SerializeExecutables.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/SerializeExecutables.cpp index 9cbaec880ab2..2a805a30e6f1 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/SerializeExecutables.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/SerializeExecutables.cpp @@ -78,8 +78,9 @@ struct SerializeTargetExecutablesPass auto variantOps = llvm::to_vector( executableOp.getBlock().getOps()); for (auto variantOp : variantOps) { - if (variantOp.getTarget().getBackend().getValue() != target) + if (variantOp.getTarget().getBackend().getValue() != target) { continue; + } OpBuilder executableBuilder(variantOp); // Ask the target backend to serialize the executable. Note that it // may create one or more hal.executable.binary ops in the case of diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/SubstituteExecutables.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/SubstituteExecutables.cpp index 35bf99d46e46..80dc320476ab 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/SubstituteExecutables.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/SubstituteExecutables.cpp @@ -42,8 +42,9 @@ scanSearchPath(std::string prefix, StringRef searchPath, dir != dir_end && !ec; dir.increment(ec)) { auto childPath = dir->path(); llvm::sys::fs::file_status status; - if (llvm::sys::fs::status(childPath, status)) + if (llvm::sys::fs::status(childPath, status)) { continue; + } switch (status.type()) { case llvm::sys::fs::file_type::regular_file: case llvm::sys::fs::file_type::symlink_file: @@ -104,8 +105,9 @@ replaceExecutableOpWithMLIR(IREE::HAL::ExecutableOp executableOp, // Load the replacement IR. It may have any mix of stuff in it including // multiple other executables. auto rootOpRef = loadModuleObject(executableOp.getContext(), filePath); - if (!rootOpRef) + if (!rootOpRef) { return failure(); + } IREE::HAL::ExecutableOp replacementOp; if (auto moduleOp = dyn_cast(rootOpRef.get())) { // We expect a `hal.executable` with the same name as the one we are @@ -165,8 +167,9 @@ externalizeExecutableOp(IREE::HAL::ExecutableOp executableOp, auto fileObjectAttr = builder.getAttr( builder.getStringAttr(filePath), nullptr); auto fileContents = fileObjectAttr.loadData(); - if (!fileContents) + if (!fileContents) { return failure(); + } // Link the referenced object file contents. We fully replace the existing // objects in case there were any as this does entire executable replacement - @@ -243,8 +246,9 @@ struct SubstituteExecutablesPass uniqueSubstitutions[std::string(key)] = value; } - if (uniqueSubstitutions.empty()) + if (uniqueSubstitutions.empty()) { return; // no-op + } // Walk each substitution and process the matching executable if found. for (auto &[executableName, filePath] : uniqueSubstitutions) { diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/TranslateExecutables.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/TranslateExecutables.cpp index f1f39fa64918..391ff79be794 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/TranslateExecutables.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/TranslateExecutables.cpp @@ -51,10 +51,12 @@ struct TranslateTargetExecutableVariantsPass void runOnOperation() override { IREE::HAL::ExecutableVariantOp variantOp = getOperation(); - if (variantOp.getTarget().getBackend().getValue() != target) + if (variantOp.getTarget().getBackend().getValue() != target) { return; - if (variantOp.isExternal()) + } + if (variantOp.isExternal()) { return; + } auto targetBackend = targetRegistry->getTargetBackend(target); if (!targetBackend) { diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel index d100a35b1ba0..0f87d6865510 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "annotate_target_devices.mlir", "assign_legacy_target_devices.mlir", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Utils/LLVMLinkerUtils.cpp b/compiler/src/iree/compiler/Dialect/HAL/Utils/LLVMLinkerUtils.cpp index f860eaaa0269..5588cecbfa1a 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Utils/LLVMLinkerUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Utils/LLVMLinkerUtils.cpp @@ -89,8 +89,9 @@ loadBitcodeObject(IREE::HAL::ExecutableObjectAttr objectAttr, llvm::MemoryBufferRef bitcodeBufferRef(objectData.value(), objectAttr.getPath()); auto bitcodeModuleValue = llvm::parseBitcodeFile(bitcodeBufferRef, context); - if (!bitcodeModuleValue) + if (!bitcodeModuleValue) { return bitcodeModuleValue; + } // NOTE: at this point the bitcode may not have the expected data layout! return std::move(bitcodeModuleValue.get()); } diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp index eddece9acfa8..3cb319dbe17a 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp @@ -209,7 +209,7 @@ static Value applyPostQKMatmulElementwise(OpBuilder &builder, Location loc, } static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap, - AffineMap maskMap, Value qk, Value mask) { + AffineMap maskMap, Value qk, Value mask, bool useExp2) { SmallVector compressedMaps = compressUnusedDims(SmallVector{qkMap, maskMap}); @@ -245,9 +245,11 @@ static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap, maskVal = convertScalarToDtype(b, loc, maskVal, qkVal.getType(), /*isUnsignedCast=*/false); // Scaling to compensate for base-2 softmax - Value log2e = arith::ConstantOp::create( - b, loc, b.getFloatAttr(qkVal.getType(), M_LOG2E)); - maskVal = arith::MulFOp::create(b, loc, maskVal, log2e); + if (useExp2) { + Value log2e = arith::ConstantOp::create( + b, loc, b.getFloatAttr(qkVal.getType(), M_LOG2E)); + maskVal = arith::MulFOp::create(b, loc, maskVal, log2e); + } } // Finally, set the returned value to the qk element plus the mask // element (or 0/-infinity if bool mask). We opt for a AddFOp (instead @@ -260,10 +262,10 @@ static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap, return genericOp.getResult(0); } -// Compute output = exp2(output - input) -static Value computeSubAndExp2(OpBuilder &builder, Location loc, - AffineMap inputMap, AffineMap outputMap, - Value input, Value output) { +// Compute output = exp2/exp(output - input) depending on useExp2 flag. +static Value computeSubAndExp(OpBuilder &builder, Location loc, + AffineMap inputMap, AffineMap outputMap, + Value input, Value output, bool useExp2) { SmallVector compressedMaps = compressUnusedDims(SmallVector{inputMap, outputMap}); inputMap = compressedMaps[0]; @@ -279,8 +281,9 @@ static Value computeSubAndExp2(OpBuilder &builder, Location loc, Value in = convertScalarToDtype(b, loc, args[0], args[1].getType(), /*isUnsignedCast=*/false); Value diff = arith::SubFOp::create(b, loc, args[1], in); - Value weight = math::Exp2Op::create(b, loc, diff); - linalg::YieldOp::create(b, loc, weight); + Operation *weight = useExp2 ? math::Exp2Op::create(b, loc, diff) + : math::ExpOp::create(b, loc, diff); + linalg::YieldOp::create(b, loc, weight->getResult(0)); }); return genericOp.getResult(0); } @@ -316,15 +319,18 @@ Value computeQKAndElementwise(Location loc, OpBuilder &b, Value query, std::optional maskMap, SmallVector iterationDomain, Type sElementType, Region &elementwiseRegion, - DictionaryAttr qkAttrs, bool lowPrecision) { + DictionaryAttr qkAttrs, bool lowPrecision, + bool useExp2) { MLIRContext *ctx = b.getContext(); - // Since we use exp2 for attention instead of the original exp, we have to + // If using exp2 for attention instead of the original exp, we have to // multiply the scale by log2(e). We use exp2 instead of exp as most platforms // have better support for exp2 (we verified that we gain some speedup on // some GPUs). - Value log2e = arith::ConstantOp::create( - b, loc, b.getFloatAttr(scale.getType(), M_LOG2E)); - scale = arith::MulFOp::create(b, loc, scale, log2e); + if (useExp2) { + Value log2e = arith::ConstantOp::create( + b, loc, b.getFloatAttr(scale.getType(), M_LOG2E)); + scale = arith::MulFOp::create(b, loc, scale, log2e); + } auto qETy = getElementTypeOrSelf(query.getType()); @@ -392,7 +398,7 @@ Value computeQKAndElementwise(Location loc, OpBuilder &b, Value query, // S += mask if (mask != nullptr) { - s = applyMask(b, loc, sMap, *maskMap, s, mask.value()); + s = applyMask(b, loc, sMap, *maskMap, s, mask.value(), useExp2); } return s; @@ -436,9 +442,9 @@ FailureOr> AttentionOp::decomposeOperation(OpBuilder &b) { Type f32Type = b.getF32Type(); // ---- QK Matmul + elementwise math ---- - Value s = computeQKAndElementwise(loc, b, query, key, getScale(), mask, qMap, - kMap, sMap, getMaskMap(), sizes, f32Type, - getRegion(), qkAttrs, lowPrecision); + Value s = computeQKAndElementwise( + loc, b, query, key, getScale(), mask, qMap, kMap, sMap, getMaskMap(), + sizes, f32Type, getRegion(), qkAttrs, lowPrecision, /*useExp2=*/true); // ---- Softmax ---- @@ -480,7 +486,7 @@ FailureOr> AttentionOp::decomposeOperation(OpBuilder &b) { // P = exp2(S - max) AffineMap pMap = sMap; - Value p = computeSubAndExp2(b, loc, maxMap, sMap, max, s); + Value p = computeSubAndExp(b, loc, maxMap, sMap, max, s, /*useExp2=*/true); // sum = rowSum(P) Value sum = reduce(b, loc, pMap, sumMap, p, sumFill); @@ -530,9 +536,13 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { DictionaryAttr config = getDecompositionConfigAttr(); DictionaryAttr qkAttrs, pvAttrs; + bool useExp2 = true; if (config) { qkAttrs = config.getAs(getQKAttrStr()); pvAttrs = config.getAs(getPVAttrStr()); + if (auto useExp2Attr = config.getAs(getUseExp2AttrStr())) { + useExp2 = useExp2Attr.getValue(); + } } FailureOr maybeOpInfo = AttentionOpDetail::get( @@ -553,7 +563,7 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { // ---- QK Matmul + elementwise math ---- Value s = computeQKAndElementwise( loc, b, query, key, getScale(), mask, qMap, kMap, sMap, getMaskMap(), - sizes, elementType, getRegion(), qkAttrs, lowPrecision); + sizes, elementType, getRegion(), qkAttrs, lowPrecision, useExp2); // TODO: This decomposition should be in a seperate op called // "online softmax". @@ -563,20 +573,21 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { AffineMap maxMap = getMaxMap(); Value newMax = reduce(b, loc, sMap, maxMap, s, oldMax); - // norm = exp2(oldMax - newMax) + // norm = exp2(oldMax - newMax) or exp(oldMax - newMax) depending on useExp2 // normMap = maxMap AffineMap normMap = getMaxMap(); - Value norm = computeSubAndExp2(b, loc, maxMap, normMap, newMax, oldMax); + Value norm = + computeSubAndExp(b, loc, maxMap, normMap, newMax, oldMax, useExp2); // normSum = norm * oldSum AffineMap sumMap = getSumMap(); Value normSum = elementwiseValueInPlace(b, loc, sumMap, normMap, oldSum, norm); - // P = exp2(S - newMax) + // P = exp2(S - newMax) or exp(S - newMax) depending on useExp2 // PMap = SMap AffineMap pMap = sMap; - Value p = computeSubAndExp2(b, loc, maxMap, sMap, newMax, s); + Value p = computeSubAndExp(b, loc, maxMap, sMap, newMax, s, useExp2); // newSum = normSum + rowSum(P) Value newSum = reduce(b, loc, pMap, sumMap, p, normSum); @@ -816,8 +827,9 @@ FailureOr> Im2colOp::decomposeOperation(OpBuilder &b) { } SetVector batchPosSet(getBatchPos().begin(), getBatchPos().end()); for (auto [idx, size] : enumerate(inputSizes)) { - if (batchPosSet.contains(idx)) + if (batchPosSet.contains(idx)) { continue; + } if (mPosSet.contains(idx)) { kBasis.push_back(kernelSize[mKernelIdx[idx]]); continue; @@ -850,8 +862,9 @@ FailureOr> Im2colOp::decomposeOperation(OpBuilder &b) { int delinKIdx = 0; SmallVector invInputKPerm = invertPermutationVector(inputKPerm); for (int i = 0; i < getInputRank(); ++i) { - if (batchPosSet.contains(i)) + if (batchPosSet.contains(i)) { continue; + } if (mPosSet.contains(i)) { windowOffset.push_back(delinKOffset[invInputKPerm[delinKIdx++]]); continue; @@ -1209,11 +1222,11 @@ FailureOr> ExpReductionOp::decomposeOperation(OpBuilder &b) { Value currMax = reduce( rewriter, loc, normValMap, prevMaxMap, sValue->get(), prevMax->get()); // ex = e^{sValue - curr_max} - Value ex = computeSubAndExp2(rewriter, loc, prevMaxMap, normValMap, currMax, - sValue->get()); + Value ex = computeSubAndExp(rewriter, loc, prevMaxMap, normValMap, currMax, + sValue->get(), /*useExp2=*/true); // norm = e^(prev_max - curr_max) - Value norm = computeSubAndExp2(rewriter, loc, prevMaxMap, prevMaxMap, currMax, - prevMax->get()); + Value norm = computeSubAndExp(rewriter, loc, prevMaxMap, prevMaxMap, currMax, + prevMax->get(), /*useExp2=*/true); SmallVector inputs = getDpsInputs(); SmallVector normOuts(getNumDpsInits()); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel index 7209bda88aa7..244aac9a100a 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "LinalgExtBase.td", "LinalgExtInterfaces.td", diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp index 22728e5167e2..197435f69779 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp @@ -656,6 +656,45 @@ void MapGatherOp::insertTransformationAtStart( transformBody.eraseArguments(0, oldOutputIndices.size()); } +/// Shared implementation for inlining the transformation body of map_gather +/// and map_scatter ops. +static void inlineMapGatherScatterBodyImpl( + OpBuilder &b, Location loc, Region &transformRegion, + ValueRange transformBodyIndices, + function_ref)> bodyBuilder) { + Block &transformBlock = transformRegion.front(); + IRMapping mapping; + // Map the induction variables of the loop nest to the block arguments of the + // transformation body. + for (auto [idx, arg] : llvm::enumerate(transformBlock.getArguments())) { + mapping.map(arg, transformBodyIndices[idx]); + } + // Clone the operations within the transformation body to the current + // insertion point, and map their results to the new cloned operations' + // results. + for (Operation &op : transformBlock.without_terminator()) { + Operation *clonedOp = b.clone(op, mapping); + for (auto [result, clonedResult] : + llvm::zip_equal(op.getResults(), clonedOp->getResults())) { + mapping.map(result, clonedResult); + } + } + + // Get the cloned values that were yielded by the transformation body to pass + // to the bodyBuilder. + SmallVector mappedYieldedValues = llvm::map_to_vector( + transformBlock.getTerminator()->getOperands(), + [&](Value operand) -> Value { return mapping.lookupOrDefault(operand); }); + bodyBuilder(b, loc, mappedYieldedValues); +} + +void MapGatherOp::inlineMapGatherBody( + OpBuilder &b, Location loc, ValueRange transformBodyIndices, + function_ref)> bodyBuilder) { + inlineMapGatherScatterBodyImpl(b, loc, getTransformationRegion(), + transformBodyIndices, bodyBuilder); +} + //===----------------------------------------------------------------------===// // MapScatterOp //===----------------------------------------------------------------------===// @@ -787,31 +826,8 @@ void MapScatterOp::insertTransformationAtStart( void MapScatterOp::inlineMapScatterBody( OpBuilder &b, Location loc, ValueRange transformBodyIndices, function_ref)> bodyBuilder) { - Block &transformBlock = getTransformationRegion().front(); - IRMapping mapping; - // Map the induction variables of the loop nest to the block arguments of the - // transformation body. The induction variables are the indices looping over - // the elements of input operand. - for (auto [idx, arg] : llvm::enumerate(transformBlock.getArguments())) { - mapping.map(arg, transformBodyIndices[idx]); - } - // Clone the operations within the transformation body to the current - // insertion point, and map their results to the new cloned operations' - // results. - for (Operation &op : transformBlock.without_terminator()) { - Operation *clonedOp = b.clone(op, mapping); - for (auto [result, clonedResult] : - llvm::zip_equal(op.getResults(), clonedOp->getResults())) { - mapping.map(result, clonedResult); - } - } - - // Get the cloned values that were yielded by the transformation body to pass - // to the bodyBuilder. - SmallVector mappedYieldedValues = llvm::map_to_vector( - transformBlock.getTerminator()->getOperands(), - [&](Value operand) -> Value { return mapping.lookupOrDefault(operand); }); - bodyBuilder(b, loc, mappedYieldedValues); + inlineMapGatherScatterBodyImpl(b, loc, getTransformationRegion(), + transformBodyIndices, bodyBuilder); } bool MapScatterOp::isIdentity() { @@ -1032,8 +1048,9 @@ LogicalResult FftOp::verify() { // After tiling, it could be dynamic shape. (Because // subview/subtensor does not inference the type correctly // on (1 << x)) cases). - if (ShapedType::isDynamic(length)) + if (ShapedType::isDynamic(length)) { return success(); + } if (length & (length - 1)) { return op->emitOpError("only powers of 2 are handled currently"); } @@ -1271,8 +1288,9 @@ LogicalResult ArgCompareOp::verify() { SmallVector expectedShape; for (int64_t i = 0; i < rank; ++i) { - if (i != dim) + if (i != dim) { expectedShape.push_back(inputType.getDimSize(i)); + } } if (!llvm::equal(expectedShape, outputValueType.getShape())) { return op->emitOpError("output shape must match input shape with reduction " @@ -1376,15 +1394,18 @@ areNotFullTiles(ArrayRef inputShape, DenseMap const &dimAndTileMapping) { int64_t rank = inputShape.size(); for (int64_t dim = 0; dim < rank; dim++) { - if (ShapedType::isDynamic(inputShape[dim])) + if (ShapedType::isDynamic(inputShape[dim])) { continue; + } auto it = dimAndTileMapping.find(dim); if (it != dimAndTileMapping.end()) { std::optional constantTile = getConstantIntValue(it->second); - if (!constantTile) + if (!constantTile) { continue; - if (inputShape[dim] % (*constantTile) != 0) + } + if (inputShape[dim] % (*constantTile) != 0) { return true; + } } } return false; @@ -2111,8 +2132,9 @@ LogicalResult AttentionOp::verify() { // Additional check case if mask exists if (auto maskMap = getMaskMap()) { - if (failed(checkShape("Mask", getMask().getType().getShape(), *maskMap))) + if (failed(checkShape("Mask", getMask().getType().getShape(), *maskMap))) { return failure(); + } } int expectedSymbols = getQueryMap().getNumInputs(); @@ -2137,14 +2159,16 @@ LogicalResult AttentionOp::verify() { // Additional check case if mask exists if (auto maskMap = getMaskMap()) { - if (failed(checkDomain("Mask", *maskMap))) + if (failed(checkDomain("Mask", *maskMap))) { return failure(); + } } auto &block = getRegion().front(); auto blockTys = block.getArgumentTypes(); - if (!isa(blockTys[0])) + if (!isa(blockTys[0])) { return attnOp->emitOpError("block argument 0 should be float"); + } auto yieldOp = dyn_cast(block.getTerminator()); if (!yieldOp) { @@ -2288,8 +2312,9 @@ LogicalResult OnlineAttentionOp::verify() { // Additional check case if mask exists if (auto maskMap = getMaskMap()) { - if (failed(checkShape("Mask", getMask().getType().getShape(), *maskMap))) + if (failed(checkShape("Mask", getMask().getType().getShape(), *maskMap))) { return failure(); + } } int expectedSymbols = getQueryMap().getNumInputs(); @@ -2316,8 +2341,9 @@ LogicalResult OnlineAttentionOp::verify() { // Additional check case if mask exists if (auto maskMap = getMaskMap()) { - if (failed(checkDomain("Mask", *maskMap))) + if (failed(checkDomain("Mask", *maskMap))) { return failure(); + } } Block &block = attnOp.getRegion().front(); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index ee84c2abd433..f09805746254 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -280,7 +280,8 @@ def IREELinalgExt_MapGatherOp : IREELinalgExt_Op<"map_gather", "getTiledImplementation", "generateResultTileValue", "getIterationDomainTileFromOperandTiles", - "getTiledImplementationFromOperandTiles"]>]> { + "getTiledImplementationFromOperandTiles", + "generateScalarImplementation"]>]> { let summary = [{Gather with a mapping from result indices to source indices.}]; let description = [{ Takes two operands, `source` and `output`, and reads every element from @@ -337,6 +338,15 @@ def IREELinalgExt_MapGatherOp : IREELinalgExt_Op<"map_gather", transformationBuilder, int64_t numOutputIndices); + // Inline the transformation region of the map_gather op without its + // terminator, replacing the block arguments with the passed + // `transformBodyIndices`. The `bodyBuilder` function is called with the + // cloned `Value`s that would have been yielded by the terminator of + // the inlined transformation body (source indices and padding value). + void inlineMapGatherBody( + OpBuilder &b, Location loc, ValueRange transformBodyIndices, + function_ref)> bodyBuilder); + // Method to implement for specifying output range for // DestinationStyleOpInterface MutableOperandRange getDpsInitsMutable() { @@ -997,6 +1007,16 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_Op<"online_attention", it over the entire softmax reduction dimension by: x, _, sum : results x = (1 / sum) * x + + Decomposition Configuration: + The `decomposition_config` attribute is a DictionaryAttr that controls how + this operation is decomposed into lower-level operations. It supports: + - "qk_attrs": DictionaryAttr - Attributes to attach to the Q@K matmul + operation after decomposition (e.g., lowering_config, attention markers) + - "pv_attrs": DictionaryAttr - Attributes to attach to the P@V matmul + operation after decomposition + - "use_exp2": BoolAttr - If true, uses exp2 with log2(e) scaling instead + of exp. (Gives better perf on some hardware, but trades off accuracy) }]; let arguments = (ins AnyShaped:$query, @@ -1081,6 +1101,8 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_Op<"online_attention", // Attributes to set on QK and PV matmul after decomposition. static StringRef getQKAttrStr() { return "qk_attrs"; } static StringRef getPVAttrStr() { return "pv_attrs"; } + // Flag to control whether to use exp2 (with log2(e) scaling) or exp. + static StringRef getUseExp2AttrStr() { return "use_exp2"; } }]; let hasCanonicalizer = 1; diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp index f3b71954fb0c..acd09029fc9b 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp @@ -291,8 +291,9 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b, auto dim = dimMap[i]; - if (starts[dim]) + if (starts[dim]) { ret = arith::AddIOp::create(b, loc, ret, starts[dim]); + } starts[dim] = ret; } @@ -441,8 +442,9 @@ LogicalResult GatherOp::generateScalarImplementation(OpBuilder &b, Location loc, Value idx = memref::LoadOp::create(b, loc, getIndices(), loadIndices); Value ret = arith::IndexCastOp::create(b, loc, b.getIndexType(), idx); auto dim = dimMap[i]; - if (starts[dim]) + if (starts[dim]) { ret = arith::AddIOp::create(b, loc, ret, starts[dim]); + } starts[dim] = ret; } @@ -565,6 +567,71 @@ FailureOr MapGatherOp::getTiledImplementationFromOperandTiles( return getTiledImplementation(b, mappedOffsets, mappedSizes); } +/// The body of the transformation_region is inlined, and the yielded indices +/// are used to read values from the source and write to the output. Bounds +/// checking is performed on the source indices, and the padding value is used +/// if the indices are out of bounds. +LogicalResult MapGatherOp::generateScalarImplementation(OpBuilder &b, + Location loc, + ValueRange ivs) { + // The scalar implementation is currently only implemented for buffer + // semantics. + if (!hasPureBufferSemantics()) { + return failure(); + } + + auto bodyBuilder = [&](OpBuilder nestedBuilder, Location nestedLoc, + ArrayRef yieldedValues) { + // The last yielded Value is the padding, the rest are source indices. + Value paddingValue = yieldedValues.back(); + ArrayRef loadIndices = yieldedValues.drop_back(); + + // Check bounds for each source dimension. Start with true so that + // for 0-D sources, inBounds is always true. + Value inBounds = nestedBuilder.createOrFold( + nestedLoc, /*value=*/1, /*width=*/1); + Value zero = + nestedBuilder.createOrFold(nestedLoc, 0); + for (auto [dim, idx] : llvm::enumerate(loadIndices)) { + Value dimSize = + memref::DimOp::create(nestedBuilder, nestedLoc, getSource(), dim); + + // Check: idx >= 0 + Value geZero = arith::CmpIOp::create( + nestedBuilder, nestedLoc, arith::CmpIPredicate::sge, idx, zero); + // Check: idx < dimSize + Value ltDim = arith::CmpIOp::create( + nestedBuilder, nestedLoc, arith::CmpIPredicate::slt, idx, dimSize); + // Combine: idx >= 0 && idx < dimSize + Value dimInBounds = + arith::AndIOp::create(nestedBuilder, nestedLoc, geZero, ltDim); + + inBounds = arith::AndIOp::create(nestedBuilder, nestedLoc, inBounds, + dimInBounds); + } + + // Create if-else: if in bounds, load from source; else use padding. + // The if yields the value to store. + auto ifOp = scf::IfOp::create(nestedBuilder, nestedLoc, + TypeRange{paddingValue.getType()}, inBounds, + /*addThenBlock=*/true, /*addElseBlock=*/true); + { + auto thenBuilder = ifOp.getThenBodyBuilder(); + Value loaded = memref::LoadOp::create(thenBuilder, nestedLoc, getSource(), + loadIndices); + scf::YieldOp::create(thenBuilder, nestedLoc, loaded); + } + { + auto elseBuilder = ifOp.getElseBodyBuilder(); + scf::YieldOp::create(elseBuilder, nestedLoc, paddingValue); + } + memref::StoreOp::create(nestedBuilder, nestedLoc, ifOp.getResult(0), + getOutput(), ivs); + }; + inlineMapGatherBody(b, loc, ivs, bodyBuilder); + return success(); +} + //===----------------------------------------------------------------------===// // MapScatterOp //===----------------------------------------------------------------------===// @@ -1127,11 +1194,13 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc, scanBlkArgs.push_back( memref::LoadOp::create(b, loc, getOutput(), indices)); Value i0; - if (!isInclusive) + if (!isInclusive) { i0 = memref::LoadOp::create(b, loc, getInput(), indices); + } indices[scanDim] = iv; - if (isInclusive) + if (isInclusive) { i0 = memref::LoadOp::create(b, loc, getInput(), indices); + } scanBlkArgs.push_back(i0); }); @@ -1227,8 +1296,9 @@ LogicalResult ScanOp::getResultTilePosition( int64_t rank = getOperandRank(); if (rank > 1) { for (auto i : llvm::seq(0, rank)) { - if (i == getDimension()) + if (i == getDimension()) { continue; + } resultOffsets.push_back(offsets[i]); resultSizes.push_back(sizes[i]); } @@ -1552,8 +1622,9 @@ LogicalResult ArgCompareOp::generateScalarImplementation(OpBuilder &b, uint64_t reductionDim = getDimension(); SmallVector parallelIndices; for (size_t i = 0, rank = ivs.size(); i < rank; ++i) { - if (i == reductionDim) + if (i == reductionDim) { continue; + } parallelIndices.push_back(ivs[i]); } @@ -3507,8 +3578,9 @@ static void offsetCustomOpIndices(OpBuilder &b, CustomOp customOp, ArrayRef offsets) { IRRewriter rewriter(b); for (auto indexOp : customOp.getBody()->getOps()) { - if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()]) + if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()]) { continue; + } OpBuilder::InsertionGuard guard(b); rewriter.setInsertionPointAfter(indexOp); AffineExpr index, offset; diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/BUILD.bazel index aff5921fd57b..1d224bb0bcc3 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "canonicalize.mlir", "decompose_aggregate_op.mlir", diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir index 7de65d86c8c7..7eff2e014b79 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir @@ -190,6 +190,7 @@ func.func @online_attention_f16(%query: tensor<192x1024x64xf16>, // correct number of extf/truncfs are emitted. // CHECK-LABEL: @online_attention_f16 // Q = Q * scale +// CHECK: arith.constant 1.442380e+00 : f16 // CHECK: linalg.generic // CHECK: arith.mulf // S = Q @ K @@ -419,6 +420,65 @@ func.func @online_attention_f8_masked(%query: tensor<192x1024x64xf8E4M3FNUZ>, // ----- +// Spec to decompose online attention op. +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op + transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> () + transform.yield + } +} + +#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> +#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> +#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> +#mapS = affine_map<(batch, m, k1, k2, n) -> ()> +#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)> +#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> + +func.func @online_attention_f16_noexp2(%query: tensor<192x1024x64xf16>, + %key: tensor<192x1024x64xf16>, + %value: tensor<192x1024x64xf16>, + %output: tensor<192x1024x64xf32>, + %max: tensor<192x1024xf32>, + %sum: tensor<192x1024xf32>) + -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) { + %scale = arith.constant 1.0 : f16 + + %out:3 = iree_linalg_ext.online_attention + {decomposition_config = {use_exp2=false}, indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] } + ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16) + outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) { + ^bb0(%score: f32): + iree_linalg_ext.yield %score: f32 + } + -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32> + + return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32> +} + +// We want to check that we're correctly using exp +// when specified so from the decomposition_config. +// CHECK-LABEL: @online_attention_f16_noexp2 +// Q = Q * scale +// CHECK: arith.constant 1.000000e+00 : f16 +// CHECK: linalg.generic +// CHECK: arith.mulf +// norm = exp (oldMax - newMax) +// CHECK: linalg.generic +// CHECK: arith.subf +// CHECK-NOT: arith.extf +// CHECK-NOT: math.exp2 +// CHECK: linalg.yield +// P = exp(S - newMax) +// CHECK: linalg.generic +// CHECK: arith.subf +// CHECK-NOT: arith.extf +// CHECK-NOT: math.exp2 +// CHECK: linalg.yield + +// ----- + // Spec to decompose exp reduction op. module attributes { transform.with_named_sequence } { transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { @@ -495,4 +555,4 @@ func.func @exp_reduction( // CHECK-SAME: outs(%[[acc_norm]] // CHECK: arith.mulf // CHECK: arith.addf -// CHECK: return %[[M]], %[[SUM]], %[[PV]] +// CHECK: return %[[M]], %[[SUM]], %[[PV]] \ No newline at end of file diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/BUILD.bazel index d70b7ab22f1f..65619199475c 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "LinalgExtExtensionsOps.td", ], diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/test/BUILD.bazel index 603f52011f7a..43f44a7a07e0 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ ], include = ["*.mlir"], diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConvToIm2ColOp.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConvToIm2ColOp.cpp index ecc7dd2035e4..2b6076aa0793 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConvToIm2ColOp.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConvToIm2ColOp.cpp @@ -27,15 +27,17 @@ static bool hasAllOneValues(ArrayRef attr) { static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) { bool isInt = isa(x.getType()); - if (isInt) + if (isInt) { return arith::AddIOp::create(builder, loc, x, y); + } return arith::AddFOp::create(builder, loc, x, y); } static Value createMul(Location loc, Value x, Value y, OpBuilder &builder) { bool isInt = isa(x.getType()); - if (isInt) + if (isInt) { return arith::MulIOp::create(builder, loc, x, y); + } return arith::MulFOp::create(builder, loc, x, y); } @@ -153,9 +155,10 @@ class ConvertConvGeneric final auto igemmConvDetailsOrFailure = LinalgExt::getIGEMMGenericConvDetails(linalgOp); - if (failed(igemmConvDetailsOrFailure)) + if (failed(igemmConvDetailsOrFailure)) { return rewriter.notifyMatchFailure(linalgOp, "Failed to extract IGEMM details"); + } LinalgExt::IGEMMGenericConvDetails igemmConvDetails = *igemmConvDetailsOrFailure; diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp index ac7c42ab58ec..1fb50d83d52a 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp @@ -32,8 +32,15 @@ struct DecomposeAttentionPass final void DecomposeAttentionPass::runOnOperation() { MLIRContext *context = &getContext(); IRRewriter rewriter(context); + getOperation().walk([&](OnlineAttentionOp onlineAtt) { rewriter.setInsertionPoint(onlineAtt); + + NamedAttrList decompositionConfig(onlineAtt.getDecompositionConfigAttr()); + decompositionConfig.set("use_exp2", rewriter.getBoolAttr(useExp2)); + onlineAtt.setDecompositionConfigAttr( + decompositionConfig.getDictionary(context)); + FailureOr> results = onlineAtt.decomposeOperation(rewriter); if (failed(results)) { diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td index c1ce03397950..ea27c2e16fd7 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td @@ -116,6 +116,11 @@ def DecomposeAttentionPass : InterfacePass<"iree-linalg-ext-decompose-attention", "mlir::FunctionOpInterface"> { let summary = "Decomposes attention op into a sequence of linalg ops"; + let options = [ + Option<"useExp2", "use-exp2", "bool", /*default=*/"true", + "Use exp2 for computations; Tunable to allow for accuracte computations" + "in case of accuracy losses due to fp-reassociation.">, + ]; } def ConvertAttentionToOnlineAttentionPass : diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp index 252e6204d790..d91a41213e20 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp @@ -71,8 +71,9 @@ static SmallVector getDimSizes(Value v) { static bool isIdentityReassoc(const SmallVector &indices) { for (auto &index : indices) { - if (index.size() != 1) + if (index.size() != 1) { return false; + } } return true; }; @@ -240,8 +241,9 @@ LogicalResult ExpansionInfo::compute( SmallVector infos, SmallVector loopRanges, OpOperand *fusableOpOperand, ArrayRef operandReassoc, ArrayRef expandedShape) { - if (operandReassoc.empty()) + if (operandReassoc.empty()) { return failure(); + } // Check that the operand dim size matches the iteration space dim size. This // can fail when one is static and the other is dynamic. @@ -307,28 +309,33 @@ CollapsingInfo::initialize(unsigned origNumLoops, llvm::SmallDenseSet processedDims; // Find all the dims that are folded. for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) { - if (foldedIterationDim.empty()) + if (foldedIterationDim.empty()) { continue; + } // If the folded dims contain dims already folded, that's illegal // specification. Repetition within a list is also illegal. for (auto dim : foldedIterationDim) { - if (dim >= origNumLoops) + if (dim >= origNumLoops) { return failure(); - if (processedDims.count(dim)) + } + if (processedDims.count(dim)) { return failure(); + } processedDims.insert(dim); } collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(), foldedIterationDim.end()); } - if (processedDims.size() > origNumLoops) + if (processedDims.size() > origNumLoops) { return failure(); + } // Add all the preserved dims of the original op as single // elements to `collapsedOpToOrigOpIterationDim`. for (auto dim : llvm::seq(0, origNumLoops)) { - if (processedDims.count(dim)) + if (processedDims.count(dim)) { continue; + } collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim}); } @@ -339,9 +346,10 @@ CollapsingInfo::initialize(unsigned origNumLoops, origOpToCollapsedOpIterationDim.resize(origNumLoops); for (const auto &foldedDims : llvm::enumerate(collapsedOpToOrigOpIterationDim)) { - for (const auto &dim : enumerate(foldedDims.value())) + for (const auto &dim : enumerate(foldedDims.value())) { origOpToCollapsedOpIterationDim[dim.value()] = std::make_pair(foldedDims.index(), dim.index()); + } } return success(); } @@ -387,9 +395,10 @@ getReshapeInfo(LinalgExt::ScatterOp scatterOp) { indicesInfo.originalShape = getDimSizes(scatterOp.getIndices()); llvm::append_range(indicesInfo.operandToIterationSpace, llvm::seq(0, scatterOp.getBatchRank())); - if (scatterOp.getBatchRank() != scatterOp.getIndicesType().getRank()) + if (scatterOp.getBatchRank() != scatterOp.getIndicesType().getRank()) { indicesInfo.operandToIterationSpace.push_back( ReshapeOperandInfo::kNoMapping); + } infos.push_back(std::move(indicesInfo)); ReshapeOperandInfo originalInfo; @@ -420,9 +429,10 @@ getReshapeInfo(LinalgExt::GatherOp gatherOp) { indicesInfo.originalShape = getDimSizes(gatherOp.getIndices()); llvm::append_range(indicesInfo.operandToIterationSpace, llvm::seq(0, gatherOp.getBatchRank())); - if (gatherOp.getBatchRank() != gatherOp.getIndicesType().getRank()) + if (gatherOp.getBatchRank() != gatherOp.getIndicesType().getRank()) { indicesInfo.operandToIterationSpace.push_back( ReshapeOperandInfo::kNoMapping); + } infos.push_back(std::move(indicesInfo)); ReshapeOperandInfo outputInfo; @@ -846,10 +856,12 @@ struct FoldWithProducerReshapeByExpansion final for (OpOperand &opOperand : op->getOpOperands()) { tensor::CollapseShapeOp reshapeOp = opOperand.get().getDefiningOp(); - if (!reshapeOp) + if (!reshapeOp) { continue; - if (!controlFoldingReshapes(&opOperand)) + } + if (!controlFoldingReshapes(&opOperand)) { continue; + } std::optional replacementValue = fuseWithReshapeByExpansion(op, reshapeOp, &opOperand, rewriter); @@ -893,8 +905,9 @@ struct FoldWithConsumerReshapeByExpansion final std::optional replacementValue = fuseWithReshapeByExpansion( op, expandOp, op.getTiedOpOperand(producerResult), rewriter); - if (!replacementValue) + if (!replacementValue) { return failure(); + } rewriter.replaceOp(op, *replacementValue); return success(); } @@ -946,8 +959,9 @@ static Value getCollapsedOpOperand(Location loc, AttentionOp op, // the number of results of the indexing map, then nothing to do for this // operand. Value operand = opOperand->get(); - if (operandReassociation.size() == indexingMap.getNumResults()) + if (operandReassociation.size() == indexingMap.getNumResults()) { return operand; + } // Insert a reshape to collapse the dimensions. if (isa(operand.getType())) { @@ -982,8 +996,9 @@ static void collapseOperandsAndResults(AttentionOp op, outputOperands.push_back(newOutput); // If the op has "buffer semantics", then the init operands are ranked // memrefs and the op has no results. - if (!op.hasPureBufferSemantics()) + if (!op.hasPureBufferSemantics()) { resultTypes.push_back(newOutput.getType()); + } } } @@ -1001,8 +1016,9 @@ getCollapsedOpIndexingMap(AffineMap indexingMap, for (auto expr : indexingMap.getResults()) { unsigned dim = cast(expr).getPosition(); // If the dim is not the first of the collapsed dim, do nothing. - if (origOpToCollapsedOpMapping[dim].second != 0) + if (origOpToCollapsedOpMapping[dim].second != 0) { continue; + } // The next n-dims are guaranteed to be collapsed. So just use the // iteration dimension of the collapsed op. resultExprs.push_back( @@ -1067,8 +1083,9 @@ collapseOpIterationDims(AttentionOp op, if (op.getNumLoops() <= 1 || foldedIterationDims.empty() || llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) { return foldedDims.size() <= 1; - })) + })) { return failure(); + } CollapsingInfo collapsingInfo; if (failed( diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/RewriteFft.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/RewriteFft.cpp index 087ec5a7e729..e34b54c33cc3 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/RewriteFft.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/RewriteFft.cpp @@ -25,8 +25,9 @@ FailureOr> rewriteFft(Operation *op, Value operand, } // Skip else getBitReversalOrder produces invalid dense elements attr. - if (!operandType.getElementType().isF32()) + if (!operandType.getElementType().isF32()) { return rewriter.notifyMatchFailure(op, "expected F32 types"); + } ImplicitLocOpBuilder b(loc, rewriter); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/SplitReduction.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/SplitReduction.cpp index f03682a23394..20baa05c7a3c 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/SplitReduction.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/SplitReduction.cpp @@ -658,9 +658,10 @@ splitArgmaxReduction(RewriterBase &rewriter, linalg::GenericOp genericOp, Value outVal = args[1]; Value outIdx = args[2]; Value reductionIdx = linalg::IndexOp::create(b, loc, reductionDim + 1); - if (outIdx.getType() != reductionIdx.getType()) + if (outIdx.getType() != reductionIdx.getType()) { reductionIdx = arith::IndexCastOp::create(b, loc, outIdx.getType(), reductionIdx); + } Value inCast = in; Type inType = in.getType(); Type outType = outVal.getType(); @@ -715,8 +716,9 @@ splitArgmaxReduction(RewriterBase &rewriter, linalg::GenericOp genericOp, Value outIdx = inputs[3]; Value outer = linalg::IndexOp::create(b, loc, insertSplitDimension); Value offset = arith::MulIOp::create(b, loc, outer, tileSize); - if (offset.getType() != local.getType()) + if (offset.getType() != local.getType()) { offset = arith::IndexCastOp::create(b, loc, local.getType(), offset); + } // gidx = outer * ratio + local. Value gidx = arith::AddIOp::create(b, loc, offset, local); Operation *clonedMax = b.clone(*combinerOps.maxOp); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel index 6e36532763a8..2176e3431a51 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "conv2d_to_winograd.mlir", "conv_to_im2col.mlir", diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir index 32c39875fa04..06ed7cd81d59 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir @@ -1703,3 +1703,49 @@ func.func @map_scatter_memref( // CHECK-NEXT: %[[INPUT_ELEM:.+]] = memref.load %[[INPUT]][%[[IV]]] // CHECK-NEXT: memref.store %[[INPUT_ELEM]], %[[OUTPUT]] // CHECK-SAME: [%[[OUT_IDX]]#0, %[[OUT_IDX]]#1] : memref + +// ----- + +func.func @map_gather_memref( + %source: memref, %output: memref +) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = memref.dim %source, %c0 : memref + %dim1 = memref.dim %source, %c1 : memref + iree_linalg_ext.map_gather %source into %output { + ^bb0(%idx0: index): + %src_idx:2 = affine.delinearize_index %idx0 into (%dim0, %dim1) : index, index + %pad = arith.constant 0.0 : f32 + iree_linalg_ext.yield %src_idx#0, %src_idx#1, %pad : index, index, f32 + } : memref into memref + return +} +// CHECK: func @map_gather_memref +// CHECK-SAME: %[[SOURCE:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[PAD:.+]] = arith.constant 0.{{0+}}e+00 : f32 +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[SRC_D0:.+]] = memref.dim %[[SOURCE]], %[[C0]] +// CHECK-DAG: %[[SRC_D1:.+]] = memref.dim %[[SOURCE]], %[[C1]] +// CHECK-DAG: %[[OUT_D0:.+]] = memref.dim %[[OUTPUT]], %[[C0]] +// CHECK: scf.for %[[IV:.+]] = %[[C0]] to %[[OUT_D0]] step %[[C1]] +// CHECK: %[[SRC_IDX:.+]]:2 = affine.delinearize_index %[[IV]] into (%[[SRC_D0]], %[[SRC_D1]]) : index, index +// CHECK-DAG: %[[BOUND_D0:.+]] = memref.dim %[[SOURCE]], %[[C0]] +// CHECK-DAG: %[[GE_ZERO_0:.+]] = arith.cmpi sge, %[[SRC_IDX]]#0, %[[C0]] : index +// CHECK-DAG: %[[LT_DIM_0:.+]] = arith.cmpi slt, %[[SRC_IDX]]#0, %[[BOUND_D0]] : index +// CHECK-DAG: %[[IN_BOUNDS_0:.+]] = arith.andi %[[GE_ZERO_0]], %[[LT_DIM_0]] +// CHECK-DAG: %[[BOUND_D1:.+]] = memref.dim %[[SOURCE]], %[[C1]] +// CHECK-DAG: %[[GE_ZERO_1:.+]] = arith.cmpi sge, %[[SRC_IDX]]#1, %[[C0]] : index +// CHECK-DAG: %[[LT_DIM_1:.+]] = arith.cmpi slt, %[[SRC_IDX]]#1, %[[BOUND_D1]] : index +// CHECK-DAG: %[[IN_BOUNDS_1:.+]] = arith.andi %[[GE_ZERO_1]], %[[LT_DIM_1]] +// CHECK-DAG: %[[IN_BOUNDS:.+]] = arith.andi %[[IN_BOUNDS_0]], %[[IN_BOUNDS_1]] +// CHECK: %[[IF_RESULT:.+]] = scf.if %[[IN_BOUNDS]] -> (f32) { +// CHECK: %[[SOURCE_ELEM:.+]] = memref.load %[[SOURCE]] +// CHECK-SAME: [%[[SRC_IDX]]#0, %[[SRC_IDX]]#1] : memref +// CHECK: scf.yield %[[SOURCE_ELEM]] : f32 +// CHECK: } else { +// CHECK: scf.yield %[[PAD]] : f32 +// CHECK: } +// CHECK: memref.store %[[IF_RESULT]], %[[OUTPUT]][%[[IV]]] diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/MatchUtils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/MatchUtils.cpp index 468ccc8f8bf1..2481289a6904 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/MatchUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/MatchUtils.cpp @@ -97,8 +97,9 @@ findPermutationsIndexingOperand(AffineMap indexingMap, if (iterators[d.getPosition()] == iter && llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) { return e.isFunctionOfDim(d.getPosition()); - }) == 1) + }) == 1) { res.insert(d.getPosition()); + } } } return res; diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp index 464bf7b4f80d..39a0ed0cf37e 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp @@ -461,8 +461,9 @@ FailureOr getIGEMMGenericConvDetails(linalg::LinalgOp linalgOp) { auto convDimsOrFailure = linalg::inferConvolutionDims(linalgOp); MLIRContext *ctx = linalgOp->getContext(); - if (failed(convDimsOrFailure)) + if (failed(convDimsOrFailure)) { return failure(); + } const mlir::linalg::ConvolutionDimensions &convDims = *convDimsOrFailure; LLVM_DEBUG({ llvm::dbgs() << "conv: " << linalgOp; @@ -524,8 +525,9 @@ getIGEMMGenericConvDetails(linalg::LinalgOp linalgOp) { LDBG() << "output image or output channel dim not found in output."; return failure(); } - if (outputChannelLastDim.value() < outputImageFirstDim.value()) + if (outputChannelLastDim.value() < outputImageFirstDim.value()) { isOutputChannelFirst = true; + } SmallVector filterkPos; for (auto reductionDim : reductionDims) { @@ -620,8 +622,9 @@ getIGEMMGenericConvDetails(linalg::LinalgOp linalgOp) { // Lambda to remap conv dim indices to igemm dimensions. auto remapDims = [&](ArrayRef dims) -> SmallVector { SmallVector mapped; - for (unsigned d : dims) + for (unsigned d : dims) { mapped.push_back(convToIgemmDimMap.at(d)); + } return mapped; }; @@ -721,8 +724,9 @@ static Value getSourceSkipUnary(Value value) { Operation *op = value.getDefiningOp(); while (op && op->getNumOperands() == 1) { auto iface = dyn_cast(op); - if (!iface || !iface.hasNoEffect()) + if (!iface || !iface.hasNoEffect()) { break; + } value = op->getOperand(0); op = value.getDefiningOp(); } @@ -782,13 +786,15 @@ template static bool isPairTemplateImpl(Operation *add, Operation *mul) { static_assert(sizeof...(Args) % 2 == 0, "expected an even number of template arguments"); - if (isa(add) && isa(mul)) + if (isa(add) && isa(mul)) { return true; + } - if constexpr (sizeof...(Args) > 0) + if constexpr (sizeof...(Args) > 0) { return isPairTemplateImpl(add, mul); - else + } else { return false; + } } /// Returns true if the block is a body of a contraction with the kinds of @@ -918,19 +924,22 @@ bool isArgmaxOp(linalg::GenericOp genericOp) { // TODO: Add better affine map checks. auto indexing_maps = genericOp.getIndexingMapsArray(); - if (!indexing_maps[0].isIdentity()) + if (!indexing_maps[0].isIdentity()) { return false; + } // Check that initial value is negative Infinite. // TODO: Move this check to ukernel once we implement // variant to handle non neg-Inf initial value. Value initVal = genericOp.getDpsInitOperand(0)->get(); auto fillOp = initVal.getDefiningOp(); - if (!fillOp) + if (!fillOp) { return false; + } Value fillVal = fillOp.getDpsInputOperand(0)->get(); - if (!matchPattern(fillVal, m_NegInfFloat())) + if (!matchPattern(fillVal, m_NegInfFloat())) { return false; + } // Work back from linalg.yield and check body of genericOp. // The genericOp should yield the result of an arith.select, @@ -965,13 +974,15 @@ bool isArgmaxOp(linalg::GenericOp genericOp) { } auto selectOp = cast(producerOutput.getDefiningOp()); Value trueVal = selectOp.getTrueValue(); - if (auto castOp = trueVal.getDefiningOp()) + if (auto castOp = trueVal.getDefiningOp()) { trueVal = castOp.getIn(); + } // Ensure the true value is directly produced by linalg.index. auto indexOp = trueVal.getDefiningOp(); - if (!indexOp) + if (!indexOp) { return false; + } } // Producer of arith.select op is arith.cmpf @@ -1034,11 +1045,13 @@ bool isPureBatchMatmul(Operation *op) { // it requires a single input where the indexing maps are full permutations and // non-equal. bool isaTransposeOpInterface(linalg::LinalgOp linalgOp) { - if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops()) + if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops()) { return false; + } - if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1) + if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1) { return false; + } auto mapRange = linalgOp.getIndexingMapsArray(); if (mapRange.size() != 2 || !mapRange.front().isPermutation() || !mapRange.back().isPermutation() || mapRange.front() == mapRange.back()) { diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp index 367162905ba8..31c5664d00be 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp @@ -156,8 +156,9 @@ LogicalResult Partition::verify(Location loc) { for (auto in : ins) { // Only check ops, not bare values. auto definingOp = in.getDefiningOp(); - if (!definingOp) + if (!definingOp) { continue; + } // Collect all values used by this input op (including nested regions). SetVector inputConsumedValues; @@ -216,8 +217,9 @@ LogicalResult PartitionSet::verify(Location loc) { } void PartitionSet::topologicalSort() { - if (partitions.empty()) + if (partitions.empty()) { return; + } SetVector unsortedSet; DenseMap> consumers; @@ -246,8 +248,9 @@ void PartitionSet::topologicalSort() { } } }; - for (auto *partition : unsortedSet) + for (auto *partition : unsortedSet) { postorderWalk(partition); + } SmallVector sortedSet; sortedSet.reserve(partitions.size()); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp index 0ae3d0d0997c..9f4c0e5ce7f9 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp @@ -64,8 +64,9 @@ struct PartitionBuilder { : affinityOp.getAffinityAttr(); } opInfo.membership.set(ordinal); - if (opInfo.hazards.size() > ordinal) + if (opInfo.hazards.size() > ordinal) { opInfo.hazards.reset(ordinal); + } ops.insert(op); hazards |= opInfo.hazards; hazards |= opInfo.nestedRegionHazards; @@ -497,8 +498,9 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config, // First see which partitions are consuming this that we can also safely // move in to. consumers &= candidates; - if (consumers.any()) + if (consumers.any()) { candidates = consumers; + } opInfo.membership.reserve(builders.size() + 1); opInfo.membership.resize(builders.size(), /*t=*/false); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp index cf0fe20d7339..268531a62615 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp @@ -132,37 +132,50 @@ class AbstractResourceUsage const std::string getAsStr(AsmState &asmState) const override { std::string str; - if (!isValidState()) + if (!isValidState()) { return "*"; + } auto append = [&](const char *part) { - if (!str.empty()) + if (!str.empty()) { str += '|'; + } str += part; }; - if (!this->isAssumed(NOT_INDIRECT)) + if (!this->isAssumed(NOT_INDIRECT)) { append("indirect"); + } append(this->isAssumed(NOT_EXTERNAL) ? "internal" : "external"); append(this->isAssumed(NOT_MUTATED) ? "immutable" : "mutable"); - if (!this->isAssumed(NOT_CONSTANT)) + if (!this->isAssumed(NOT_CONSTANT)) { append("constant"); - if (!this->isAssumed(NOT_TRANSFER_READ)) + } + if (!this->isAssumed(NOT_TRANSFER_READ)) { append("transfer_read"); - if (!this->isAssumed(NOT_TRANSFER_WRITE)) + } + if (!this->isAssumed(NOT_TRANSFER_WRITE)) { append("transfer_write"); - if (!this->isAssumed(NOT_STAGING_READ)) + } + if (!this->isAssumed(NOT_STAGING_READ)) { append("staging_read"); - if (!this->isAssumed(NOT_STAGING_WRITE)) + } + if (!this->isAssumed(NOT_STAGING_WRITE)) { append("staging_write"); - if (!this->isAssumed(NOT_DISPATCH_READ)) + } + if (!this->isAssumed(NOT_DISPATCH_READ)) { append("dispatch_read"); - if (!this->isAssumed(NOT_DISPATCH_WRITE)) + } + if (!this->isAssumed(NOT_DISPATCH_WRITE)) { append("dispatch_write"); - if (!this->isAssumed(NOT_GLOBAL_READ)) + } + if (!this->isAssumed(NOT_GLOBAL_READ)) { append("global_read"); - if (!this->isAssumed(NOT_GLOBAL_WRITE)) + } + if (!this->isAssumed(NOT_GLOBAL_WRITE)) { append("global_write"); - if (!this->isAssumed(NOT_GLOBAL_STORAGE)) + } + if (!this->isAssumed(NOT_GLOBAL_STORAGE)) { append("global_storage"); + } return str.empty() ? "*" : str; } @@ -250,8 +263,9 @@ class ValueResourceUsage : public AbstractResourceUsage { // itself is under analysis. void updateFromDefiningOp(Value value, OpResult result, DFX::Solver &solver) { // Some tied uses route through ops that change types - ignore those. - if (!isa(result.getType())) + if (!isa(result.getType())) { return; + } TypeSwitch(result.getOwner()) .Case([&](mlir::arith::SelectOp op) { @@ -552,8 +566,9 @@ class ValueResourceUsage : public AbstractResourceUsage { // This walks through tied uses as well. void updateFromUse(Value value, OpOperand &operand, DFX::Solver &solver) { // Some tied uses route through ops that change types - ignore those. - if (!isa(operand.get().getType())) + if (!isa(operand.get().getType())) { return; + } auto *userOp = operand.getOwner(); unsigned operandIdx = operand.getOperandNumber(); @@ -977,8 +992,9 @@ std::optional ResourceUsageAnalysis::tryLookupResourceUsage(Value value) { auto resourceUsage = solver.lookupElementFor(Position::forValue(value)); - if (!resourceUsage) + if (!resourceUsage) { return std::nullopt; + } return resourceUsage->getAssumedUsage(); } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp index ae07021b2b49..f316560f9570 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp @@ -85,8 +85,9 @@ struct ConvertTensorDynamicConstantOp IREE::Stream::AffinityAttr executionAffinityAttr, ConversionPatternRewriter &rewriter) const override { auto attrType = dyn_cast(constantOp.getValue().getType()); - if (!attrType) + if (!attrType) { return failure(); + } auto resultType = constantOp.getType(); // If the op is acting as a dynamic value then preserve that behavior by @@ -355,13 +356,16 @@ struct ConvertTensorUpdateOp }; static bool isScalarTensor(RankedTensorType type) { - if (type.getRank() == 0) + if (type.getRank() == 0) { return true; // tensor - if (!type.hasStaticShape()) + } + if (!type.hasStaticShape()) { return false; // tensor<...?...xi32> + } int64_t elementCount = 1; - for (int64_t dim : type.getShape()) + for (int64_t dim : type.getShape()) { elementCount *= dim; + } return elementCount == 1; // tensor<1xi32> or tensor<1x1x1xi32> } @@ -1002,8 +1006,9 @@ static bool insertBindingOp(BlockArgument arg, IREE::TensorExt::DispatchTensorType tensorType, Value zero, OpBuilder &builder) { // No uses: don't need a binding op. - if (arg.use_empty()) + if (arg.use_empty()) { return true; + } // Find the dynamic dimension SSA values of the argument within the region. // If the flow dialect properly modeled dimension associations we wouldn't @@ -1018,8 +1023,9 @@ static bool insertBindingOp(BlockArgument arg, IREE::Flow::DispatchTieShapeOp tieShapeOp; for (auto user : arg.getUsers()) { tieShapeOp = dyn_cast(user); - if (tieShapeOp) + if (tieShapeOp) { break; + } } if (tieShapeOp) { // Found a tie shape op - we'll insert ourselves there. @@ -1125,8 +1131,9 @@ struct ConvertExecutableOp // Dispatch tensor arguments become bindings and all others are preserved // as adaptor. Note that we only touch public (exported) functions. for (auto funcOp : moduleOp.getOps()) { - if (!funcOp.isPublic()) + if (!funcOp.isPublic()) { continue; + } SmallVector newTypes; newTypes.reserve(funcOp.getNumArguments()); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/BUILD.bazel index eb5c9b643c5d..7ad8f03376e4 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "call_ops.mlir", "collective_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp index 18f8f4cbdd07..1bb8c5ecc6bd 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp @@ -20,8 +20,9 @@ namespace { /// Flatten the given value ranges into a single vector of values. static SmallVector flattenValues(ArrayRef values) { SmallVector result; - for (const auto &vals : values) + for (const auto &vals : values) { llvm::append_range(result, vals); + } return result; } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/BUILD.bazel index 34c6442f57a3..e1497efcb477 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "abi_ops.mlir", ], diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp index 2dd7777dd523..a87455290614 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp @@ -14,8 +14,9 @@ namespace mlir::iree_compiler { TypedAttr convertAttributeToStream(TypedAttr attr) { - if (!attr) + if (!attr) { return {}; + } if (auto parameterAttr = dyn_cast(attr)) { return IREE::Stream::NamedParameterAttr::get( attr.getContext(), parameterAttr.getType(), parameterAttr.getScope(), diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.cpp index 8d90a92090ca..253c3b3d338d 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.cpp @@ -32,8 +32,9 @@ namespace { /// Flatten the given value ranges into a single vector of values. static SmallVector flattenValues(ArrayRef values) { SmallVector result; - for (const auto &vals : values) + for (const auto &vals : values) { llvm::append_range(result, vals); + } return result; } @@ -130,8 +131,9 @@ struct SelectOpConversion matchAndRewrite(mlir::arith::SelectOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Only handle selects where the operands are tensors (resources). - if (!isa(op.getTrueValue().getType())) + if (!isa(op.getTrueValue().getType())) { return failure(); + } auto trueOperand = resolveTensorOperands(op.getLoc(), op.getTrueValue(), adaptor.getTrueValue(), rewriter); auto falseOperand = resolveTensorOperands( diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/test/BUILD.bazel index 54fbdc060085..cd8d750cc3fa 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "constant_ops.mlir", "structural_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp index de0497e3f439..58bc0c85e097 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp @@ -22,8 +22,9 @@ namespace { /// Flatten the given value ranges into a single vector of values. static SmallVector flattenValues(ArrayRef values) { SmallVector result; - for (const auto &vals : values) + for (const auto &vals : values) { llvm::append_range(result, vals); + } return result; } @@ -99,8 +100,9 @@ struct CallOpConversion }, [&](unsigned i, Type type, SmallVectorImpl &newTypes) { size_t newIndex = newTypes.size(); - if (failed(getTypeConverter()->convertType(type, newTypes))) + if (failed(getTypeConverter()->convertType(type, newTypes))) { anyFailed = true; + } resultMap.push_back(Result{i, newIndex, newTypes[newIndex]}); }, rewriter); @@ -158,8 +160,9 @@ struct GlobalExpansionState { }; static bool isExpandedType(Type type) { - if (isa(type)) + if (isa(type)) { return true; + } if (auto ptrType = dyn_cast(type)) { return isExpandedType(ptrType); } @@ -190,8 +193,9 @@ struct GlobalOpExpansion matchAndRewrite(IREE::Util::GlobalOp globalOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Only apply to expanded types (tensors/etc). - if (!isExpandedType(globalOp.getType())) + if (!isExpandedType(globalOp.getType())) { return failure(); + } SmallVector newTypes; if (failed(getTypeConverter()->convertType(globalOp.getType(), newTypes))) { @@ -297,13 +301,15 @@ struct GlobalLoadOpExpansion matchAndRewrite(IREE::Util::GlobalLoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Only apply to expanded types (tensors/etc). - if (!isExpandedType(loadOp.getType())) + if (!isExpandedType(loadOp.getType())) { return failure(); + } auto expandedGlobalIt = this->expansionState->globalMap.find(adaptor.getGlobal()); - if (expandedGlobalIt == this->expansionState->globalMap.end()) + if (expandedGlobalIt == this->expansionState->globalMap.end()) { return rewriter.notifyMatchFailure(loadOp, "expanded global not found"); + } auto &expandedGlobal = expandedGlobalIt->getSecond(); @@ -336,13 +342,15 @@ struct GlobalStoreOpExpansion matchAndRewrite(IREE::Util::GlobalStoreOp storeOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Only apply to expanded types (tensors/etc). - if (!isExpandedType(storeOp.getValue().getType())) + if (!isExpandedType(storeOp.getValue().getType())) { return failure(); + } auto expandedGlobalIt = this->expansionState->globalMap.find(adaptor.getGlobal()); - if (expandedGlobalIt == this->expansionState->globalMap.end()) + if (expandedGlobalIt == this->expansionState->globalMap.end()) { return rewriter.notifyMatchFailure(storeOp, "expanded global not found"); + } auto &expandedGlobal = expandedGlobalIt->getSecond(); @@ -430,8 +438,9 @@ void populateUtilToStreamConversionPatterns( typeConverter.addConversion([=](IREE::Util::PtrType type, SmallVectorImpl &resultTypes) { // Expand pointers to tensors to [resource, sizeof resource] pointers. - if (!isExpandedType(type)) + if (!isExpandedType(type)) { return failure(); + } resultTypes.push_back( IREE::Util::PtrType::get(IREE::Stream::ResourceType::get(context))); resultTypes.push_back(IREE::Util::PtrType::get(IndexType::get(context))); @@ -441,8 +450,9 @@ void populateUtilToStreamConversionPatterns( typeConverter.addConversion( [=](IREE::Util::PtrType type, SmallVectorImpl &resultTypes) { // Expand pointers to tensors to [ptr, ptr]. - if (!isExpandedType(type.getTargetType())) + if (!isExpandedType(type.getTargetType())) { return failure(); + } resultTypes.push_back(IREE::Stream::ResourceType::get(context)); resultTypes.push_back(IndexType::get(context)); return success(); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/BUILD.bazel index ceb82144e9f2..eecd75ff1755 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "compiler_hints.mlir", "global_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/IR/BUILD.bazel index 91244ba3b01e..33c213a8a6e4 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "StreamAttrs.td", "StreamBase.td", diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamDialect.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamDialect.cpp index 32f357808ef9..6424b8fa9ceb 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamDialect.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamDialect.cpp @@ -69,8 +69,9 @@ struct StripResourceConversionCastPattern LogicalResult matchAndRewrite(UnrealizedConversionCastOp castOp, PatternRewriter &rewriter) const override { auto result = castOp.getResult(0); - if (!isa(result.getType())) + if (!isa(result.getType())) { return failure(); + } assert(castOp.getNumOperands() == 2 && "expect resource, index -> resource"); auto resourceValue = castOp.getOperand(0); diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td index e9977dde4470..d96ab6899f5a 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td @@ -213,6 +213,18 @@ def Stream_AffinityOp : Stream_OpInterface<"AffinityOpInterface"> { return $_op.getAffinityAttr(); }] >, + InterfaceMethod< + /*desc=*/[{ + Removes all affinities specified on the op. + }], + /*retTy=*/"void", + /*methodName=*/"removeAffinityAttrs", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + $_op.setAffinityAttr(nullptr); + }] + >, ]; } diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp index 3dddcd632de7..148877ddeba3 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp @@ -44,10 +44,12 @@ namespace { // 0xCDCDCDCD : i32 -> 0xCD : i8 static APInt computeRequiredPatternBits(APInt pattern) { // Special case for well-known constant values. - if (pattern.isZero()) + if (pattern.isZero()) { return APInt(8, 0u); - if (pattern.isAllOnes()) + } + if (pattern.isAllOnes()) { return APInt(8, 0xFF); + } // Extend up to a power of two bit width. This makes the value easier to work // with as we'll be dealing with one of 4 sizes (1/2/4/8b). @@ -142,8 +144,9 @@ static TypedAttr tryNarrowPatternBits(TypedAttr patternAttr) { // Try narrowing the pattern. auto newPattern = computeRequiredPatternBits(oldPattern); - if (newPattern.getBitWidth() == oldPattern.getBitWidth()) + if (newPattern.getBitWidth() == oldPattern.getBitWidth()) { return patternAttr; + } // Wrap the result in an attribute - note that it is always an integer. return IntegerAttr::get( @@ -163,8 +166,9 @@ struct NarrowFillPattern : public OpRewritePattern { return failure(); } auto newPatternAttr = tryNarrowPatternBits(oldPatternAttr); - if (newPatternAttr == oldPatternAttr) + if (newPatternAttr == oldPatternAttr) { return failure(); + } // Replace the pattern on the op with the new one. auto narrowValue = @@ -182,13 +186,16 @@ struct NarrowFillPattern : public OpRewritePattern { // stream.yield // } static std::optional getYieldIfOnlyOp(Block &block) { - if (block.empty()) + if (block.empty()) { return std::nullopt; - if (&block.front() != &block.back()) + } + if (&block.front() != &block.back()) { return std::nullopt; + } auto yieldOp = dyn_cast(block.back()); - if (yieldOp) + if (yieldOp) { return yieldOp; + } return std::nullopt; } @@ -250,14 +257,16 @@ static bool canStablySinkTo(Operation *toBeSunkOp, Operation *targetOp) { // If the sinking operation would be a no-op, then we need to prevent // the sinking operation, to avoid infinite pattern applications. - if (Block::iterator(targetOp) == std::next(Block::iterator(toBeSunkOp))) + if (Block::iterator(targetOp) == std::next(Block::iterator(toBeSunkOp))) { return false; + } // If the sinking is to a different block, then it okay, since for any later // sinkings, this reduces the problem to stable sinking within a single // block (handled below). - if (toBeSunkOp->getBlock() != targetOp->getBlock()) + if (toBeSunkOp->getBlock() != targetOp->getBlock()) { return true; + } SmallPtrSet producerOps; if (allowUseDefPruning) { @@ -274,11 +283,13 @@ static bool canStablySinkTo(Operation *toBeSunkOp, Operation *targetOp) { Block::iterator(targetOp))) { // If the intervening op that is not even a sink candidate itself, // then it cannot fight. - if (!isSinkCandidate(&op)) + if (!isSinkCandidate(&op)) { return true; + } // If the op is pruned by use-def chains, then it won't fight. - if (allowUseDefPruning && !producerOps.contains(&op)) + if (allowUseDefPruning && !producerOps.contains(&op)) { return true; + } } return false; } @@ -286,8 +297,9 @@ static bool canStablySinkTo(Operation *toBeSunkOp, Operation *targetOp) { // Sinks |op| down to |targetOp|, ensuring that we don't oscillate. // Returns success if the op was sunk and failure if sinking was not needed. static LogicalResult sinkOp(Operation *op, Operation *targetOp) { - if (!canStablySinkTo(op, targetOp)) + if (!canStablySinkTo(op, targetOp)) { return failure(); + } op->moveBefore(targetOp); return success(); } @@ -319,8 +331,9 @@ struct ElideUnusedOp : public OpRewritePattern { : OpRewritePattern(context, /*benefit=*/1000) {} LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const override { - if (!op.use_empty()) + if (!op.use_empty()) { return failure(); + } rewriter.eraseOp(op); return success(); } @@ -447,8 +460,9 @@ struct ElideImmediateTimepointWait : public OpRewritePattern { bool isImmediate = op.getAwaitTimepoint() && isa_and_nonnull( op.getAwaitTimepoint().getDefiningOp()); - if (!isImmediate) + if (!isImmediate) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.getAwaitTimepointMutable().clear(); }); return success(); @@ -482,8 +496,9 @@ struct ChainDependentAwaits : public OpRewritePattern { } } } - if (replacements.empty()) + if (replacements.empty()) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.setAwaitTimepoints(newTimepoints, rewriter); for (auto replacement : replacements) { @@ -712,8 +727,9 @@ struct SelectResourceSizeOp : public OpRewritePattern { LogicalResult matchAndRewrite(ResourceSizeOp op, PatternRewriter &rewriter) const override { auto selectOp = op.getOperand().getDefiningOp(); - if (!selectOp) + if (!selectOp) { return failure(); + } auto trueSize = rewriter.createOrFold( op.getLoc(), selectOp.getTrueValue(), op.getAffinityAttr()); auto falseSize = rewriter.createOrFold( @@ -761,8 +777,9 @@ struct FoldSubviewIntoLoadOp : public OpRewritePattern { LogicalResult matchAndRewrite(ResourceLoadOp op, PatternRewriter &rewriter) const override { auto subviewOp = ResourceSubviewOp::findSubviewOp(op.getSource()); - if (!subviewOp) + if (!subviewOp) { return failure(); + } auto fusedLoc = rewriter.getFusedLoc({subviewOp.getLoc(), op.getLoc()}); auto newOffset = rewriter.createOrFold( fusedLoc, subviewOp.getSourceOffset(), op.getSourceOffset()); @@ -806,8 +823,9 @@ struct FoldSubviewIntoStoreOp : public OpRewritePattern { LogicalResult matchAndRewrite(ResourceStoreOp op, PatternRewriter &rewriter) const override { auto subviewOp = ResourceSubviewOp::findSubviewOp(op.getTarget()); - if (!subviewOp) + if (!subviewOp) { return failure(); + } auto fusedLoc = rewriter.getFusedLoc({subviewOp.getLoc(), op.getLoc()}); auto newOffset = rewriter.createOrFold( fusedLoc, subviewOp.getSourceOffset(), op.getTargetOffset()); @@ -873,8 +891,9 @@ struct PropagateResourcePackBaseOffset PatternRewriter &rewriter) const override { // Offset is optional. auto baseOffset = op.getOffset(); - if (!baseOffset) + if (!baseOffset) { return failure(); + } // We always strip the offset here. rewriter.modifyOpInPlace(op, [&]() { op.getOffsetMutable().clear(); }); @@ -932,8 +951,9 @@ struct CanonicalizeResourcePackIntervals break; } } - if (!orderChanged) + if (!orderChanged) { return failure(); + } // TODO(benvanik): compact the slice ranges. @@ -993,8 +1013,9 @@ struct FoldResourceSubviewOps : public OpRewritePattern { LogicalResult matchAndRewrite(ResourceSubviewOp op, PatternRewriter &rewriter) const override { auto parentOp = ResourceSubviewOp::findSubviewOp(op.getSource()); - if (!parentOp) + if (!parentOp) { return failure(); + } auto fusedLoc = rewriter.getFusedLoc({parentOp.getLoc(), op.getLoc()}); auto newOffset = rewriter.createOrFold( fusedLoc, parentOp.getSourceOffset(), op.getSourceOffset()); @@ -1021,14 +1042,16 @@ struct SinkSubviewAcrossSelectOps using Base::Base; LogicalResult matchAndRewrite(mlir::arith::SelectOp op, PatternRewriter &rewriter) const override { - if (!isa(op.getType())) + if (!isa(op.getType())) { return failure(); + } auto trueSubview = dyn_cast_if_present( op.getTrueValue().getDefiningOp()); auto falseSubview = dyn_cast_if_present( op.getFalseValue().getDefiningOp()); - if (!trueSubview || !falseSubview) + if (!trueSubview || !falseSubview) { return failure(); + } if (trueSubview.getSource() != falseSubview.getSource() || trueSubview.getResultSize() != falseSubview.getResultSize()) { return failure(); @@ -1134,8 +1157,9 @@ struct TensorConstantToEmpty : public OpRewritePattern { LogicalResult matchAndRewrite(TensorConstantOp constantOp, PatternRewriter &rewriter) const override { auto shapedType = dyn_cast(constantOp.getResultEncoding()); - if (!shapedType) + if (!shapedType) { return failure(); + } // See if any dim (including dynamic ones) is known zero. // It's still possible for empty tensors to slip through if their dynamic @@ -1155,8 +1179,9 @@ struct TensorConstantToEmpty : public OpRewritePattern { break; } } - if (!anyZeroDims) + if (!anyZeroDims) { return failure(); + } // Definitely empty if here. Value resultSize = IREE::Stream::TensorSizeOfOp::create( @@ -1383,8 +1408,9 @@ struct DeduplicateTensorDispatchEntryRefs final PatternRewriter &rewriter) const override { auto originalAttr = dispatchOp.getEntryPointsAttr(); auto newAttr = deduplicateArrayElements(originalAttr); - if (newAttr == originalAttr) + if (newAttr == originalAttr) { return failure(); + } rewriter.modifyOpInPlace(dispatchOp, [&]() { dispatchOp.setEntryPointsAttr(newAttr); }); return success(); @@ -1414,8 +1440,9 @@ struct SinkAllocaLikeOpToConsumers : public OpRewritePattern { LogicalResult matchAndRewrite(Op producerOp, PatternRewriter &rewriter) const override { auto users = llvm::to_vector(producerOp->getUsers()); - if (users.size() == 0) + if (users.size() == 0) { return failure(); + } // If we have a single user then we can sink right to it. if (users.size() == 1) { @@ -1576,8 +1603,9 @@ struct PropagateSplatsThroughSlices : public OpRewritePattern { PatternRewriter &rewriter) const override { auto splatOp = sliceOp.getSource().getDefiningOp(); - if (!splatOp) + if (!splatOp) { return failure(); + } rewriter.replaceOpWithNewOp( sliceOp, sliceOp.getResult().getType(), splatOp.getValue(), sliceOp.getResultSize(), sliceOp.getAffinityAttr(), @@ -1615,8 +1643,9 @@ struct FlattenFullFillToSplat : public OpRewritePattern { using Base::Base; LogicalResult matchAndRewrite(AsyncFillOp fillOp, PatternRewriter &rewriter) const override { - if (fillOp.getTargetLength() != fillOp.getTargetSize()) + if (fillOp.getTargetLength() != fillOp.getTargetSize()) { return failure(); + } auto targetOp = fillOp.getTarget().getDefiningOp(); if (!targetOp || IREE::Util::TiedOpInterface::findTiedBaseValue( @@ -1647,8 +1676,9 @@ struct ElideRedundantFill : public OpRewritePattern { PatternRewriter &rewriter) const override { auto splatOp = dyn_cast_if_present( fillOp.getTarget().getDefiningOp()); - if (!splatOp) + if (!splatOp) { return failure(); + } if (splatOp.getValue() != fillOp.getValue()) { return rewriter.notifyMatchFailure(fillOp, "fill patterns are not compatible"); @@ -1678,8 +1708,9 @@ struct CoalesceAdjacentFills : public OpRewritePattern { PatternRewriter &rewriter) const override { auto sourceOp = dyn_cast_if_present( fillOp.getTarget().getDefiningOp()); - if (!sourceOp) + if (!sourceOp) { return failure(); + } if (!sourceOp.getResult().hasOneUse()) { // Note that hazard analysis could make this work if we can guarantee that // the source result is only ever sliced out to a range that doesn't @@ -1757,20 +1788,23 @@ static bool hasValueSemantics(Value value) { // Can't analyze function arguments (though we could add arg attrs to indicate // value semantics). auto *definingOp = value.getDefiningOp(); - if (!definingOp) + if (!definingOp) { return false; + } // If produced by a tied op then see if the particular result is tied. if (auto tiedOp = dyn_cast(definingOp)) { - if (tiedOp.getTiedResultOperand(value)) + if (tiedOp.getTiedResultOperand(value)) { return false; + } } // To be conservative we only allow stream dialect ops that produce the // resource as we know they all indicate value semantics when non-tied - ops // from other dialects may not. - if (!definingOp->hasTrait()) + if (!definingOp->hasTrait()) { return false; + } return true; } @@ -1894,8 +1928,9 @@ struct CombineSplatUpdateFromToFill : public OpRewritePattern { PatternRewriter &rewriter) const override { auto splatOp = updateOp.getUpdate().getDefiningOp(); - if (!splatOp) + if (!splatOp) { return failure(); + } rewriter.replaceOpWithNewOp( updateOp, updateOp.getResult().getType(), updateOp.getTarget(), updateOp.getTargetSize(), updateOp.getTargetOffset(), @@ -2078,12 +2113,14 @@ struct IntermediateTransferElision : public OpRewritePattern { auto source = originTransferOp.getSource(); auto previousTransferOp = dyn_cast_if_present(source.getDefiningOp()); - if (!previousTransferOp) + if (!previousTransferOp) { break; + } originTransferOp = previousTransferOp; } - if (originTransferOp == transferOp) + if (originTransferOp == transferOp) { return failure(); + } rewriter.replaceOpWithNewOp( transferOp, transferOp.getResult().getType(), originTransferOp.getSource(), originTransferOp.getSourceSize(), @@ -2116,12 +2153,14 @@ struct FoldAsyncLoadBitcast : public OpRewritePattern { LogicalResult matchAndRewrite(AsyncLoadOp loadOp, PatternRewriter &rewriter) const override { auto loadedValue = loadOp.getResult(); - if (!loadedValue.hasOneUse()) + if (!loadedValue.hasOneUse()) { return failure(); + } auto bitcastOp = dyn_cast(*loadedValue.getUsers().begin()); - if (!bitcastOp) + if (!bitcastOp) { return failure(); + } rewriter.modifyOpInPlace( loadOp, [&]() { loadedValue.setType(bitcastOp.getType()); }); rewriter.replaceOp(bitcastOp, loadedValue); @@ -2187,8 +2226,9 @@ struct DeduplicateAsyncDispatchEntryRefs final PatternRewriter &rewriter) const override { auto originalAttr = dispatchOp.getEntryPointsAttr(); auto newAttr = deduplicateArrayElements(originalAttr); - if (newAttr == originalAttr) + if (newAttr == originalAttr) { return failure(); + } rewriter.modifyOpInPlace(dispatchOp, [&]() { dispatchOp.setEntryPointsAttr(newAttr); }); return success(); @@ -2235,13 +2275,15 @@ struct CloneCapturedAsyncExecuteSubviewOps SmallVector captures; for (auto operand : llvm::enumerate(op.getResourceOperands())) { auto subviewOp = ResourceSubviewOp::findSubviewOp(operand.value()); - if (!subviewOp) + if (!subviewOp) { continue; + } captures.push_back( SubviewCapture{static_cast(operand.index()), subviewOp}); } - if (captures.empty()) + if (captures.empty()) { return failure(); + } rewriter.startOpModification(op); auto &entryBlock = op.getBody().front(); @@ -2383,8 +2425,9 @@ findConsumerThroughAwait(Value timelineResult) { for (auto [resource, result] : llvm::zip_equal(awaitOp.getResourceOperands(), awaitOp.getResults())) { if (resource == timelineResult) { - if (!result.hasOneUse()) + if (!result.hasOneUse()) { return {nullptr, nullptr}; + } return {*result.getUsers().begin(), result}; } } @@ -2867,8 +2910,9 @@ struct FoldSubviewsIntoCmdFlushOp : public OpRewritePattern { LogicalResult matchAndRewrite(CmdFlushOp op, PatternRewriter &rewriter) const override { auto subviewOp = ResourceSubviewOp::findSubviewOp(op.getTarget()); - if (!subviewOp) + if (!subviewOp) { return failure(); + } setInsertionPointToParentExecutionScope(op, rewriter); auto fusedLoc = rewriter.getFusedLoc({subviewOp.getLoc(), op.getLoc()}); auto newOffset = rewriter.createOrFold( @@ -2909,8 +2953,9 @@ struct FoldSubviewsIntoCmdInvalidateOp LogicalResult matchAndRewrite(CmdInvalidateOp op, PatternRewriter &rewriter) const override { auto subviewOp = ResourceSubviewOp::findSubviewOp(op.getTarget()); - if (!subviewOp) + if (!subviewOp) { return failure(); + } setInsertionPointToParentExecutionScope(op, rewriter); auto fusedLoc = rewriter.getFusedLoc({subviewOp.getLoc(), op.getLoc()}); auto newOffset = rewriter.createOrFold( @@ -2950,8 +2995,9 @@ struct FoldSubviewsIntoCmdDiscardOp : public OpRewritePattern { LogicalResult matchAndRewrite(CmdDiscardOp op, PatternRewriter &rewriter) const override { auto subviewOp = ResourceSubviewOp::findSubviewOp(op.getTarget()); - if (!subviewOp) + if (!subviewOp) { return failure(); + } setInsertionPointToParentExecutionScope(op, rewriter); auto fusedLoc = rewriter.getFusedLoc({subviewOp.getLoc(), op.getLoc()}); auto newOffset = rewriter.createOrFold( @@ -2991,8 +3037,9 @@ struct FoldSubviewsIntoCmdFillOp : public OpRewritePattern { LogicalResult matchAndRewrite(CmdFillOp op, PatternRewriter &rewriter) const override { auto subviewOp = ResourceSubviewOp::findSubviewOp(op.getTarget()); - if (!subviewOp) + if (!subviewOp) { return failure(); + } setInsertionPointToParentExecutionScope(op, rewriter); auto fusedLoc = rewriter.getFusedLoc({subviewOp.getLoc(), op.getLoc()}); auto newOffset = rewriter.createOrFold( @@ -3034,8 +3081,9 @@ struct FoldSubviewsIntoCmdCopyOp : public OpRewritePattern { PatternRewriter &rewriter) const override { auto sourceSubviewOp = ResourceSubviewOp::findSubviewOp(op.getSource()); auto targetSubviewOp = ResourceSubviewOp::findSubviewOp(op.getTarget()); - if (!sourceSubviewOp && !targetSubviewOp) + if (!sourceSubviewOp && !targetSubviewOp) { return failure(); + } setInsertionPointToParentExecutionScope(op, rewriter); if (sourceSubviewOp) { auto fusedLoc = @@ -3100,19 +3148,22 @@ struct FoldSubviewsIntoDispatchOp : public OpRewritePattern { bool anySubviewOps = false; for (auto operand : op.getResources()) { auto subviewOp = ResourceSubviewOp::findSubviewOp(operand); - if (subviewOp) + if (subviewOp) { anySubviewOps = true; + } resourceSubviewOps.push_back(subviewOp); } - if (!anySubviewOps) + if (!anySubviewOps) { return failure(); + } rewriter.startOpModification(op); setInsertionPointToParentExecutionScope(op, rewriter); for (auto [resourceIndex, subviewOp] : llvm::enumerate(resourceSubviewOps)) { - if (!subviewOp) + if (!subviewOp) { continue; + } auto fusedLoc = rewriter.getFusedLoc({subviewOp.getLoc(), op.getLoc()}); auto newOffset = rewriter.createOrFold( fusedLoc, subviewOp.getSourceOffset(), @@ -3151,8 +3202,9 @@ struct DeduplicateCmdDispatchEntryRefs final PatternRewriter &rewriter) const override { auto originalAttr = dispatchOp.getEntryPointsAttr(); auto newAttr = deduplicateArrayElements(originalAttr); - if (newAttr == originalAttr) + if (newAttr == originalAttr) { return failure(); + } rewriter.modifyOpInPlace(dispatchOp, [&]() { dispatchOp.setEntryPointsAttr(newAttr); }); return success(); @@ -3187,21 +3239,24 @@ struct FoldSubviewsIntoCmdCallOp : public OpRewritePattern { llvm::enumerate(op.getResourceOperands())) { if (isa(operand.getType())) { auto subviewOp = ResourceSubviewOp::findSubviewOp(operand); - if (subviewOp) + if (subviewOp) { anySubviewOps = true; + } resourceSubviewOps.push_back({operandIndex, subviewOp}); } } - if (!anySubviewOps) + if (!anySubviewOps) { return failure(); + } rewriter.startOpModification(op); setInsertionPointToParentExecutionScope(op, rewriter); for (auto [resourceIndex, resourceSubviewOp] : llvm::enumerate(resourceSubviewOps)) { auto [operandIndex, subviewOp] = resourceSubviewOp; - if (!subviewOp) + if (!subviewOp) { continue; + } auto fusedLoc = rewriter.getFusedLoc({subviewOp.getLoc(), op.getLoc()}); auto newOffset = rewriter.createOrFold( fusedLoc, subviewOp.getSourceOffset(), @@ -3258,13 +3313,15 @@ struct CloneCapturedCmdExecuteSubviewOps SmallVector captures; for (auto operand : llvm::enumerate(op.getResourceOperands())) { auto subviewOp = ResourceSubviewOp::findSubviewOp(operand.value()); - if (!subviewOp) + if (!subviewOp) { continue; + } captures.push_back( SubviewCapture{static_cast(operand.index()), subviewOp}); } - if (captures.empty()) + if (captures.empty()) { return failure(); + } rewriter.startOpModification(op); auto &entryBlock = op.getBody().front(); @@ -3414,8 +3471,9 @@ struct FoldParameterLoadTargetSubviews } rewriter.restoreInsertionPoint(ip); - if (!needsUpdate) + if (!needsUpdate) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.getSourceOffsetsMutable().assign(newSourceOffsets); op.getResultSizesMutable().assign(newResultSizes); @@ -3465,8 +3523,9 @@ struct FoldParameterReadTargetSubview needsUpdate = true; } rewriter.restoreInsertionPoint(ip); - if (!needsUpdate) + if (!needsUpdate) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.getSourceOffsetMutable().assign(newSourceOffset); op.getTargetMutable().assign(newTargetResource); @@ -3518,8 +3577,9 @@ struct FoldParameterWriteSourceSubview needsUpdate = true; } rewriter.restoreInsertionPoint(ip); - if (!needsUpdate) + if (!needsUpdate) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.getSourceMutable().assign(newSourceResource); op.getSourceSizeMutable().assign(newSourceSize); @@ -3768,8 +3828,9 @@ struct ElideImmediateTimepointJoinOperands newTimepoints.push_back(timepoint); } } - if (newTimepoints.size() == op.getAwaitTimepoints().size()) + if (newTimepoints.size() == op.getAwaitTimepoints().size()) { return failure(); + } if (newTimepoints.empty()) { // Fully immediate; replace entire join with immediate. rewriter.replaceOpWithNewOp( @@ -3790,8 +3851,9 @@ struct FoldDuplicateTimepointJoinOperands SetVector newTimepoints; newTimepoints.insert(op.getAwaitTimepoints().begin(), op.getAwaitTimepoints().end()); - if (newTimepoints.size() == op.getAwaitTimepoints().size()) + if (newTimepoints.size() == op.getAwaitTimepoints().size()) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.getAwaitTimepointsMutable().assign(newTimepoints.takeVector()); }); @@ -3821,8 +3883,9 @@ struct ExpandTimepointJoinOperands : public OpRewritePattern { newTimepoints.insert(timepoint); } } - if (!didExpand) + if (!didExpand) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.getAwaitTimepointsMutable().assign(newTimepoints.takeVector()); }); @@ -3853,8 +3916,9 @@ static bool isSourceImmediatelyResolved(Value resource) { // TODO(benvanik): data flow analysis/at least walk up tied ops. For now we // err on the conservative side and only check for a few common scenarios. auto *definingOp = resource.getDefiningOp(); - if (!definingOp) + if (!definingOp) { return false; + } return TypeSwitch(definingOp) .Case( [](auto op) { return true; }) @@ -3902,8 +3966,9 @@ findSourceAwaitOp(Value resource) { } } auto tiedValue = definingOp.getTiedResultOperand(baseResource); - if (!tiedValue) + if (!tiedValue) { break; + } baseResource = tiedValue; } return {nullptr, nullptr}; @@ -3925,8 +3990,9 @@ struct ChainTimepoints : public OpRewritePattern { // Try to find an await op. This may traverse through any number of tied ops // along the way. auto [awaitOp, baseResource] = findSourceAwaitOp(barrierOp.getResource()); - if (!awaitOp) + if (!awaitOp) { return failure(); + } // TODO(benvanik): move this to a pass that can do IPO. Local analysis is // insufficient for this. For now we conservatively ignore any case where @@ -4007,8 +4073,9 @@ struct SinkAwaitToFirstConsumer : public OpRewritePattern { // Its possible we are nested in an SCF region. If so the SCF operation // depends on the timepoint as a whole. Operation *owner = use.getOwner(); - while (owner && owner->getParentOp() != op->getParentOp()) + while (owner && owner->getParentOp() != op->getParentOp()) { owner = owner->getParentOp(); + } if (allUsers.insert(owner)) { auto *userBlock = owner->getBlock(); @@ -4019,8 +4086,9 @@ struct SinkAwaitToFirstConsumer : public OpRewritePattern { } } } - if (!commonDominator) + if (!commonDominator) { return failure(); + } // Find the first use within the dominator block (if any) so that we // can sink down to it. @@ -4035,8 +4103,9 @@ struct SinkAwaitToFirstConsumer : public OpRewritePattern { // If sinking to `firstUserInDominator` could result in patterns // fighting each other, then don't sink. - if (!canStablySinkTo(op, firstUserInDominator)) + if (!canStablySinkTo(op, firstUserInDominator)) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op->moveBefore(firstUserInDominator); }); @@ -4057,8 +4126,9 @@ struct SinkSubviewsAcrossAwaits : public OpRewritePattern { for (auto operand : llvm::enumerate(op.getResourceOperands())) { auto subviewOp = operand.value().getDefiningOp(); - if (!subviewOp) + if (!subviewOp) { continue; + } didChange = true; unsigned operandIdx = static_cast(operand.index()); @@ -4095,8 +4165,9 @@ struct SinkSubviewsAcrossAwaits : public OpRewritePattern { static bool areAllOperandsDefinedBy(Operation *op, Operation *insertionPoint, DominanceInfo &dominanceInfo) { for (auto operand : op->getOperands()) { - if (!dominanceInfo.dominates(operand, insertionPoint)) + if (!dominanceInfo.dominates(operand, insertionPoint)) { return false; + } } return true; } @@ -4124,15 +4195,19 @@ struct GroupAwaitsByTimepoint : public OpRewritePattern { // TODO(benvanik): make this handle joins/ties; today we get blocked // there. We rely on other canonicalizers to sink things such that // (hopefully) we get them directly accessible here. - if (use.getOwner() == op) + if (use.getOwner() == op) { continue; - if (op->getBlock() != use.getOwner()->getBlock()) + } + if (op->getBlock() != use.getOwner()->getBlock()) { continue; - if (dominanceInfo.dominates(use.getOwner(), op)) + } + if (dominanceInfo.dominates(use.getOwner(), op)) { continue; + } auto awaitOp = dyn_cast(use.getOwner()); - if (!awaitOp || awaitOp.getSync()) + if (!awaitOp || awaitOp.getSync()) { continue; + } // Ensure all dependencies of the await op are available. if (!areAllOperandsDefinedBy(awaitOp, op, dominanceInfo)) { // One or more operands is defined after op so we can't merge. @@ -4140,8 +4215,9 @@ struct GroupAwaitsByTimepoint : public OpRewritePattern { } coveredOps.push_back(awaitOp); } - if (coveredOps.empty()) + if (coveredOps.empty()) { return failure(); + } coveredOps.push_back(op); // Sort the ops by their definition order; this gives us a deterministic diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp index 1e8d8fb35ea4..cb5249d0c752 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp @@ -167,10 +167,12 @@ static LogicalResult verifyAllResourcesCaptured(Region ®ion) { availableResources.insert(result); } for (auto operand : op.getOperands()) { - if (!operand) + if (!operand) { continue; - if (!isa(operand.getType())) + } + if (!isa(operand.getType())) { continue; + } if (!availableResources.contains(operand)) { return op.emitOpError() << "used resource not listed in explicit " "captures (or produced internally)"; @@ -215,8 +217,9 @@ static void eraseStreamRegionResults(Region ®ion, ArrayRef excludedResultIndices) { for (auto &block : region.getBlocks()) { auto yieldOp = dyn_cast(block.getTerminator()); - if (!yieldOp) + if (!yieldOp) { continue; + } // HACK: there's no good way of updating the operand and size together today // - we should add a helper to the ClosureYieldOpInterface that checks for // size/shape aware traits and does this automatically. @@ -316,8 +319,9 @@ static IREE::Util::ValueAccess computeValueAccess(Value rootValue) { DenseSet processedValues; SmallVector worklist; auto enqueueValue = [&](Value value) { - if (processedValues.contains(value)) + if (processedValues.contains(value)) { return; + } processedValues.insert(value); worklist.push_back(value); }; @@ -357,8 +361,9 @@ static IREE::Util::ValueAccess computeValueAccess(Value rootValue) { if (auto tiedOp = dyn_cast(user)) { auto tiedIndices = tiedOp.getTiedResultOperandIndices(); for (int64_t tiedIndex : tiedIndices) { - if (tiedIndex == IREE::Util::TiedOpInterface::kUntiedIndex) + if (tiedIndex == IREE::Util::TiedOpInterface::kUntiedIndex) { continue; + } auto operand = user->getOperand(tiedIndex); if (operand == value) { // Tied operand. @@ -387,16 +392,19 @@ static ParseResult parseDispatchEntryPoints(OpAsmParser &parser, if (succeeded(parser.parseOptionalLBrace())) { do { SymbolRefAttr entryPointAttr; - if (failed(parser.parseAttribute(entryPointAttr))) + if (failed(parser.parseAttribute(entryPointAttr))) { return failure(); + } entryPointAttrs.push_back(entryPointAttr); } while (succeeded(parser.parseOptionalComma())); - if (failed(parser.parseRBrace())) + if (failed(parser.parseRBrace())) { return failure(); + } } else { SymbolRefAttr entryPointAttr; - if (failed(parser.parseAttribute(entryPointAttr))) + if (failed(parser.parseAttribute(entryPointAttr))) { return failure(); + } entryPointAttrs.push_back(entryPointAttr); } entryPointAttrsArray = parser.getBuilder().getArrayAttr(entryPointAttrs); @@ -434,21 +442,24 @@ static ParseResult parseEncodedResourceOperands( TypeAttr resourceEncoding; if (failed(parser.parseOperand(resources.back())) || failed(parser.parseColon()) || - failed(parser.parseAttribute(resourceEncoding))) + failed(parser.parseAttribute(resourceEncoding))) { return failure(); + } resourceEncodingAttrs.push_back(resourceEncoding); if (int64_t dynamicDimCount = cast(resourceEncoding.getValue()).getNumDynamicDims()) { if (failed(parser.parseOperandList(resourceEncodingDims, dynamicDimCount, - AsmParser::Delimiter::Braces))) + AsmParser::Delimiter::Braces))) { return failure(); + } } resourceTypes.emplace_back(); resourceSizes.emplace_back(); if (failed(parser.parseKeyword("in")) || failed(parseSizeAwareType(parser, resourceTypes.back(), - resourceSizes.back()))) + resourceSizes.back()))) { return failure(); + } } while (succeeded(parser.parseOptionalComma())); resourceEncodings = parser.getBuilder().getArrayAttr(resourceEncodingAttrs); return success(); @@ -1429,12 +1440,14 @@ static void printResourceRegion(OpAsmPrinter &p, Operation *op, p << ")"; if (!resultTypes.empty()) { p << " -> "; - if (resultTypes.size() != 1) + if (resultTypes.size() != 1) { p << "("; + } printShapedResultList(p, op, operands, operandTypes, operandSizes, resultTypes, resultSizes, tiedOperands); - if (resultTypes.size() != 1) + if (resultTypes.size() != 1) { p << ")"; + } } p << " "; p.printRegion(body, /*printEntryBlockArgs=*/false, @@ -1527,8 +1540,9 @@ static ParseResult parsePackSliceRanges( auto indexType = parser.getBuilder().getIndexType(); SmallVector lifetimeRangeValues; do { - if (failed(parser.parseOptionalLSquare())) + if (failed(parser.parseOptionalLSquare())) { break; + } IntegerAttr lifetimeStart; IntegerAttr lifetimeEnd; OpAsmParser::UnresolvedOperand dynamicSliceSize; @@ -1552,8 +1566,9 @@ static void printPackSliceRanges(OpAsmPrinter &p, Operation *op, ArrayAttr lifetimeIntervals, ValueRange dynamicSliceSizes, TypeRange packedOffsetTypes) { - if (packedOffsetTypes.empty()) + if (packedOffsetTypes.empty()) { return; + } for (unsigned i = 0; i < packedOffsetTypes.size(); ++i) { auto lifetimeStart = lifetimeIntervals[i * 2]; auto lifetimeEnd = lifetimeIntervals[i * 2 + 1]; @@ -1565,8 +1580,9 @@ static void printPackSliceRanges(OpAsmPrinter &p, Operation *op, p.printAttributeWithoutType(lifetimeEnd); p << "] = "; p.printOperand(sliceSize); - if (i < packedOffsetTypes.size() - 1) + if (i < packedOffsetTypes.size() - 1) { p << ","; + } } p.printNewline(); } @@ -1604,16 +1620,18 @@ static ParseResult parseConstantValueList( static void printConstantValueList(OpAsmPrinter &p, Operation *op, TypeRange resultTypes, ValueRange resultSizes, ArrayAttr values) { - if (resultTypes.empty()) + if (resultTypes.empty()) { return; + } for (unsigned i = 0; i < resultTypes.size(); ++i) { p.printNewline(); p << " "; printSizeAwareType(p, op, resultTypes[i], resultSizes[i]); p << " = "; p.printAttribute(values[i]); - if (i < resultTypes.size() - 1) + if (i < resultTypes.size() - 1) { p << ","; + } } } @@ -1667,13 +1685,15 @@ static ParseResult parseWorkgroupCountRegion(OpAsmParser &parser, static void printWorkgroupCountRegion(OpAsmPrinter &p, Operation *op, Region &body) { - if (body.empty()) + if (body.empty()) { return; + } p << "workgroups("; auto args = body.getArguments(); for (unsigned i = 0; i < args.size(); ++i) { - if (i > 0) + if (i > 0) { p << ", "; + } p.printRegionArgument(args[i]); } p << ")"; @@ -1695,8 +1715,9 @@ ResourceAllocOp::createSuballocations( bool uninitialized, AffinityAttr affinityAttr, OpBuilder &builder) { assert(locs.size() == storageSizes.size() && "expect locs and storageSizes to match"); - if (locs.empty()) + if (locs.empty()) { return {}; + } if (locs.size() == 1) { auto allocOp = IREE::Stream::ResourceAllocOp::create( builder, locs.front(), resourceType, storageSizes.front(), @@ -1712,17 +1733,10 @@ ResourceAllocOp::createSuballocations( // small enough workloads and our target devices are relatively lax on // things so long as we stay under UINT32_MAX boundaries. - // All slices are 0-0 (overlapping). - size_t sliceCount = locs.size(); - SmallVector lifetimeIntervals(sliceCount * 2, 0); - // Compute total size and the offsets of all suballocated resources via the // pack op. - auto indexType = builder.getIndexType(); - SmallVector packedOffsetTypes(sliceCount, indexType); auto packOp = IREE::Stream::ResourcePackOp::create( - builder, fusedLoc, indexType, packedOffsetTypes, /*offset=*/nullptr, - builder.getIndexArrayAttr(lifetimeIntervals), storageSizes, affinityAttr); + builder, fusedLoc, /*offset=*/nullptr, storageSizes, affinityAttr); // Create the new alloca based on the total required size. auto allocOp = IREE::Stream::ResourceAllocOp::create( @@ -1757,8 +1771,9 @@ ResourceAllocaOp::createSuballocations(Type timepointType, Type resourceType, OpBuilder &builder) { assert(locs.size() == storageSizes.size() && "expect locs and storageSizes to match"); - if (locs.empty()) + if (locs.empty()) { return {}; + } if (locs.size() == 1) { auto allocaOp = IREE::Stream::ResourceAllocaOp::create( builder, locs.front(), resourceType, timepointType, @@ -1884,6 +1899,18 @@ void ResourcePackOp::getAsmResultNames( // } } +void ResourcePackOp::build(OpBuilder &builder, OperationState &state, + Value offset, ValueRange valueSizes, + IREE::Stream::AffinityAttr affinity) { + // All slices are 0-0 (overlapping). + size_t sliceCount = valueSizes.size(); + SmallVector lifetimeIntervals(sliceCount * 2, 0); + auto indexType = builder.getIndexType(); + SmallVector indexTypes(sliceCount, indexType); + build(builder, state, indexType, indexTypes, offset, + builder.getIndexArrayAttr(lifetimeIntervals), valueSizes, affinity); +} + LogicalResult ResourcePackOp::verify() { ResourcePackOp op = *this; size_t sliceCount = op.getPackedOffsets().size(); @@ -2541,12 +2568,14 @@ void AsyncSplatOp::build(OpBuilder &builder, OperationState &state, Type result_type, Value value, Value result_size, Attribute affinity, Value await_timepoint) { state.addTypes(result_type); - if (await_timepoint) + if (await_timepoint) { state.addOperands(await_timepoint); + } state.addOperands(value); state.addOperands(result_size); - if (affinity) + if (affinity) { state.addAttribute("affinity", affinity); + } } LogicalResult AsyncSplatOp::verify() { @@ -2743,8 +2772,9 @@ static ParseResult parseCollectiveParam( OpAsmParser &parser, Attribute opAttr, std::optional &optionalParamValue) { const char *keyword = getCollectiveParamKeyword(opAttr); - if (!keyword) + if (!keyword) { return success(); // optional + } OpAsmParser::UnresolvedOperand paramValue; if (failed(parser.parseKeyword(keyword)) || failed(parser.parseLParen()) || failed(parser.parseOperand(paramValue)) || failed(parser.parseRParen())) { @@ -2909,6 +2939,11 @@ IREE::Stream::AffinityAttr AsyncTransferOp::getResultAffinityAttr() { return getTargetAffinityAttr(); } +void AsyncTransferOp::removeAffinityAttrs() { + removeSourceAffinityAttr(); + removeTargetAffinityAttr(); +} + void AsyncTransferOp::getAsyncAccessRanges( SmallVectorImpl &ranges) { ranges.push_back({ResourceAccessBitfield::Read, getSource(), Value{}, @@ -2985,16 +3020,19 @@ static ParseResult parseDispatchOperands( SmallVectorImpl &resourceOffsets, SmallVectorImpl &resourceEnds, SmallVectorImpl &resourceLengths) { - if (failed(parser.parseLParen())) + if (failed(parser.parseLParen())) { return failure(); + } // Handle the case of no operands specially. - if (succeeded(parser.parseOptionalRParen())) + if (succeeded(parser.parseOptionalRParen())) { return success(); + } do { // All entries at least have an %operand. resourceOperands.emplace_back(); - if (failed(parser.parseOperand(resourceOperands.back()))) + if (failed(parser.parseOperand(resourceOperands.back()))) { return failure(); + } // Resources have a range. if (succeeded(parser.parseOptionalLSquare())) { resourceOffsets.emplace_back(); @@ -3010,8 +3048,9 @@ static ParseResult parseDispatchOperands( } } } while (succeeded(parser.parseOptionalComma())); - if (failed(parser.parseRParen())) + if (failed(parser.parseRParen())) { return failure(); + } return success(); } @@ -3080,8 +3119,9 @@ void AsyncDispatchOp::getAsyncAccessRanges( unsigned rangeIndex = 0; unsigned tiedOperandBase = getTiedOperandsIndexAndLength().first; for (auto [operandIndex, operand] : llvm::enumerate(getResourceOperands())) { - if (!isa(operand.getType())) + if (!isa(operand.getType())) { continue; + } ResourceAccessBitfield access = ResourceAccessBitfield::Read; auto tiedResults = getOperandTiedResults(tiedOperandBase + operandIndex); if (!tiedResults.empty()) { @@ -3161,12 +3201,14 @@ void AsyncFuncOp::build(OpBuilder &builder, OperationState &state, bool AsyncFuncOp::isResultTied(int resultIndex) { auto tiedOperandsAttr = getTiedOperandsAttr(); - if (!tiedOperandsAttr) + if (!tiedOperandsAttr) { return false; + } auto indexAttr = dyn_cast_if_present( tiedOperandsAttr.getValue()[resultIndex]); - if (!indexAttr) + if (!indexAttr) { return false; + } return indexAttr.getInt() != IREE::Util::TiedOpInterface::kUntiedIndex; } @@ -3236,8 +3278,9 @@ AsyncCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { } // auto typesCompatible = [](Type actual, Type expected) { auto typesCompatible = [](Type callee, Type call) { - if (callee == call) + if (callee == call) { return true; + } auto calleeResource = dyn_cast(callee); auto callResource = dyn_cast(call); if (calleeResource && callResource) { @@ -3283,8 +3326,9 @@ void AsyncCallOp::getAsyncAccessRanges( unsigned rangeIndex = 0; unsigned tiedOperandBase = getTiedOperandsIndexAndLength().first; for (auto [operandIndex, operand] : llvm::enumerate(getResourceOperands())) { - if (!isa(operand.getType())) + if (!isa(operand.getType())) { continue; + } ResourceAccessBitfield access = ResourceAccessBitfield::Read; auto tiedResults = getOperandTiedResults(tiedOperandBase + operandIndex); if (!tiedResults.empty()) { @@ -3326,8 +3370,9 @@ void AsyncExecuteOp::build(OpBuilder &builder, OperationState &state, state.addOperands(operands); state.addOperands(operandSizes); state.addOperands(resultSizes); - if (awaitTimepoint) + if (awaitTimepoint) { state.addOperands(awaitTimepoint); + } state.addAttributes(attributes); state.attributes.erase(IREE::Util::TiedOpInterface::getStorageAttrName()); state.addAttribute(IREE::Util::TiedOpInterface::getStorageAttrName(), @@ -3390,8 +3435,9 @@ getExecutionAsyncAccessRanges(Op op, for (auto [i, operand, operandSize] : llvm::zip_equal( llvm::seq(0, op.getResourceOperands().size()), op.getResourceOperands(), op.getResourceOperandSizes())) { - if (!isa(operand.getType())) + if (!isa(operand.getType())) { continue; + } ResourceAccessBitfield access = ResourceAccessBitfield::Read; auto tiedResults = op.getOperandTiedResults(tiedOperandBase + i); if (!tiedResults.empty()) { @@ -3469,8 +3515,9 @@ AsyncExecuteOp::cloneReplacementExcludingOperandsAndResults( auto &block = newBody.front(); BitVector eraseIndices(block.getNumArguments()); - for (auto i : excludedOperandIndices) + for (auto i : excludedOperandIndices) { eraseIndices.set(i); + } block.eraseArguments(eraseIndices); return newOp; } @@ -3588,8 +3635,9 @@ AsyncConcurrentOp::cloneReplacementExcludingOperandsAndResults( eraseStreamRegionResults(newBody, excludedResultIndices); auto &block = newBody.front(); BitVector eraseIndices(block.getNumArguments()); - for (auto i : excludedOperandIndices) + for (auto i : excludedOperandIndices) { eraseIndices.set(i); + } block.eraseArguments(eraseIndices); return newOp; } @@ -3630,8 +3678,9 @@ Value AsyncParameterReadOp::getTiedResult(unsigned resultIndex) { ::std::optional AsyncParameterReadOp::getTiedResultOperandIndex(unsigned resultIndex) { - if (resultIndex == 0) - return {0}; // result tied to target + if (resultIndex == 0) { + return {0}; // result tied to target + } return std::nullopt; // result_timepoint not tied } @@ -3668,8 +3717,9 @@ Value AsyncParameterWriteOp::getTiedResult(unsigned resultIndex) { ::std::optional AsyncParameterWriteOp::getTiedResultOperandIndex(unsigned resultIndex) { - if (resultIndex == 0) - return {0}; // result tied to source + if (resultIndex == 0) { + return {0}; // result tied to source + } return std::nullopt; // result_timepoint not tied } @@ -3727,10 +3777,11 @@ Value AsyncParameterGatherOp::getTiedResult(unsigned resultIndex) { ::std::optional AsyncParameterGatherOp::getTiedResultOperandIndex(unsigned resultIndex) { - if (resultIndex == 0) + if (resultIndex == 0) { return { getSourceOffsets() - .size()}; // result tied to target (after variadic source_offsets) + .size()}; // result tied to target (after variadic source_offsets) + } return std::nullopt; // result_timepoint not tied } @@ -3792,8 +3843,9 @@ Value AsyncParameterScatterOp::getTiedResult(unsigned resultIndex) { ::std::optional AsyncParameterScatterOp::getTiedResultOperandIndex(unsigned resultIndex) { - if (resultIndex == 0) - return {0}; // result tied to source + if (resultIndex == 0) { + return {0}; // result tied to source + } return std::nullopt; // result_timepoint not tied } @@ -4099,8 +4151,9 @@ printDispatchResources(OpAsmPrinter &p, Operation *op, ValueRange resources, p.printOperand(resourceLength); p << "] : "; printSizeAwareType(p, op, resourceType, resourceSize); - if (i < resources.size() - 1) + if (i < resources.size() - 1) { p << ","; + } } } @@ -4186,8 +4239,9 @@ static ParseResult parseDispatchFunctionArgumentList( SmallVector argAttrsVec; do { OpAsmParser::UnresolvedOperand arg; - if (failed(parser.parseOperand(arg))) + if (failed(parser.parseOperand(arg))) { return failure(); + } bool hasOffsetLength = false; OpAsmParser::UnresolvedOperand offsetArg; OpAsmParser::UnresolvedOperand lengthArg; @@ -4262,8 +4316,9 @@ static void printDispatchFunctionResultList(OpAsmPrinter &p, Operation *op, p.printOptionalAttrDict(attrs.getValue()); } } - if (i < resultTypes.size() - 1) + if (i < resultTypes.size() - 1) { p << ", "; + } } } @@ -4274,8 +4329,9 @@ ParseResult parseDispatchFunctionSignature(OpAsmParser &parser, SmallVector args; SmallVector argTypes; SmallVector resultTypes; - if (failed(parser.parseLParen())) + if (failed(parser.parseLParen())) { return failure(); + } if (failed(parser.parseOptionalRParen())) { if (failed(parseDispatchFunctionArgumentList(parser, args, argTypes, argAttrs)) || @@ -4308,8 +4364,9 @@ void printDispatchFunctionSignature(OpAsmPrinter &p, Operation *op, auto functionType = cast(functionTypeAttr.getValue()); p << "("; for (size_t argIndex = 0; argIndex < functionType.getNumInputs();) { - if (argIndex) + if (argIndex) { p << ", "; + } int baseArgIndex = argIndex; auto type = functionType.getInput(baseArgIndex); p << "%arg"; @@ -4335,11 +4392,13 @@ void printDispatchFunctionSignature(OpAsmPrinter &p, Operation *op, auto resultTypes = functionType.getResults(); if (!resultTypes.empty()) { p << " -> "; - if (resultTypes.size() != 1) + if (resultTypes.size() != 1) { p << "("; + } printDispatchFunctionResultList(p, op, resultTypes, resultAttrs); - if (resultTypes.size() != 1) + if (resultTypes.size() != 1) { p << ")"; + } } } @@ -4401,11 +4460,13 @@ static ParseResult parseCmdCallOperands( SmallVectorImpl &resourceOffsets, SmallVectorImpl &resourceLengths, ArrayAttr &resourceAccesses) { - if (failed(parser.parseLParen())) + if (failed(parser.parseLParen())) { return failure(); + } // Handle the case of no operands specially. - if (succeeded(parser.parseOptionalRParen())) + if (succeeded(parser.parseOptionalRParen())) { return success(); + } SmallVector accessAttrs; do { StringRef accessStr; @@ -4444,8 +4505,9 @@ static ParseResult parseCmdCallOperands( } } while (succeeded(parser.parseOptionalComma())); resourceAccesses = parser.getBuilder().getArrayAttr(accessAttrs); - if (failed(parser.parseRParen())) + if (failed(parser.parseRParen())) { return failure(); + } return success(); } @@ -4490,8 +4552,9 @@ static void printCmdCallOperands(OpAsmPrinter &p, Operation *op, // Primitive/custom type. p.printOperand(operand); } - if (i < resourceOperands.size() - 1) + if (i < resourceOperands.size() - 1) { p << ", "; + } } p << ")"; } @@ -4507,8 +4570,9 @@ void CmdExecuteOp::build(OpBuilder &builder, OperationState &state, state.addTypes(IREE::Stream::TimepointType::get(builder.getContext())); state.addOperands(operands); state.addOperands(operandSizes); - if (awaitTimepoint) + if (awaitTimepoint) { state.addOperands(awaitTimepoint); + } state.addAttributes(attributes); state.attributes.erase(getOperandSegmentSizeAttr()); state.addAttribute(getOperandSegmentSizeAttr(), @@ -4542,8 +4606,9 @@ LogicalResult CmdExecuteOp::verify() { return failure(); } for (auto &nestedOp : op.getBody().front()) { - if (failed(verifyCmdOp(&nestedOp))) + if (failed(verifyCmdOp(&nestedOp))) { return failure(); + } } return success(); } @@ -4606,8 +4671,9 @@ CmdExecuteOp::cloneReplacementExcludingOperandsAndResults( newBody.takeBody(getClosureBodyRegion()); auto &block = newBody.front(); BitVector eraseIndices(block.getNumArguments()); - for (auto i : excludedOperandIndices) + for (auto i : excludedOperandIndices) { eraseIndices.set(i); + } block.eraseArguments(eraseIndices); return newOp; } @@ -4619,8 +4685,9 @@ CmdExecuteOp::cloneReplacementExcludingOperandsAndResults( LogicalResult CmdSerialOp::verify() { CmdSerialOp op = *this; for (auto &nestedOp : op.getBody().front()) { - if (failed(verifyCmdOp(&nestedOp))) + if (failed(verifyCmdOp(&nestedOp))) { return failure(); + } } return success(); } @@ -4645,8 +4712,9 @@ void CmdSerialOp::getSuccessorRegions( LogicalResult CmdConcurrentOp::verify() { CmdConcurrentOp op = *this; for (auto &nestedOp : op.getBody().front()) { - if (failed(verifyCmdOp(&nestedOp))) + if (failed(verifyCmdOp(&nestedOp))) { return failure(); + } } return success(); } @@ -4758,8 +4826,9 @@ LogicalResult TimepointJoinOp::verify() { Value TimepointJoinOp::join(Location loc, ValueRange timepoints, OpBuilder &builder) { assert(!timepoints.empty() && "must have at least one timepoint"); - if (timepoints.size() == 1) + if (timepoints.size() == 1) { return timepoints.front(); + } return IREE::Stream::TimepointJoinOp::create( builder, loc, builder.getType(), timepoints); } @@ -4957,11 +5026,13 @@ LogicalResult ExecutableExportOp::verify() { mlir::FunctionOpInterface ExecutableExportOp::lookupFunctionRef() { auto executableOp = this->getOperation()->getParentOfType(); - if (!executableOp) + if (!executableOp) { return {}; + } auto innerModuleOp = executableOp.getInnerModule(); - if (!innerModuleOp) + if (!innerModuleOp) { return {}; + } return innerModuleOp.lookupSymbol( getFunctionRef()); } diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td index a56218f637db..8a8469478932 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td @@ -577,6 +577,14 @@ def Stream_ResourcePackOp : Stream_PureOp<"resource.pack", [ attr-dict-with-keyword }]; + let builders = [ + OpBuilder<(ins + "Value":$offset, + "ValueRange":$valueSizes, + CArg<"AffinityAttr", "{}">:$affinityAttr + )>, + ]; + let hasVerifier = 1; let extraClassDeclaration = [{ @@ -1844,6 +1852,7 @@ def Stream_AsyncConstantOp : Stream_PureOp<"async.constant", [ DeclareOpInterfaceMethods, + Util_HoistableOpInterface, Util_SizeAwareOp, DeclareOpInterfaceMethods, ]> { @@ -1906,6 +1915,7 @@ def Stream_AsyncSplatOp : Stream_Op<"async.splat", [ "getAsyncAccessRanges", ]>, Util_SizeAwareOp, + Util_HoistableOpInterface, ]> { let summary = [{Splats a value into a resource.}]; let description = [{ @@ -1972,6 +1982,7 @@ def Stream_AsyncCloneOp : Stream_Op<"async.clone", [ "getAsyncAccessRanges", ]>, Util_SizeAwareOp, + Util_HoistableOpInterface, ]> { let summary = [{Clones the contents of a value.}]; let description = [{ @@ -2016,6 +2027,7 @@ def Stream_AsyncSliceOp : Stream_PureOp<"async.slice", [ DeclareOpInterfaceMethods, + Util_HoistableOpInterface, Util_SizeAwareOp, ]> { let summary = [{Slices out a cloned subview of a value.}]; @@ -2069,6 +2081,7 @@ def Stream_AsyncFillOp : Stream_Op<"async.fill", [ "getTiedResultOperandIndex", "getTiedResultOperandIndices", ]>, + Util_HoistableOpInterface, ]> { let summary = [{Fills a subview of a stream resource with a value.}]; let description = [{ @@ -2124,6 +2137,7 @@ def Stream_AsyncUpdateOp : Stream_Op<"async.update", [ "getTiedResultOperandIndex", "getTiedResultOperandIndices", ]>, + Util_HoistableOpInterface, ]> { let summary = [{Updates a slice of a subview of a resource in-place.}]; let description = [{ @@ -2181,6 +2195,7 @@ def Stream_AsyncCopyOp : Stream_Op<"async.copy", [ "getTiedResultOperandIndex", "getTiedResultOperandIndices", ]>, + Util_HoistableOpInterface, ]> { let summary = [{Copies a subview of a stream resource to another.}]; let description = [{ @@ -2244,6 +2259,7 @@ def Stream_AsyncCollectiveOp : Stream_Op<"async.collective", [ "getTiedResultOperandIndex", "getTiedResultOperandIndices", ]>, + Util_HoistableOpInterface, ]> { let summary = [{Performs a collective operation.}]; let description = [{ @@ -2358,12 +2374,14 @@ def Stream_AsyncTransferOp : Stream_PureOp<"async.transfer", [ "getAffinityAttr", "setAffinityAttr", "getResultAffinityAttr", + "removeAffinityAttrs", ]>, Stream_AsyncPhaseOp, Stream_StreamableOp, DeclareOpInterfaceMethods, + Util_HoistableOpInterface, Util_SizeAwareOp, ]> { let summary = [{Transfers a resource from one location/state to another.}]; @@ -2504,6 +2522,7 @@ def Stream_AsyncDispatchOp : Stream_PureOp<"async.dispatch", [ DeclareOpInterfaceMethods, + Util_HoistableOpInterface, Util_SizeAwareOp, DeclareOpInterfaceMethodsgetAttrOfType(attrId); - if (attr) + if (attr) { return attr; + } // See if the affinity specified provides a resource configuration. if (auto affinityOp = dyn_cast(op)) { auto affinityAttr = affinityOp.getAffinityAttr(); if (affinityAttr) { auto attr = affinityAttr.getResourceConfigAttr(); - if (attr) + if (attr) { return attr; + } } } op = op->getParentOp(); @@ -325,13 +341,15 @@ int64_t NamedParameterAttr::getStorageSize() const { Attribute TimepointAttr::parse(AsmParser &p, Type type) { StringRef timeStr; - if (failed(p.parseLess())) + if (failed(p.parseLess())) { return {}; + } if (failed(p.parseKeyword(&timeStr))) { return {}; } - if (failed(p.parseGreater())) + if (failed(p.parseGreater())) { return {}; + } if (timeStr != "immediate") { p.emitError(p.getCurrentLocation(), "only immediate timepoint attrs are supported"); @@ -389,8 +407,9 @@ AffinityAttr AffinityAttr::lookupOrDefault(Operation *fromOp) { // static bool AffinityAttr::areCompatible(AffinityAttr desiredAffinity, AffinityAttr requiredAffinity) { - if (desiredAffinity == requiredAffinity) + if (desiredAffinity == requiredAffinity) { return true; + } if ((desiredAffinity && !requiredAffinity) || (requiredAffinity && !desiredAffinity)) { return true; @@ -401,10 +420,12 @@ bool AffinityAttr::areCompatible(AffinityAttr desiredAffinity, // static bool AffinityAttr::canExecuteTogether(AffinityAttr lhs, AffinityAttr rhs) { - if (lhs == rhs) + if (lhs == rhs) { return true; - if ((lhs && !rhs) || (rhs && !lhs)) + } + if ((lhs && !rhs) || (rhs && !lhs)) { return true; + } return lhs.isExecutableWith(rhs); } @@ -429,15 +450,17 @@ AffinityAttr AffinityAttr::joinOR(ArrayRef affinityAttrs) { Attribute PartitioningConfigAttr::parse(AsmParser &p, Type type) { std::string favorStr; - if (failed(p.parseLess())) + if (failed(p.parseLess())) { return {}; + } if (succeeded(p.parseOptionalStar())) { favorStr = "size"; } else if (failed(p.parseString(&favorStr))) { return {}; } - if (failed(p.parseGreater())) + if (failed(p.parseGreater())) { return {}; + } auto favor = symbolizeFavor(favorStr); if (!favor.has_value()) { p.emitError(p.getNameLoc(), "unknown favor value: ") << favorStr; @@ -458,8 +481,9 @@ PartitioningConfigAttr PartitioningConfigAttr::lookup(Operation *op) { auto attrId = StringAttr::get(op->getContext(), "stream.partitioning"); while (op) { auto attr = op->getAttrOfType(attrId); - if (attr) + if (attr) { return attr; + } op = op->getParentOp(); } // No config found; use defaults. @@ -499,15 +523,17 @@ static void printLifetime(Lifetime lifetime, llvm::raw_ostream &os) { Type ResourceType::parse(AsmParser &p) { StringRef lifetimeStr; - if (failed(p.parseLess())) + if (failed(p.parseLess())) { return {}; + } if (succeeded(p.parseOptionalStar())) { lifetimeStr = "*"; } else if (failed(p.parseKeyword(&lifetimeStr))) { return {}; } - if (failed(p.parseGreater())) + if (failed(p.parseGreater())) { return {}; + } auto lifetime = parseLifetime(lifetimeStr); if (!lifetime.has_value()) { p.emitError(p.getNameLoc(), "unknown lifetime value: ") << lifetimeStr; diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/IR/test/BUILD.bazel index 7d7ee55b763a..cda6a510571e 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "async_folding.mlir", "async_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateAffinities.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateAffinities.cpp index 891ec95a8584..440c0341b626 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateAffinities.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateAffinities.cpp @@ -137,8 +137,9 @@ struct AnnotateAffinitiesPass // Annotate all ops with derived affinities. for (auto &op : getOperation().getOps()) { - if (op.hasTrait()) + if (op.hasTrait()) { continue; + } if (auto globalOp = dyn_cast(op)) { annotateGlobalOp(globalOp, affinityAnalysis); } else if (auto funcOp = dyn_cast(op)) { diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateDispatchArguments.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateDispatchArguments.cpp index 74104a7a9536..6197f580acd5 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateDispatchArguments.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateDispatchArguments.cpp @@ -216,8 +216,9 @@ ChangeStatus GlobalPVS::updateOperation(IREE::Util::GlobalOp globalOp, auto *globalInfo = solver.getExplorer().getGlobalInfo(globalOp); for (auto use : globalInfo->uses) { auto storeOp = dyn_cast(use); - if (!storeOp) + if (!storeOp) { continue; + } auto value = solver.getElementFor( *this, Position::forValue(storeOp.getStoredGlobalValue()), DFX::Resolution::REQUIRED); @@ -275,8 +276,9 @@ class ValueAlignment } static llvm::MaybeAlign computeAlignment(const ValuePVS::SetTy &set) { - if (set.empty()) + if (set.empty()) { return llvm::MaybeAlign(); + } llvm::MaybeAlign alignment; for (auto value : set) { APInt valueDivisor = (value & (~(value - 1))); @@ -373,8 +375,9 @@ class ArgumentAnalysis { ArrayRef getDispatchSites(IREE::Stream::ExecutableExportOp exportOp) { auto it = entryDispatchMap.find(exportOp); - if (it == entryDispatchMap.end()) + if (it == entryDispatchMap.end()) { return {}; + } return it->second; } @@ -383,8 +386,9 @@ class ArgumentAnalysis { llvm::MaybeAlign getAlignmentFor(Value value) { auto element = solver.lookupElementFor(Position::forValue(value)); - if (!element) + if (!element) { return llvm::MaybeAlign(); + } return element->getAssumedAlignment(); } @@ -422,8 +426,9 @@ class ArgumentAnalysis { for (auto dispatchOp : getDispatchSites(exportOp)) { auto element = solver.lookupElementFor( Position::forValue(dispatchOp.getUniformOperands()[operandIdx])); - if (!element || !element->isValidState()) + if (!element || !element->isValidState()) { return llvm::MaybeAlign(); + } alignment = commonAlignment(alignment, element->getAssumedAlignment()); } if (alignment.valueOrOne().value() == kMaximumAlignment) { @@ -441,8 +446,9 @@ class ArgumentAnalysis { for (auto dispatchOp : getDispatchSites(exportOp)) { auto element = solver.lookupElementFor( Position::forValue(dispatchOp.getResourceOffsets()[resourceIdx])); - if (!element || !element->isValidState()) + if (!element || !element->isValidState()) { return llvm::MaybeAlign(); + } alignment = commonAlignment(alignment, element->getAssumedAlignment()); } if (alignment.valueOrOne().value() == kMaximumAlignment) { @@ -477,8 +483,9 @@ static void annotateExport(IREE::Stream::ExecutableOp executableOp, // Operands/resources on the func are in an arbitrary order; get maps that // lets us go from dispatch site operand/resource to function argument. auto funcOp = exportOp.lookupFunctionRef(); - if (!funcOp) + if (!funcOp) { return; + } auto operandToArgMap = IREE::Stream::CmdDispatchOp::makeOperandToArgMap(funcOp); auto resourceToArgMap = @@ -502,8 +509,9 @@ static void annotateExport(IREE::Stream::ExecutableOp executableOp, llvm::sort(potentialValues, [](Attribute lhs, Attribute rhs) { auto lhsInt = dyn_cast(lhs); auto rhsInt = dyn_cast(rhs); - if (!lhsInt || !rhsInt) + if (!lhsInt || !rhsInt) { return false; + } return lhsInt.getValue().ult(rhsInt.getValue()); }); auto potentialValuesAttr = ArrayAttr::get(context, potentialValues); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateDispatchAssumptions.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateDispatchAssumptions.cpp index eb74773f2ffe..378c3bf6b1d5 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateDispatchAssumptions.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateDispatchAssumptions.cpp @@ -54,8 +54,9 @@ class ArgumentAnalysis { LogicalResult run() { for (Operation *analysisRoot : analysisRoots) { - if (failed(solver.initializeAndRun(analysisRoot))) + if (failed(solver.initializeAndRun(analysisRoot))) { return failure(); + } } return success(); } @@ -65,8 +66,9 @@ class ArgumentAnalysis { ArrayRef getDispatchSites(IREE::Stream::ExecutableExportOp exportOp) { auto it = entryDispatchMap.find(exportOp); - if (it == entryDispatchMap.end()) + if (it == entryDispatchMap.end()) { return {}; + } return it->second; } @@ -109,8 +111,9 @@ class ArgumentAnalysis { IREE::Util::IntAssumptionAttr::get(context, umin, umax, udiv)); } - if (assumptions.empty()) + if (assumptions.empty()) { return {}; + } return std::make_pair( ArrayAttr::get(context, ArrayRef(assumptions.begin(), @@ -138,8 +141,9 @@ static void annotateExport(IREE::Stream::ExecutableOp executableOp, // Operands/resources on the func are in an arbitrary order; get maps that // lets us go from dispatch site operand/resource to function argument. auto funcOp = exportOp.lookupFunctionRef(); - if (!funcOp || funcOp.empty()) + if (!funcOp || funcOp.empty()) { return; + } auto operandToArgMap = IREE::Stream::CmdDispatchOp::makeOperandToArgMap(funcOp); auto resourceToArgMap = @@ -156,8 +160,9 @@ static void annotateExport(IREE::Stream::ExecutableOp executableOp, unsigned argIdx = operandToArgMap[operandIdx]; Value argValue = funcOp.getArgument(argIdx); Type argType = argValue.getType(); - if (!argType.isIndex() && !argType.isInteger()) + if (!argType.isIndex() && !argType.isInteger()) { continue; + } auto [assumptions, hasNonEmpty] = analysis.getOperandAssumptions(exportOp, operandIdx); @@ -168,8 +173,9 @@ static void annotateExport(IREE::Stream::ExecutableOp executableOp, } } - if (nonEmptyCount == 0) + if (nonEmptyCount == 0) { return; + } // Do the rewrite. OpBuilder builder = OpBuilder::atBlockBegin(&funcOp.front()); @@ -186,8 +192,9 @@ class AnnotateDispatchAssumptionsPass AnnotateDispatchAssumptionsPass> { void runOnOperation() override { ArgumentAnalysis analysis(getOperation()); - if (failed(analysis.run())) + if (failed(analysis.run())) { return signalPassFailure(); + } // Annotate the exported dispatch functions. for (auto executableOp : getOperation().getBodyRegion().getOps()) { diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/AutomaticReferenceCounting.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AutomaticReferenceCounting.cpp index 005c5e0daed3..9e5377c5e673 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/AutomaticReferenceCounting.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AutomaticReferenceCounting.cpp @@ -306,31 +306,40 @@ struct LastUseSet { } }; +// Returns the timepoints sorted by their order in the block (textual order). +// All timepoints must be in the same block. +static SmallVector getSortedTimepointsInBlock(TimepointSet &timepoints) { + auto sorted = llvm::to_vector_of(timepoints); + llvm::sort(sorted, [](Value a, Value b) { + Operation *opA = a.getDefiningOp(); + Operation *opB = b.getDefiningOp(); + if (!opA && !opB) { + // Both are block arguments, compare by argument number. + return cast(a).getArgNumber() < + cast(b).getArgNumber(); + } + if (!opA) { + return true; // Block argument comes before operation. + } + if (!opB) { + return false; // Operation comes before block argument. + } + return opA->isBeforeInBlock(opB); + }); + return sorted; +} + // Returns the last defined SSA value in the block in |timepoints| (textual // order within the block). All timepoints must be in the same block. static Value getLastTimepointInBlock(TimepointSet &timepoints) { if (timepoints.empty()) { return nullptr; - } else if (timepoints.size() == 1) { - return *timepoints.begin(); } - Value lastTimepoint; - for (auto timepoint : timepoints) { - if (!lastTimepoint) { - lastTimepoint = timepoint; - } else { - auto *timepointOp = timepoint.getDefiningOp(); - auto *lastTimepointOp = lastTimepoint.getDefiningOp(); - if (!timepointOp) { - continue; // block arg - } else if (!lastTimepointOp) { - lastTimepoint = timepoint; // last found was a block arg, this isn't - } else if (lastTimepointOp->isBeforeInBlock(timepointOp)) { - lastTimepoint = timepoint; - } - } + if (timepoints.size() == 1) { + return *timepoints.begin(); } - return lastTimepoint; + SmallVector sorted = getSortedTimepointsInBlock(timepoints); + return sorted.back(); } // Returns a FusedLoc with the location of all |timepoints| and the base |loc|. @@ -595,8 +604,7 @@ static void insertDeallocations(LastUseSet &lastUseSet, AsmState *asmState, auto joinOp = IREE::Stream::TimepointJoinOp::create( builder, timepointsLoc, builder.getType(), - llvm::map_to_vector(timepoints, - [](Value timepoint) { return timepoint; })); + getSortedTimepointsInBlock(timepoints)); auto deallocaOp = IREE::Stream::ResourceDeallocaOp::create( builder, timepointsLoc, builder.getType(), resource, diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel index f35877b02a4d..ad17fd38648e 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel @@ -52,6 +52,7 @@ iree_compiler_cc_library( "ScheduleExecution.cpp", "SpecializeDispatches.cpp", "SpecializeEncodings.cpp", + "SplitParameterEncoder.cpp", "SyncInitializers.cpp", "UnifyEncodingForGlobals.cpp", "Utils.cpp", @@ -101,6 +102,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt index ab7f116493a0..3cba16cf865a 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt @@ -49,6 +49,7 @@ iree_cc_library( "ScheduleExecution.cpp" "SpecializeDispatches.cpp" "SpecializeEncodings.cpp" + "SplitParameterEncoder.cpp" "SyncInitializers.cpp" "UnifyEncodingForGlobals.cpp" "Utils.cpp" @@ -72,6 +73,7 @@ iree_cc_library( MLIRPass MLIRSCFDialect MLIRSCFToControlFlow + MLIRSideEffectInterfaces MLIRSupport MLIRTensorDialect MLIRTransformUtils diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp index 1894b47f9694..86bc008142bd 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp @@ -45,8 +45,9 @@ namespace { static bool doesOperationNeedWrapping(Operation *op) { return llvm::any_of(op->getOperands(), [](Value operand) { - if (!isa(operand.getType())) + if (!isa(operand.getType())) { return false; + } return !isa_and_nonnull( operand.getDefiningOp()); }) || diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/DumpStatistics.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/DumpStatistics.cpp index 63c70436ab16..9715f7929fbe 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/DumpStatistics.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/DumpStatistics.cpp @@ -127,8 +127,9 @@ struct Statistics { for (auto [name, globalOp] : usageInfo.resourceGlobalOps) { auto globalType = dyn_cast(globalOp.getType()); - if (!globalType) + if (!globalType) { continue; + } // TODO(benvanik): analyze size in UsageInfo where possible. switch (globalType.getLifetime()) { case IREE::Stream::Lifetime::Constant: @@ -436,14 +437,16 @@ static void dumpExecutionCSVTable(const UsageInfo &usageInfo, TypeSwitch(op) .Case([&](auto op) { ++depth; - for (auto &nestedOp : op.getBody().front()) + for (auto &nestedOp : op.getBody().front()) { dumpRow(&nestedOp); + } --depth; }) .Case([&](auto op) { ++depth; - for (auto &nestedOp : op.getBody().front()) + for (auto &nestedOp : op.getBody().front()) { dumpRow(&nestedOp); + } --depth; }) .Case([&](auto op) { @@ -462,8 +465,9 @@ static void dumpExecutionCSVTable(const UsageInfo &usageInfo, auto workload = op.getWorkload(); SmallString<32> workloadStr; for (unsigned i = 0; i < workload.size(); ++i) { - if (i > 0) + if (i > 0) { workloadStr.append(";"); + } APInt dimValue; if (matchPattern(workload[i], m_ConstantInt(&dimValue))) { dimValue.toString(workloadStr, 10, /*signed=*/true); @@ -575,8 +579,9 @@ openOutputFile(StringRef filePath) { std::error_code ec; auto result = std::make_unique( filePath, ec, llvm::sys::fs::OF_TextWithCRLF); - if (!ec) + if (!ec) { return result; + } llvm::errs() << "Error opening iree-stream-dump-statistics output file '" << filePath << "'\n"; return std::make_unique(2, false); // stderr. @@ -588,8 +593,9 @@ struct DumpStatisticsPass using IREE::Stream::impl::DumpStatisticsPassBase< DumpStatisticsPass>::DumpStatisticsPassBase; void runOnOperation() override { - if (outputFormat == DumpOutputFormat::None) + if (outputFormat == DumpOutputFormat::None) { return; + } // Open the output file we'll be streaming to. // Since we are processing the entire module at once we overwrite the file. diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp index c25281f0ed33..563a65034c07 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp @@ -173,8 +173,9 @@ class ArgumentSemantics const std::string getAsStr(AsmState &asmState) const override { std::string str; auto append = [&](const char *part) { - if (!str.empty()) + if (!str.empty()) { str += '|'; + } str += part; }; append(this->isAssumed(NOT_MUTATED) ? "immutable" : "mutable"); @@ -190,8 +191,9 @@ class ArgumentSemantics static bool isTiedUse(OpOperand &operand) { if (auto tiedOp = dyn_cast(operand.getOwner())) { - if (tiedOp.isOperandTied(operand.getOperandNumber())) + if (tiedOp.isOperandTied(operand.getOperandNumber())) { return true; + } } return false; } @@ -573,8 +575,9 @@ class ElisionAnalysis { bool isArgMoved(BlockArgument arg) { auto argumentSemantics = solver.lookupElementFor(Position::forValue(arg)); - if (!argumentSemantics) + if (!argumentSemantics) { return false; + } return argumentSemantics->getAssumedByValue(); } @@ -1006,16 +1009,18 @@ static bool isSafeToElideSliceOp(IREE::Stream::AsyncSliceOp sliceOp, SmallVector consumerRanges; SmallVector queryRanges; for (auto user : source.getUsers()) { - if (user == sliceOp) + if (user == sliceOp) { continue; + } if (auto accessOp = dyn_cast(user)) { // Async op consuming part of the resource. We can query it to see what // it's doing to its operands/results and filter to just the accesses of // the source value. accessOp.getAsyncAccessRanges(queryRanges); for (auto range : queryRanges) { - if (range.resource == source) + if (range.resource == source) { consumerRanges.push_back(range); + } } queryRanges.clear(); } else { @@ -1058,10 +1063,12 @@ static bool isSafeToElideSliceOp(IREE::Stream::AsyncSliceOp sliceOp, // arith.addi folders are terrible and don't handle adds of 0 so we handle that // here and then avoid doing the folding. static Value addOffset(Value lhs, Value rhs, OpBuilder &builder) { - if (matchPattern(lhs, m_Zero())) + if (matchPattern(lhs, m_Zero())) { return rhs; - if (matchPattern(rhs, m_Zero())) + } + if (matchPattern(rhs, m_Zero())) { return lhs; + } return builder.createOrFold( builder.getFusedLoc(lhs.getLoc(), rhs.getLoc()), lhs, rhs); } @@ -1111,8 +1118,9 @@ static void foldSliceIntoDispatch(IREE::Stream::AsyncSliceOp sliceOp, // Elides a stream.async.slice op (assuming able) by folding it into consumers. static void elideSliceOp(IREE::Stream::AsyncSliceOp sliceOp) { SmallVector> consumers; - for (auto &use : sliceOp.getResult().getUses()) + for (auto &use : sliceOp.getResult().getUses()) { consumers.push_back(std::make_pair(use.getOwner(), use.getOperandNumber())); + } for (auto [owner, operandNumberIt] : consumers) { unsigned operandNumber = operandNumberIt; // need C++20 to avoid this :| TypeSwitch(owner) @@ -1222,8 +1230,9 @@ static bool isSafeToElideUpdateOp(IREE::Stream::AsyncUpdateOp updateOp, // the dispatch fully overwrites our update region. if (auto dispatchOp = dyn_cast(user)) { for (auto &operand : user->getOpOperands()) { - if (operand.get() != result) + if (operand.get() != result) { continue; + } if (dispatchOp.isOperandTied(operand.getOperandNumber())) { // Result is tied to dispatch - check if dispatch fully overwrites // the update region. If not, downstream reads might access our diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideTimepoints.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideTimepoints.cpp index 30e478919cb0..119bd605d7eb 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideTimepoints.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideTimepoints.cpp @@ -786,8 +786,9 @@ class TimepointCoverageAnalysis { // Seed FenceCoverage for fence values that might be relevant. // This includes fences from timeline-aware ops and imported fences. for (auto op : block.getOps()) { - if (!op.participatesInTimeline()) + if (!op.participatesInTimeline()) { continue; + } if (Value signalFence = op.getSignalFence()) { solver.getOrCreateElementFor( Position::forValue(signalFence)); @@ -828,8 +829,9 @@ class TimepointCoverageAnalysis { }; for (auto callableOp : getTopLevelOps()) { auto *region = callableOp.getCallableRegion(); - if (!region || region->empty()) + if (!region || region->empty()) { continue; + } seedRegion(*region); } @@ -846,8 +848,9 @@ class TimepointCoverageAnalysis { // Returns true if |value| is known to be immediately resolved. bool isImmediate(Value value) { - if (isDefinedImmediate(value)) + if (isDefinedImmediate(value)) { return true; + } auto &isImmediate = solver.getOrCreateElementFor(Position::forValue(value)); return isImmediate.isValidState() && isImmediate.isKnown(); @@ -881,8 +884,9 @@ class TimepointCoverageAnalysis { bool unionTransitivelyReachedTimepoints(Value value, SetVector &set) { auto coverage = solver.getOrCreateElementFor( Position::forValue(value)); - if (!coverage.isValidState() || coverage.isUndefContained()) + if (!coverage.isValidState() || coverage.isUndefContained()) { return false; + } for (auto reached : coverage.getAssumedSet()) { set.insert(reached); } @@ -914,8 +918,9 @@ buildRequiredCoverageSet(SmallVector possibleTimepoints, if (isValid) { for (auto reachedTimepoint : reachedTimepoints) { // TODO(benvanik): avoid self-references so we don't need this check. - if (reachedTimepoint == possibleTimepoint) + if (reachedTimepoint == possibleTimepoint) { continue; + } ++coverageMap[reachedTimepoint]; } } @@ -1036,8 +1041,9 @@ static bool trySinkAwaitIntoBranch(IREE::Stream::TimepointAwaitOp awaitOp, llvm::dbgs() << "[ElideTimepoints] sinking await into scf.if "; bool first = true; for (Region *region : regionsWithDirectUse) { - if (!first) + if (!first) { llvm::dbgs() << " and "; + } if (region == &ifOp.getThenRegion()) { llvm::dbgs() << "then"; } else { @@ -1066,8 +1072,9 @@ static bool trySinkAwaitIntoBranch(IREE::Stream::TimepointAwaitOp awaitOp, bool first = true; auto caseRegions = switchOp.getCaseRegions(); for (Region *region : regionsWithDirectUse) { - if (!first) + if (!first) { llvm::dbgs() << ", "; + } // Find which case this is. bool foundCase = false; for (auto [idx, caseRegion] : llvm::enumerate(caseRegions)) { @@ -1532,8 +1539,9 @@ static bool tryElideTimepointsInRegion(Region ®ion, // Elides |elidedTimepoint| by replacing all its uses by |op| with an // immediate timepoint value. auto elideTimepointOperand = [&](Operation *op, Value elidedTimepoint) { - if (isDefinedImmediate(elidedTimepoint)) + if (isDefinedImmediate(elidedTimepoint)) { return; // already immediate + } auto immediateTimepoint = makeImmediate(elidedTimepoint, OpBuilder(op)); elidedTimepoint.replaceUsesWithIf( immediateTimepoint, @@ -1544,10 +1552,12 @@ static bool tryElideTimepointsInRegion(Region ®ion, // Elides all timepoint operands of |op| that are immediately resolved. auto elideTimepointOperands = [&](Operation *op) { for (auto operand : llvm::make_early_inc_range(op->getOperands())) { - if (!isa(operand.getType())) + if (!isa(operand.getType())) { continue; - if (isDefinedImmediate(operand)) + } + if (isDefinedImmediate(operand)) { continue; + } if (analysis.isImmediate(operand)) { LLVM_DEBUG({ llvm::dbgs() << " >>> eliding known-immediate operand "; @@ -1562,10 +1572,12 @@ static bool tryElideTimepointsInRegion(Region ®ion, // Elides |elidedTimepoint| by replacing all its uses with an immediate // timepoint value. The original value will end up with zero uses. auto elideTimepointResult = [&](Operation *op, Value elidedTimepoint) { - if (elidedTimepoint.use_empty()) + if (elidedTimepoint.use_empty()) { return; // no-op - if (isDefinedImmediate(elidedTimepoint)) + } + if (isDefinedImmediate(elidedTimepoint)) { return; // already immediate + } OpBuilder afterBuilder(op); afterBuilder.setInsertionPointAfterValue(elidedTimepoint); Value immediateTimepoint = IREE::Stream::TimepointImmediateOp::create( @@ -1583,10 +1595,12 @@ static bool tryElideTimepointsInRegion(Region ®ion, // %imm0 = immediate // %imm1 = immediate for (auto result : llvm::reverse(op->getResults())) { - if (!isa(result.getType())) + if (!isa(result.getType())) { continue; - if (isDefinedImmediate(result)) + } + if (isDefinedImmediate(result)) { continue; + } if (analysis.isImmediate(result)) { LLVM_DEBUG({ llvm::dbgs() << " >>> eliding known-immediate result "; @@ -1604,8 +1618,9 @@ static bool tryElideTimepointsInRegion(Region ®ion, auto processTimelineOp = [&](IREE::Stream::TimelineOpInterface op) { auto resultTimepoint = op.getResultTimepoint(); auto awaitTimepoints = op.getAwaitTimepoints(); - if (awaitTimepoints.empty()) + if (awaitTimepoints.empty()) { return; + } LLVM_DEBUG({ llvm::dbgs() << "[ElideTimepoints] pruning " << op->getName() @@ -1652,8 +1667,9 @@ static bool tryElideTimepointsInRegion(Region ®ion, } // If there's only one timepoint we don't have to worry with coverage. - if (possibleTimepoints.size() <= 1) + if (possibleTimepoints.size() <= 1) { return; + } // Perform the analysis on the possible timepoints to find which are covered // by others and elide all of those known-covered. @@ -1761,8 +1777,9 @@ struct ElideTimepointsPass : public IREE::Stream::impl::ElideTimepointsPassBase { void runOnOperation() override { mlir::ModuleOp moduleOp = getOperation(); - if (moduleOp.getBody()->empty()) + if (moduleOp.getBody()->empty()) { return; + } // Perform whole-program analysis to find for each timepoint what other // timepoints are known to be reached. @@ -1793,8 +1810,9 @@ struct ElideTimepointsPass tryElideTimepointsInRegion(*region, analysis, domInfo) || didChange; } - if (didChange) + if (didChange) { signalFixedPointModified(moduleOp); + } } }; diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/EncodeTensors.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/EncodeTensors.cpp index 9b808c74674b..d1ee6cbb0981 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/EncodeTensors.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/EncodeTensors.cpp @@ -60,8 +60,9 @@ static LogicalResult checkEncoding(Operation *op, RankedTensorType encodingType, static RankedTensorType alignTensorType(RankedTensorType originalType) { Type elementType = originalType.getElementType(); Type alignedType = legalizeStorageElementType(originalType); - if (alignedType == elementType) + if (alignedType == elementType) { return originalType; + } return RankedTensorType::get(originalType.getShape(), alignedType, originalType.getEncoding()); } @@ -79,8 +80,9 @@ static Value makeTensorDim(Location loc, RankedTensorType tensorType, // Map from absolute dimension index to the compact dynamic index. unsigned di = 0; for (unsigned j = 0; j < i; ++j) { - if (tensorType.isDynamicDim(j)) + if (tensorType.isDynamicDim(j)) { ++di; + } } return dynamicDims[di]; } @@ -661,8 +663,9 @@ alignDispatchTensorType(IREE::TensorExt::DispatchTensorType originalType) { Type elementType = originalType.getBoundElementType(); Type alignedType = legalizeStorageElementType(originalType.asRankedTensorType()); - if (alignedType == elementType) + if (alignedType == elementType) { return originalType; + } return IREE::TensorExt::DispatchTensorType::get( originalType.getAccess(), originalType.getShape(), alignedType); } @@ -688,8 +691,9 @@ struct EncodeBindingSubspanOp // Align the element type, if needed. IREE::TensorExt::DispatchTensorType alignedType = alignDispatchTensorType(originalType); - if (originalType == alignedType) + if (originalType == alignedType) { return failure(); // already aligned. + } // Directly swap the type with the one, changing all uses in the IR. // This works because @@ -713,8 +717,9 @@ struct EncodeDispatchTensorLoadOp // Align the element type, if needed. RankedTensorType alignedType = alignTensorType(targetType); - if (targetType == alignedType) + if (targetType == alignedType) { return failure(); // already aligned. + } // Loads always truncate from an byte aligned type to a sub-byte one. assert(targetType.getElementTypeBitWidth() < @@ -747,8 +752,9 @@ struct EncodeDispatchTensorStoreOp // Align the element type, if needed. RankedTensorType alignedType = alignTensorType(sourceType); - if (sourceType == alignedType) + if (sourceType == alignedType) { return failure(); // already aligned. + } // Stores always extend from a sub-byte aligned type to a byte aligned one. assert(sourceType.getElementTypeBitWidth() < diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/FoldUniformOperands.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/FoldUniformOperands.cpp index 50440574d0c8..4a9a1c2deb3f 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/FoldUniformOperands.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/FoldUniformOperands.cpp @@ -143,8 +143,9 @@ deduplicateOperands(mlir::FunctionOpInterface funcOp, llvm::interleaveComma(deadOperandsMap.set_bits(), llvm::dbgs()); llvm::dbgs() << "\n"; for (auto replacement : llvm::enumerate(argReplacementMap)) { - if (replacement.index() == replacement.value()) + if (replacement.index() == replacement.value()) { continue; + } llvm::dbgs() << " %arg" << replacement.index() << " -> %arg" << replacement.value() << "\n"; } @@ -155,8 +156,9 @@ deduplicateOperands(mlir::FunctionOpInterface funcOp, for (auto replacement : llvm::enumerate(argReplacementMap)) { unsigned deadIdx = replacement.index(); unsigned liveIdx = replacement.value(); - if (deadIdx == liveIdx) + if (deadIdx == liveIdx) { continue; + } deadArgMap.set(deadIdx); entryBlock.getArgument(deadIdx).replaceAllUsesWith( entryBlock.getArgument(liveIdx)); @@ -164,8 +166,9 @@ deduplicateOperands(mlir::FunctionOpInterface funcOp, // Update each dispatch site to remove duplicates. SmallVector deadOperands; - for (auto idx : deadOperandsMap.set_bits()) + for (auto idx : deadOperandsMap.set_bits()) { deadOperands.push_back(idx); + } for (auto dispatchOp : dispatchOps) { for (auto idx : llvm::reverse(deadOperands)) { dispatchOp.getUniformOperandsMutable().erase(idx); @@ -202,8 +205,9 @@ inlineUniformConstants(mlir::FunctionOpInterface funcOp, llvm::BitVector uniformOperandMap(operandCount, /*t=*/true); for (auto dispatchOp : dispatchOps) { for (unsigned idx = 0; idx < operandCount; ++idx) { - if (!uniformOperandMap.test(idx)) + if (!uniformOperandMap.test(idx)) { continue; + } auto value = dispatchOp.getUniformOperands()[idx]; APInt intValue; if (!matchPattern(value, m_ConstantInt(&intValue))) { @@ -232,8 +236,9 @@ inlineUniformConstants(mlir::FunctionOpInterface funcOp, LLVM_DEBUG({ llvm::dbgs() << "inlineUniformConstants for " << funcOp.getName() << "\n"; for (unsigned i = 0; i < operandValues.size(); ++i) { - if (!operandValues[i].has_value()) + if (!operandValues[i].has_value()) { continue; + } llvm::dbgs() << " operand " << i << " = " << operandValues[i].value() << "\n"; } @@ -258,8 +263,9 @@ inlineUniformConstants(mlir::FunctionOpInterface funcOp, // Update each dispatch site to remove duplicates. SmallVector deadOperands; - for (auto idx : uniformOperandMap.set_bits()) + for (auto idx : uniformOperandMap.set_bits()) { deadOperands.push_back(idx); + } for (auto dispatchOp : dispatchOps) { for (auto idx : llvm::reverse(deadOperands)) { dispatchOp.getUniformOperandsMutable().erase(idx); @@ -410,8 +416,9 @@ struct FoldUniformOperandsPass for (auto exportOp : executableOp.getOps()) { auto &dispatchOps = entryDispatchMap[exportOp]; - if (dispatchOps.empty()) + if (dispatchOps.empty()) { continue; // no-op if no dispatches + } auto funcOp = exportOp.lookupFunctionRef(); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/FuseDispatchBindings.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/FuseDispatchBindings.cpp index 3e182f2b97c1..5ca60ead127b 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/FuseDispatchBindings.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/FuseDispatchBindings.cpp @@ -148,8 +148,9 @@ findCorrelatedBindings(unsigned bindingCount, llvm::BitVector handledBindings(bindingCount, /*t=*/false); for (unsigned i = 0; i < bindingCount; ++i) { // Ignore bindings we've already covered earlier during iteration. - if (handledBindings.test(i)) + if (handledBindings.test(i)) { continue; + } // Build new binding. Binding binding; @@ -316,8 +317,9 @@ fuseDispatchBindings(IREE::Stream::ExecutableOp executableOp, IREE::Stream::ExecutableExportOp exportOp, ArrayRef dispatchOps, MemoizedCmdZeros &memoizedZeros) { - if (dispatchOps.empty()) + if (dispatchOps.empty()) { return; // no-op if no dispatches + } auto anyDispatchOp = dispatchOps.front(); unsigned bindingCount = anyDispatchOp.getResources().size(); @@ -443,8 +445,9 @@ struct FuseDispatchBindingsPass MemoizedCmdZeros memoizedZeros; for (auto executableOp : getOperation().getBodyRegion().getOps()) { - if (!executableOp.getInnerModule()) + if (!executableOp.getInnerModule()) { continue; + } for (auto exportOp : executableOp.getOps()) { fuseDispatchBindings(executableOp, exportOp, entryDispatchMap[exportOp], diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/LayoutSlices.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/LayoutSlices.cpp index 677fcab47d25..184c713300e0 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/LayoutSlices.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/LayoutSlices.cpp @@ -53,7 +53,7 @@ using Slice = IREE::Stream::ResourcePackOp::Slice; // |baseOffset|. Returns |baseOffset| + the total size of the allocation // aligned to the requirements of |resourceConfig|. static Value -packStaticSlicesGreedily(IREE::Stream::ResourcePackOp packOp, Value baseOffset, +packStaticSlicesGreedily(Location loc, Value baseOffset, MutableArrayRef slices, IREE::Stream::ResourceConfigAttr resourceConfig, IndexSet &indexSet, OpBuilder &builder) { @@ -114,7 +114,7 @@ packStaticSlicesGreedily(IREE::Stream::ResourcePackOp packOp, Value baseOffset, } reservations.insert(insertionIt, reservation); slice.packedOffset.replaceAllUsesWith(builder.createOrFold( - packOp.getLoc(), baseOffset, indexSet.get(bestOffset))); + loc, baseOffset, indexSet.get(bestOffset))); // Update highwater mark indicating how much memory needs to be allocated // for the entire slab. @@ -122,7 +122,7 @@ packStaticSlicesGreedily(IREE::Stream::ResourcePackOp packOp, Value baseOffset, } highwaterMark = IREE::Util::align(highwaterMark, rangeAlignment); - return builder.createOrFold(packOp.getLoc(), baseOffset, + return builder.createOrFold(loc, baseOffset, indexSet.get(highwaterMark)); } @@ -145,11 +145,10 @@ packStaticSlicesGreedily(IREE::Stream::ResourcePackOp packOp, Value baseOffset, // |baseOffset|. Returns |baseOffset| + the total size of the allocation // aligned to the requirements of |resourceConfig|. static Value -packDynamicSlicesConservatively(IREE::Stream::ResourcePackOp packOp, - Value baseOffset, MutableArrayRef slices, +packDynamicSlicesConservatively(Location loc, Value baseOffset, + MutableArrayRef slices, IREE::Stream::ResourceConfigAttr resourceConfig, IndexSet &indexSet, OpBuilder &builder) { - auto loc = packOp.getLoc(); int64_t offsetAlignment = resourceConfig.getMinBufferOffsetAlignment(); int64_t rangeAlignment = resourceConfig.getMinBufferRangeAlignment(); @@ -181,8 +180,9 @@ packDynamicSlicesConservatively(IREE::Stream::ResourcePackOp packOp, SmallVector slices; bool intersects(const Slice &slice) const { for (auto *binSlice : slices) { - if (binSlice->intersects(slice)) + if (binSlice->intersects(slice)) { return true; + } } return false; } @@ -255,9 +255,9 @@ struct LayoutSlicesPass // First pack all static slices as these are entirely knowable here at // compile time. - auto offset = packOp.getOffset() ? packOp.getOffset() : indexSet.get(0); + Value offset = packOp.getOffset() ? packOp.getOffset() : indexSet.get(0); if (!staticSlices.empty()) { - offset = packStaticSlicesGreedily(packOp, offset, staticSlices, + offset = packStaticSlicesGreedily(packOp.getLoc(), offset, staticSlices, resourceConfig, indexSet, builder); // TODO(benvanik): make this an option; it can be useful for debugging @@ -270,8 +270,9 @@ struct LayoutSlicesPass // available we could reuse static slices with non-overlapping lifetimes // in some cases. if (!dynamicSlices.empty()) { - offset = packDynamicSlicesConservatively( - packOp, offset, dynamicSlices, resourceConfig, indexSet, builder); + offset = packDynamicSlicesConservatively(packOp.getLoc(), offset, + dynamicSlices, resourceConfig, + indexSet, builder); } // Total packed length is the current offset after all slices are diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeBuiltins.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeBuiltins.cpp index b9b1b9589e10..e122c3917497 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeBuiltins.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeBuiltins.cpp @@ -351,8 +351,9 @@ struct MaterializeBuiltinsPass MaterializeBuiltinsPass> { void runOnOperation() override { mlir::ModuleOp moduleOp = getOperation(); - if (moduleOp.getBody()->empty()) + if (moduleOp.getBody()->empty()) { return; + } // Find and replace (if needed) ops that we want to turn into builtins // across the entire program. diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeCopyOnWrite.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeCopyOnWrite.cpp index dd9a40d7e80d..65211e9e23b6 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeCopyOnWrite.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeCopyOnWrite.cpp @@ -41,26 +41,29 @@ namespace { static bool isSafeToElideCOW(Value operand, IREE::Stream::ResourceType type) { // Can't do anything with block args without analysis - we don't know if the // value they carry is the last user (move semantics). - if (isa(operand)) + if (isa(operand)) { return false; + } // If our value is a constant then we need to ensure that we aren't // tied to a constant operand. If we are we need to clone to a // non-constant value. We could make this work in cases where constants are // being initialized, however those are best modeled as transfer operations // where no mutations will occur on the constant transfer target. - if (type.getLifetime() == IREE::Stream::Lifetime::Constant) + if (type.getLifetime() == IREE::Stream::Lifetime::Constant) { return false; + } // If there's more than one user we can't make a local decision. It's // expensive to query relative operation order within a block and within a // region the lifetime of values may vary - all things we can't tell here. Operation *firstUser = nullptr; for (Operation *user : operand.getUsers()) { - if (firstUser == nullptr) + if (firstUser == nullptr) { firstUser = user; - else if (firstUser != user) + } else if (firstUser != user) { return false; + } } // We are the only user and the value is contained entirely within the @@ -80,10 +83,12 @@ static Value materializeOperandCOW(Location loc, OpOperand &operand, // has to wait until a subsequent pass. auto resourceType = dyn_cast(operand.get().getType()); - if (!resourceType) + if (!resourceType) { return nullptr; - if (isSafeToElideCOW(operand.get(), resourceType)) + } + if (isSafeToElideCOW(operand.get(), resourceType)) { return nullptr; + } // Materialize a clone operation just for the operand provided. auto sizeAwareType = cast(resourceType); @@ -110,8 +115,9 @@ static bool materializeTiedOpCOW(IREE::Util::TiedOpInterface tiedOp) { auto tiedOperandIndices = tiedOp.getTiedResultOperandIndices(); for (unsigned i = 0; i < tiedOperandIndices.size(); ++i) { int64_t operandIdx = tiedOperandIndices[i]; - if (operandIdx == IREE::Util::TiedOpInterface::kUntiedIndex) + if (operandIdx == IREE::Util::TiedOpInterface::kUntiedIndex) { continue; + } auto &tiedOperand = tiedOp->getOpOperand(operandIdx); // If copy was required and materialized, we should forward it to all @@ -125,8 +131,9 @@ static bool materializeTiedOpCOW(IREE::Util::TiedOpInterface tiedOp) { // TODO(#11249): Support in-place collective operations. if (!isa(tiedOp)) { for (auto &operand : tiedOp->getOpOperands()) { - if (operand.get() == original) + if (operand.get() == original) { operand.set(clone); + } } } } @@ -141,8 +148,9 @@ static bool materializeRegionCOW(Region ®ion) { bool didChange = false; for (auto &block : region.getBlocks()) { for (auto &op : block) { - if (!op.hasTrait()) + if (!op.hasTrait()) { continue; + } didChange = TypeSwitch(&op) .Case values; int64_t offset = 0; for (auto &constantSpan : storageBuffer.spans) { - if (constantSpan.length == 0) + if (constantSpan.length == 0) { continue; + } int64_t start = constantSpan.offset; int64_t end = start + constantSpan.length; @@ -465,8 +466,9 @@ static Value generateSerializedUpload( // will need and where each value will be placed. auto storageResources = computePackingMap(slices, resourceConfig, builder.getContext()); - if (storageResources.empty()) + if (storageResources.empty()) { return nullptr; + } // TODO(benvanik): should be able to have a single buffer constant and // subrange it so that we don't need so many files. @@ -551,8 +553,9 @@ static Value generateParameterUpload( storageResources = computePackingMap(slices, resourceConfig, builder.getContext()); } - if (storageResources.empty()) + if (storageResources.empty()) { return nullptr; + } // Sort resources by type so we can batch them. // Loads are only possible if we are using the parameter as a constant and diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackDispatchOperands.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackDispatchOperands.cpp index 26e58773884a..3427890e5ca5 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackDispatchOperands.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackDispatchOperands.cpp @@ -251,16 +251,18 @@ static Value recomposeFromI32sAndConvert( // Preserve the arg attrs on either the final op or the function argument // if none was required. if (auto definingOp = value.getDefiningOp()) { - if (oldArgAttr) + if (oldArgAttr) { definingOp->setAttrs(oldArgAttr); + } newArgAttrs.push_back(nullptr); } else { newArgAttrs.push_back(oldArgAttr); } // Note that if we had decomposed the arg we'll expect that there are two attr // dicts for the two new args. - if (wasDecomposed) + if (wasDecomposed) { newArgAttrs.push_back(nullptr); + } return value; } @@ -298,7 +300,7 @@ static void updateExportFuncOp(mlir::FunctionOpInterface funcOp) { } //===----------------------------------------------------------------------===// -// --iree-hal-pack-dispatch-operands +// --iree-stream-pack-dispatch-operands //===----------------------------------------------------------------------===// struct PackDispatchOperandsPass @@ -311,8 +313,9 @@ struct PackDispatchOperandsPass for (auto executableOp : getOperation().getOps()) { auto innerModuleOp = executableOp.getInnerModule(); - if (!innerModuleOp) + if (!innerModuleOp) { continue; + } for (auto funcOp : innerModuleOp.getOps()) { if (funcOp.isPublic()) { updateExportFuncOp(funcOp); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp index 4f3879ae1414..d0ec1998266f 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp @@ -184,26 +184,67 @@ void buildStreamAsyncPassPipeline(OpPassManager &passManager, // Specialize the encodings before the lowering of stream tensor ops. passManager.addPass(IREE::Stream::createSpecializeEncodingsPass()); + // Lower stream.tensor.* ops to stream.async.* ops based on + // affinity/configuration assigned during placement. FunctionLikeNest(passManager) - // Run canonicalization after specializing to clean up any - // duplicate/redundant IR and fold any duplicate encoding chains before we - // perform the encoding materialization. - .addPass(mlir::createCanonicalizerPass) - .addPass(mlir::createCSEPass) - - // Lower stream.tensor.* ops to stream.async.* ops based on - // affinity/configuration assigned during placement. .addPass(IREE::Stream::createEncodeHostTensorsPass); passManager.addNestedPass( IREE::Stream::createEncodeDeviceTensorsPass()); passManager.addPass(IREE::Stream::createMaterializeEncodingsPass()); + // Layout packed slices (if any exist yet) to emit the arithmetic required for + // all resource offsets. We introduce more packing ops later on but do want + // to support using the layout utilities earlier if the encodings need them. + // Having the arithmetic baked out allows for better propagation (resource + // offsets and sizes can be detected as constant if statically packed, etc). + FunctionLikeNest(passManager).addPass(IREE::Stream::createLayoutSlicesPass); + buildStreamCleanupPassPipeline(passManager, transformOptions); // Everything must now be in stream.async.* form but we don't yet have - // lifetime assigned. + // lifetime assigned. We don't expect there to be any aliasing or other + // trickery yet as we haven't materialized copy-on-write handling and copy + // elision. passManager.addPass(IREE::Stream::createVerifyLoweringToAsyncResourcesPass()); + // If we want to split out a parameter encoder now is the best time: we have + // all of the encodings specialized but haven't yet started allocating memory + // (which will be entirely different in the split module) and if any are + // multi-targeting we haven't yet materialized their concrete forms. + // + // Once this pass runs the original parameters will (mostly) be removed and + // in place of globally initialized constants will be loads from the new + // encoded parameters. Any packing/layout is done now so that the parameter + // index has a common layout between both modules. + if (transformOptions.parameterEncoderOutputFile.hasValue() && + !transformOptions.parameterEncoderOutputFile.empty()) { + IREE::Stream::SplitParameterEncoderPassOptions encoderPassOptions; + encoderPassOptions.mode = transformOptions.parameterEncoderMode; + encoderPassOptions.outputScope = + transformOptions.parameterEncoderOutputScope; + encoderPassOptions.outputFile = transformOptions.parameterEncoderOutputFile; + passManager.addPass( + IREE::Stream::createSplitParameterEncoderPass(encoderPassOptions)); + + // This is somewhat dangerous in that if there is any aliasing in the + // program this _may_ break it. But we don't allow aliasing at this point of + // the pipeline so that's a risk I'm willing to take. The splitting pass + // introduces resource subview ops that we need to propagate to consumers. + // + // TODO(benvanik): improve stream.async.slice handling in + // ElideAsyncCopiesPass. Today it is local only and it results in parameters + // sliced in initializers being treated as copies. If we fixed that we could + // use stream.async.slice as is appropriate at this phase of lowering and + // remove this pass. + passManager.addPass(IREE::Util::createPropagateSubrangesPass()); + + // DCE any executables no longer required just to make the IR cleaner. + // Often times we'll have quite a few hoisted initialization and encoding + // dispatches that are not used elsewhere in the program (though some may + // be due to deduplication!). + passManager.addPass(mlir::createSymbolDCEPass()); + } + // Materialize copy-on-write behavior with explicit stream.async.* ops. // This will insert a lot of copies, so follow it up with a pass that elides // ones that aren't needed. This is easier to verify than if there was one @@ -346,11 +387,6 @@ void buildStreamOptimizationPassPipeline( // cause duplication. Run CSE to collapse. buildStreamCleanupPassPipeline(passManager, transformOptions); - // If any scf ops crept in we get rid of them here. We should be able to - // support them all the way through the stream dialect but some passes are not - // currently set up to handle them (such as elide timepoints). - FunctionLikeNest(passManager).addPass(mlir::createSCFToControlFlowPass); - //---------------------------------------------------------------------------- // Whole-program scheduling optimization //---------------------------------------------------------------------------- @@ -374,6 +410,11 @@ void buildStreamOptimizationPassPipeline( FunctionLikeNest(passManager) .addPass(IREE::Stream::createReuseAllocationsPass); + // If any scf ops crept in we get rid of them here. We should be able to + // support them all the way through the stream dialect but some passes are + // not currently set up to handle them (such as elide timepoints). + FunctionLikeNest(passManager).addPass(mlir::createSCFToControlFlowPass); + // Elide timepoints in dependency chains where one is known to have been // reached by the time another is (A -> B -> A|C). ipoPipeline.addPass(IREE::Stream::createElideTimepointsPass()); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h index da852c2a5055..4a22080f2f2a 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h @@ -9,6 +9,7 @@ #include "iree/compiler/Dialect/Stream/IR/StreamOps.h" #include "iree/compiler/Dialect/TensorExt/IR/TensorExtDialect.h" +#include "iree/compiler/Utils/OptionUtils.h" #include "llvm/ADT/StringMap.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" @@ -55,6 +56,21 @@ enum class DumpOutputFormat { JSON = 4, }; +// Controls how the encoder manages parameters. +enum class ParameterEncoderMode { + // Merge all encoded parameters and original parameters into a single + // consolidated scope. + Consolidate = 0, + // Only produce encoded parameters and leave original parameters untouched. + Overlay = 1, +}; + +// Options for the Stream transformation pipeline. +// +// These options are typically populated from top-level compiler options +// (ParameterOptions in Pipelines/Options.h) when building the full compiler +// pipeline. When constructing individual passes, relevant options are mapped +// to pass-specific option structs (e.g., SplitParameterEncoderPassOptions). struct TransformOptions : public PassPipelineOptions { Option initializationMode{ *this, @@ -72,6 +88,34 @@ struct TransformOptions : public PassPipelineOptions { "waiting for them to complete.")), }; + Option parameterEncoderMode{ + *this, + "parameter-encoder-mode", + llvm::cl::desc("Controls how the encoder manages parameters."), + llvm::cl::init(ParameterEncoderMode::Consolidate), + llvm::cl::values( + clEnumValN(ParameterEncoderMode::Consolidate, "consolidate", + "Merge all encoded parameters and original parameters " + "into a single consolidated scope."), + clEnumValN(ParameterEncoderMode::Overlay, "overlay", + "Only produce encoded parameters and leave original " + "parameters untouched.")), + }; + Option parameterEncoderOutputScope{ + *this, + "parameter-encoder-output-scope", + llvm::cl::desc( + "Parameter scope for the output parameters. Omit for global."), + llvm::cl::init("encoded"), + }; + Option parameterEncoderOutputFile{ + *this, + "parameter-encoder-output-file", + llvm::cl::desc(".mlir/.mlirbc file path to write the split parameter " + "encoder module to."), + llvm::cl::init(""), + }; + Option optimizeBindings{ *this, "optimize-bindings", diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td index 8a175399a162..9c15ada3b4b5 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td @@ -44,6 +44,56 @@ def ConvertToStreamPass : ]; } +def SplitParameterEncoderPass : + Pass<"iree-stream-split-parameter-encoder", "mlir::ModuleOp"> { + let summary = "Splits out a parameter encoder module for compatible hoisted expressions."; + let description = [{ + }]; + let dependentDialects = [ + "mlir::arith::ArithDialect", + "IREE::HAL::HALDialect", + "IREE::Stream::StreamDialect", + "IREE::Util::UtilDialect", + ]; + let options = [ + Option< + "mode", "mode", + "IREE::Stream::ParameterEncoderMode", + /*default=*/"IREE::Stream::ParameterEncoderMode::Consolidate", + "Controls how the encoder manages parameters.", + [{::llvm::cl::values( + clEnumValN(IREE::Stream::ParameterEncoderMode::Consolidate, "consolidate", "Merge all encoded parameters and original parameters into a single consolidated scope."), + clEnumValN(IREE::Stream::ParameterEncoderMode::Overlay, "overlay", "Only produce encoded parameters and leave original parameters untouched.") + )}] + >, + Option< + "outputScope", "output-scope", + "std::string", /*default=*/"\"encoded\"", + "Parameter scope for the output parameters. Omit for global." + >, + Option< + "outputFile", "output-file", + "std::string", /*default=*/"std::string()", + ".mlir/.mlirbc file path to write the split parameter encoder module to." + >, + Option< + "hoistParameterExpressions", "hoist-parameter-expressions", + "bool", /*default=*/"true", + "Enable hoisting parameter transformation expressions." + >, + Option< + "hoistConstantExpressions", "hoist-constant-expressions", + "bool", /*default=*/"true", + "Enable hoisting pure constant expressions with transformations." + >, + Option< + "maxEncodingGrowthFactor", "max-encoding-growth-factor", + "float", /*default=*/"1.2f", + "Maximum ratio of output size to input parameter size." + >, + ]; +} + def EncodeHostTensorsPass : Pass<"iree-stream-encode-host-tensors", ""> { let summary = "Encodes tensors into storage formats based on affinity and target support."; @@ -707,7 +757,7 @@ def DumpStatisticsPass : Option< "outputFormat", "output-format", "IREE::Stream::DumpOutputFormat", - "IREE::Stream::DumpOutputFormat::Pretty", + /*default=*/"IREE::Stream::DumpOutputFormat::Pretty", "Specifies the output format to produce.", [{::llvm::cl::values( clEnumValN(IREE::Stream::DumpOutputFormat::Pretty, "pretty", "Human-readable pretty printed output."), @@ -717,8 +767,7 @@ def DumpStatisticsPass : >, Option< "outputFile", "output-file", - "std::string", - /*default=*/"std::string()", + "std::string", /*default=*/"std::string()", "File path to write to; or `` for stderr or `-` for stdout." >, ]; diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PropagateTimepoints.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PropagateTimepoints.cpp index f1f5eb7b2f7e..ace1ef6ac0bd 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PropagateTimepoints.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PropagateTimepoints.cpp @@ -64,8 +64,9 @@ static ExpandedGlobalMap expandResourceGlobals(Operation *rootOp, // Gather all of the resource globals in the root. for (auto ®ion : rootOp->getRegions()) { for (auto globalOp : region.getOps()) { - if (!isa(globalOp.getType())) + if (!isa(globalOp.getType())) { continue; + } expandedGlobals[globalOp.getName()].resourceOp = globalOp; } } @@ -113,8 +114,9 @@ static void expandType(Type type, SmallVectorImpl &newTypes) { // Expands resources in the given |types| list to (timepoint, resource). // This could be changed to some iterator magic to avoid the alloc. static SmallVector expandTypes(TypeRange types) { - if (types.empty()) + if (types.empty()) { return {}; + } SmallVector newTypes; newTypes.reserve(types.size() * 2); for (auto type : types) { @@ -199,8 +201,9 @@ static Value makeBlockArgResourceSize(Location loc, Value resourceValue, if (auto sizeAwareOp = dyn_cast_if_present( resourceValue.getDefiningOp())) { auto sizeValue = sizeAwareOp.getResultSizeFromValue(resourceValue); - if (sizeValue) + if (sizeValue) { return sizeValue; + } } // Try first to scan uses in the IR. Since we carry the shape in most ops we @@ -208,11 +211,13 @@ static Value makeBlockArgResourceSize(Location loc, Value resourceValue, for (auto &use : resourceValue.getUses()) { auto sizeAwareOp = dyn_cast(use.getOwner()); - if (!sizeAwareOp) + if (!sizeAwareOp) { continue; + } auto sizeValue = sizeAwareOp.getOperandSize(use.getOperandNumber()); - if (!sizeValue) + if (!sizeValue) { continue; + } if (sizeValue.getParentRegion()->isProperAncestor( builder.getInsertionBlock()->getParent())) { // Size value found and implicitly captured; we can reuse (could be @@ -242,16 +247,19 @@ static Value makeBlockArgResourceSize(Location loc, Value resourceValue, static void expandRegion(Region ®ion, bool canModifyEntryBlock, SymbolTable &symbolTable, ExpandedGlobalMap &globalMap, IRMapping &resourceTimepointMap) { - if (region.empty()) + if (region.empty()) { return; + } // Update all block arguments. auto timepointType = IREE::Stream::TimepointType::get(region.getContext()); for (auto &block : region.getBlocks()) { - if (!llvm::any_of(block.getArgumentTypes(), isResourceType)) + if (llvm::none_of(block.getArgumentTypes(), isResourceType)) { continue; - if (block.isEntryBlock() && !canModifyEntryBlock) + } + if (block.isEntryBlock() && !canModifyEntryBlock) { continue; + } // Insert and build a list of expanded (timepoint, resource) pairs. // Don't add mappings here - we need to check if wrapExpandedBlockArgFn @@ -259,8 +267,9 @@ static void expandRegion(Region ®ion, bool canModifyEntryBlock, SmallVector> expansions; for (int i = block.getNumArguments() - 1; i >= 0; --i) { auto resourceArg = block.getArgument(i); - if (!isResourceType(resourceArg.getType())) + if (!isResourceType(resourceArg.getType())) { continue; + } auto timepointArg = block.insertArgument(i + 1, timepointType, resourceArg.getLoc()); expansions.push_back(std::make_pair(timepointArg, resourceArg)); @@ -272,8 +281,9 @@ static void expandRegion(Region ®ion, bool canModifyEntryBlock, // If the resource already has an associated timepoint mapping from the // region branch expansion (wrapExpandedBlockArgFn), defer awaiting to // the consumer to avoid over-synchronization at block boundaries. - if (resourceTimepointMap.contains(resource)) + if (resourceTimepointMap.contains(resource)) { continue; + } // Add the mapping for this block arg since we're inserting an await. resourceTimepointMap.map(resource, timepoint); // If we can look down the chain and see the size then we can use that. @@ -325,8 +335,9 @@ static void expandRegion(Region ®ion, bool canModifyEntryBlock, static void expandGlobalLoadOp(IREE::Util::GlobalLoadOpInterface op, ExpandedGlobalMap &globalMap, IRMapping &resourceTimepointMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; + } OpBuilder builder(op); auto &expandedGlobal = globalMap[op.getGlobalName()]; auto timepoint = expandedGlobal.timepointOp.createLoadOp(op.getLoc(), builder) @@ -369,8 +380,9 @@ static void expandGlobalLoadOp(IREE::Util::GlobalLoadOpInterface op, static void expandGlobalStoreOp(IREE::Util::GlobalStoreOpInterface op, ExpandedGlobalMap &globalMap, IRMapping &resourceTimepointMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; + } OpBuilder builder(op); auto timepointOperand = consumeTimepoint( op.getLoc(), op.getStoredGlobalValue(), resourceTimepointMap, builder); @@ -433,13 +445,15 @@ static void expandFuncOp(IREE::Util::FuncOp op, SymbolTable &symbolTable, // stream.timepoint.await %rt, %t static void expandCallOp(IREE::Util::CallOp op, SymbolTable &symbolTable, IRMapping &resourceTimepointMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; + } // Ignore calls to public/external functions. auto calleeOp = symbolTable.lookup(op.getCallee()); - if (IREE::Util::isPublicOrExternal(calleeOp)) + if (IREE::Util::isPublicOrExternal(calleeOp)) { return; + } // Build the new call op with expanded operands and results. OpBuilder builder(op); @@ -490,10 +504,13 @@ static void expandCallOp(IREE::Util::CallOp op, SymbolTable &symbolTable, // util.return %t, %0 static void expandReturnOp(IREE::Util::ReturnOp op, IRMapping &resourceTimepointMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; - if (IREE::Util::isPublicOrExternal(op->getParentOfType())) + } + if (IREE::Util::isPublicOrExternal( + op->getParentOfType())) { return; + } OpBuilder builder(op); auto operands = expandOperands(op.getLoc(), op.getOperands(), resourceTimepointMap, builder); @@ -514,8 +531,9 @@ static void expandReturnOp(IREE::Util::ReturnOp op, // %1 = stream.timepoint.await %bb_t, %bb_0 static void expandBranchOp(mlir::cf::BranchOp op, IRMapping &resourceTimepointMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; + } OpBuilder builder(op); auto operands = expandOperands(op.getLoc(), op.getDestOperands(), resourceTimepointMap, builder); @@ -525,8 +543,9 @@ static void expandBranchOp(mlir::cf::BranchOp op, static void expandCondBranchOp(mlir::cf::CondBranchOp op, IRMapping &resourceTimepointMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; + } OpBuilder builder(op); mlir::cf::CondBranchOp::create( builder, op.getLoc(), op.getCondition(), op.getTrueDest(), @@ -540,8 +559,9 @@ static void expandCondBranchOp(mlir::cf::CondBranchOp op, static void expandSwitchOp(mlir::cf::SwitchOp op, IRMapping &resourceTimepointMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; + } OpBuilder builder(op); auto caseOperands = llvm::to_vector( llvm::map_range(op.getCaseOperands(), [&](ValueRange operands) { @@ -577,8 +597,9 @@ static void expandAwaitOp(IREE::Stream::TimepointAwaitOp op, // mappings to leak between sibling regions (e.g., scf.if then/else // branches), leading to invalid IR where one branch tries to use a // timepoint defined in another branch. - if (isa(inputOperand)) + if (isa(inputOperand)) { continue; + } resourceTimepointMap.map(inputOperand, op.getAwaitTimepoint()); } } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp index 3378748644f4..c3dfaf6b41dd 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp @@ -132,8 +132,9 @@ struct UsageRefinementPattern : public OpRewritePattern { // Returns true if a change was made. bool applyArgTransition(BlockArgument arg, PatternRewriter &rewriter) const { auto oldType = dyn_cast(arg.getType()); - if (!oldType) + if (!oldType) { return false; + } auto newUsage = analysis.lookupResourceUsage(arg); auto newLifetime = convertUsageToLifetime(newUsage); auto newType = rewriter.getType(newLifetime); @@ -155,8 +156,9 @@ struct UsageRefinementPattern : public OpRewritePattern { bool applyResultTransition(Operation *op, Value result, PatternRewriter &rewriter) const { auto oldType = dyn_cast(result.getType()); - if (!oldType) + if (!oldType) { return false; + } auto newUsage = analysis.lookupResourceUsage(result); auto newLifetime = convertUsageToLifetime(newUsage); auto newType = rewriter.getType(newLifetime); @@ -193,8 +195,9 @@ struct UsageRefinementPattern : public OpRewritePattern { IREE::Stream::AffinityAttr affinityAttr, PatternRewriter &rewriter) const { auto oldType = dyn_cast(result.getType()); - if (!oldType) + if (!oldType) { return false; + } auto newUsage = analysis.lookupResourceUsage(result); auto newLifetime = convertUsageToLifetime(newUsage); auto newType = rewriter.getType(newLifetime); @@ -335,8 +338,9 @@ struct ApplyFuncOp : public UsageRefinementPattern { } // Blocks and nested operations: - if (this->applyRegionTransitions(op, rewriter)) + if (this->applyRegionTransitions(op, rewriter)) { didChange = true; + } return success(didChange); } @@ -350,8 +354,9 @@ struct ApplyScfIfOp : public UsageRefinementPattern { for (unsigned i = 0; i < op->getNumResults(); ++i) { auto result = op->getResult(i); if (isa(result.getType())) { - if (this->applyResultTransition(op, result, rewriter)) + if (this->applyResultTransition(op, result, rewriter)) { didChange |= true; + } } } @@ -367,8 +372,9 @@ struct ApplyScfForOp : public UsageRefinementPattern { for (unsigned i = 0; i < op->getNumResults(); ++i) { auto result = op->getResult(i); if (isa(result.getType())) { - if (this->applyResultTransition(op, result, rewriter)) + if (this->applyResultTransition(op, result, rewriter)) { didChange |= true; + } } } return success(didChange); @@ -383,8 +389,9 @@ struct ApplyScfWhileOp : public UsageRefinementPattern { for (unsigned i = 0; i < op->getNumResults(); ++i) { auto result = op->getResult(i); if (isa(result.getType())) { - if (this->applyResultTransition(op, result, rewriter)) + if (this->applyResultTransition(op, result, rewriter)) { didChange |= true; + } } } @@ -406,8 +413,9 @@ struct ApplyGenericOp : public UsageRefinementPattern { for (unsigned i = 0; i < op->getNumResults(); ++i) { auto result = op->getResult(i); if (isa(result.getType())) { - if (this->applyResultTransition(op, result, rewriter)) + if (this->applyResultTransition(op, result, rewriter)) { didChange = true; + } } } if (didChange) { @@ -499,6 +507,7 @@ static void insertUsageRefinementPatterns(MLIRContext *context, ApplyGenericOp, ApplyGenericOp, ApplyGenericOp, + ApplyGenericOp, ApplyGenericOp, ApplyGenericOp>(context, analysis); @@ -534,8 +543,9 @@ struct RefineUsagePass : public IREE::Stream::impl::RefineUsagePassBase { void runOnOperation() override { mlir::ModuleOp moduleOp = getOperation(); - if (moduleOp.getBody()->empty()) + if (moduleOp.getBody()->empty()) { return; + } // Run analysis on the entire module. ResourceUsageAnalysis analysis(moduleOp); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp index fb11759a5682..2e50bd325a2f 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp @@ -47,8 +47,9 @@ class ValueAliasingSet { SmallVector> getValueAliasSets() const { SmallVector> result; for (auto it = valueAliasing.begin(); it != valueAliasing.end(); ++it) { - if (!(*it)->isLeader()) + if (!(*it)->isLeader()) { continue; // Ignore non-leader sets. + } auto &aliasSet = result.emplace_back(); for (auto mi = valueAliasing.member_begin(**it); mi != valueAliasing.member_end(); ++mi) { @@ -110,8 +111,9 @@ static void computeRegionValueAliases(Operation *regionOp, // Tied results reuse their operand buffer. auto tiedOp = dyn_cast(op); for (auto result : op.getResults()) { - if (!isa(result.getType())) + if (!isa(result.getType())) { continue; + } if (tiedOp) { auto tiedOperand = tiedOp.getTiedResultOperand(result); if (tiedOperand) { @@ -181,8 +183,9 @@ computeExecutionRegionLivenessIntervals(IREE::Stream::AsyncExecuteOp executeOp, SmallPtrSet liveOuts; auto yieldOp = cast(streamBlock->back()); for (auto returnValue : yieldOp.getResourceOperands()) { - if (!isa(returnValue.getType())) + if (!isa(returnValue.getType())) { continue; + } liveOuts.insert(returnValue); } @@ -191,8 +194,9 @@ computeExecutionRegionLivenessIntervals(IREE::Stream::AsyncExecuteOp executeOp, LivenessIntervalMap valueIntervals; int ordinal = 0; for (Value value : streamBlock->getArguments()) { - if (!isa(value.getType())) + if (!isa(value.getType())) { continue; + } LivenessInterval interval; interval.start = LIVE_IN; if (liveOuts.contains(value)) { @@ -218,16 +222,19 @@ computeExecutionRegionLivenessIntervals(IREE::Stream::AsyncExecuteOp executeOp, // the duration of the region. concurrentOp.walk([&](Operation *op) { for (auto value : op->getResults()) { - if (!isa(value.getType())) + if (!isa(value.getType())) { continue; + } if (auto tiedOp = dyn_cast(op)) { // Skip tied results as their liveness is determined by the tied // operand. - if (tiedOp.getTiedResultOperand(value)) + if (tiedOp.getTiedResultOperand(value)) { continue; + } } - if (!value.use_empty()) + if (!value.use_empty()) { continue; + } LivenessInterval interval; interval.start = start; interval.end = start; @@ -238,8 +245,9 @@ computeExecutionRegionLivenessIntervals(IREE::Stream::AsyncExecuteOp executeOp, }); } for (auto value : op.getResults()) { - if (!isa(value.getType())) + if (!isa(value.getType())) { continue; + } LivenessInterval interval; interval.start = start; if (liveOuts.contains(value)) { @@ -267,8 +275,9 @@ computeExecutionRegionLivenessIntervals(IREE::Stream::AsyncExecuteOp executeOp, // We'd need to update this analysis to handle the nesting in order to // compute the ranges here but that's not (currently) required as all // allocated values roll up to the parent scope by way of the yields. - if (llvm::all_of(aliasSet, isNested)) + if (llvm::all_of(aliasSet, isNested)) { continue; + } assert((llvm::all_of(aliasSet, isNested) || llvm::none_of(aliasSet, isNested)) && @@ -371,8 +380,9 @@ struct AllocationScope { // Returns a memoized ConstantIndexOp of |value|. Value lookupOrCreateIndex(int64_t value) { auto it = indexConstantMap.find(value); - if (it != indexConstantMap.end()) + if (it != indexConstantMap.end()) { return it->second; + } auto constantValue = OpBuilder(rootOp).createOrFold( rootOp->getLoc(), value); indexConstantMap.insert(std::make_pair(value, constantValue)); @@ -382,10 +392,12 @@ struct AllocationScope { // Performs a memoized add (as many adds of offsets or lengths are redundant). Value add(Location loc, Value lhs, Value rhs) { // TODO(benvanik): memoize - if worth it. Needs profiling. - if (matchPattern(lhs, m_Zero())) + if (matchPattern(lhs, m_Zero())) { return rhs; - if (matchPattern(rhs, m_Zero())) + } + if (matchPattern(rhs, m_Zero())) { return lhs; + } auto result = OpBuilder(rootOp).createOrFold(loc, lhs, rhs); return result; } @@ -394,8 +406,9 @@ struct AllocationScope { // All aliases of |resource| will also be mapped. void mapResourceRange(Value resource, ResourceRange resourceRange, AsmState *asmState) { - if (resourceRangeMap.count(resource)) + if (resourceRangeMap.count(resource)) { return; + } if (!resourceRange.offset && !resourceRange.length) { resourceRange.offset = lookupOrCreateIndex(0); @@ -957,8 +970,9 @@ applyAsyncAllocations(IREE::Stream::AffinityAttr executionAffinityAttr, auto ops = llvm::map_to_vector(llvm::reverse(block), [&](Operation &op) { return &op; }); for (auto *op : ops) { - if (op->hasTrait()) + if (op->hasTrait()) { continue; + } if (failed(TypeSwitch(op) .Case([&](IREE::Stream::ResourceSubviewOp op) { return applyResourceSubviewOp(op, scope, OpBuilder(op)); @@ -1053,8 +1067,9 @@ allocateLocalTransients(IREE::Stream::AsyncExecuteOp executeOp, auto value = valueInterval.value; assert(value && "must have value for interval"); auto valueType = dyn_cast(value.getType()); - if (!valueType) + if (!valueType) { continue; + } // Only handle transient buffers (created/used/dropped within the stream). if (valueInterval.start == LIVE_IN || valueInterval.end == LIVE_OUT) { @@ -1268,8 +1283,9 @@ struct ConstantAllocation { // Returns true if |value| has one use and it is a stream.yield op. static bool isOnlyUseYield(Value value) { for (auto *user : value.getUsers()) { - if (!isa(user)) + if (!isa(user)) { return false; + } } return true; } @@ -1552,8 +1568,9 @@ gatherSubranges(Value derivedValue) { while (auto definingOp = dyn_cast_if_present( baseValue.getDefiningOp())) { auto tiedValue = definingOp.getTiedResultOperand(baseValue); - if (!tiedValue) + if (!tiedValue) { break; + } if (auto subrangeOp = dyn_cast( definingOp.getOperation())) { if (subrangeOp.getSubrangeResource() == tiedValue) { @@ -1580,8 +1597,9 @@ static ResourceRange deriveResourceRangeFromResult(Value resultValue, Value resultSize, OpBuilder &builder) { auto subranges = gatherSubranges(resultValue); - if (subranges.empty()) + if (subranges.empty()) { return ResourceRange(resultValue, resultSize); + } // TODO(benvanik): switch to affine.apply when fully supported. Value offset; @@ -1716,8 +1734,9 @@ allocateExecutionRegion(IREE::Stream::AsyncExecuteOp executeOp, // Replace results of escaping uploads with the upload values. for (auto &reservation : constantAllocation.reservations) { auto result = findTiedYieldResult(reservation.constantOp.getResult()); - if (!result) + if (!result) { continue; + } result.replaceAllUsesWith(reservation.resource); handledResults.insert(result); LLVM_DEBUG({ @@ -1954,8 +1973,9 @@ allocateExecutionRegion(IREE::Stream::AsyncExecuteOp executeOp, executeOp.getResultTimepoint().replaceAllUsesWith( newExecuteOp.getResultTimepoint()); for (auto replacement : resultReplacements) { - if (!replacement.second) + if (!replacement.second) { continue; // handled already + } LLVM_DEBUG({ AsmState asmState(newExecuteOp->getParentOp()); llvm::dbgs() << " == replacing region result "; diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleConcurrency.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleConcurrency.cpp index 504a42fef10a..5625923d5c76 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleConcurrency.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleConcurrency.cpp @@ -54,10 +54,12 @@ struct WavePartitionBuilder { Operation *insertionPt = nullptr; for (auto in : partition->ins) { auto *definingOp = in.getDefiningOp(); - if (!definingOp) + if (!definingOp) { continue; - if (definingOp->getBlock() != parentBlock) + } + if (definingOp->getBlock() != parentBlock) { continue; + } if (!insertionPt) { insertionPt = definingOp; // first defining op } else if (insertionPt->isBeforeInBlock(definingOp)) { @@ -83,8 +85,9 @@ struct WavePartitionBuilder { resultTypes.push_back(out.getType()); auto resultSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( fusedLoc, out, parentBuilder); - if (resultSize) + if (resultSize) { resultSizes.push_back(resultSize); + } } SmallVector operands; SmallVector operandTypes; @@ -93,14 +96,16 @@ struct WavePartitionBuilder { operandTypes.reserve(partition->ins.size()); operandSizes.reserve(partition->ins.size()); for (auto in : partition->ins) { - if (!isa(in.getType())) + if (!isa(in.getType())) { continue; + } operands.push_back(in); operandTypes.push_back(in.getType()); auto operandSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( fusedLoc, in, parentBuilder); - if (operandSize) + if (operandSize) { operandSizes.push_back(operandSize); + } } // TODO(benvanik): tie operands, or leave to canonicalization. @@ -134,8 +139,9 @@ struct WavePartitionBuilder { // // Returns true if the operation was cloned into the partition. bool visit(Operation *op) { - if (!partition->ops.contains(op)) + if (!partition->ops.contains(op)) { return false; + } // Clone the op into the partition and remap it. auto *clonedOp = builder.clone(*op, mapping); @@ -159,8 +165,9 @@ struct WavePartitionBuilder { results.push_back(newResult); auto resultSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( concurrentOp.getLoc(), newResult, builder); - if (resultSize) + if (resultSize) { resultSizes.push_back(resultSize); + } } IREE::Stream::YieldOp::create(builder, concurrentOp.getLoc(), results, resultSizes); @@ -188,8 +195,9 @@ struct ScheduleConcurrencyPass } for (auto executeOp : parentOp.getCallableRegion()->getOps()) { - if (failed(runOnRegion(executeOp))) + if (failed(runOnRegion(executeOp))) { return signalPassFailure(); + } } } @@ -205,10 +213,12 @@ struct ScheduleConcurrencyPass // Compute a set of partitions covering all of the streamable ops in the // execution region. auto waveSet = partitionRegionConcurrency(configAttr, block); - if (waveSet.empty()) + if (waveSet.empty()) { return success(); - if (failed(waveSet.verify(parentOp.getLoc()))) + } + if (failed(waveSet.verify(parentOp.getLoc()))) { return failure(); + } // Create partition builders for each partition. // We'll clone ops into each and insert them into the block at the @@ -217,8 +227,9 @@ struct ScheduleConcurrencyPass SmallVector partitionBuilders; partitionBuilders.reserve(waveSet.size()); for (auto partition : llvm::enumerate(waveSet.partitions)) { - if (partition.value().ops.size() == 1) + if (partition.value().ops.size() == 1) { continue; + } partitionBuilders.push_back(WavePartitionBuilder(block, partition.index(), &partition.value(), mapping, &getContext())); @@ -231,8 +242,9 @@ struct ScheduleConcurrencyPass // creates a lot of new IR (up to O(op*partitions)). SetVector deadOps; for (auto &op : *block) { - if (op.hasTrait()) + if (op.hasTrait()) { continue; + } bool handled = false; for (auto &partitionBuilder : partitionBuilders) { handled = partitionBuilder.visit(&op) || handled; diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp index 76976a163177..457e30a03edb 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp @@ -55,8 +55,9 @@ struct ExecutePartitionBuilder { // This is at the last op in the partition. Operation *insertionPt = nullptr; for (auto *op : partition->ops) { - if (op->getBlock() != parentBlock) + if (op->getBlock() != parentBlock) { continue; + } if (!insertionPt) { insertionPt = op; // first defining op } else if (insertionPt->isBeforeInBlock(op)) { @@ -82,8 +83,9 @@ struct ExecutePartitionBuilder { resultTypes.push_back(out.getType()); auto resultSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( fusedLoc, out, parentBuilder); - if (resultSize) + if (resultSize) { resultSizes.push_back(resultSize); + } } SmallVector operands; SmallVector operandTypes; @@ -92,14 +94,16 @@ struct ExecutePartitionBuilder { operandTypes.reserve(partition->ins.size()); operandSizes.reserve(partition->ins.size()); for (auto in : partition->ins) { - if (!isa(in.getType())) + if (!isa(in.getType())) { continue; + } operands.push_back(in); operandTypes.push_back(in.getType()); auto operandSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( fusedLoc, in, parentBuilder); - if (operandSize) + if (operandSize) { operandSizes.push_back(operandSize); + } } // Collect await timepoints from all ops being partitioned and join them. @@ -148,8 +152,9 @@ struct ExecutePartitionBuilder { // // Returns true if the operation was cloned into the partition. bool visit(Operation *op) { - if (!partition->ops.contains(op)) + if (!partition->ops.contains(op)) { return false; + } // Clone the op into the partition and remap it. auto *clonedOp = builder.clone(*op, mapping); @@ -197,8 +202,9 @@ struct ExecutePartitionBuilder { results.push_back(newResult); auto resultSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( executeOp.getLoc(), newResult, builder); - if (resultSize) + if (resultSize) { resultSizes.push_back(resultSize); + } } IREE::Stream::YieldOp::create(builder, executeOp.getLoc(), results, resultSizes); @@ -228,8 +234,9 @@ static SmallVector sortBlocksInDominanceOrder(Region ®ion) { } llvm::SmallSetVector markedBlocks; std::function visit = [&](Block *block) { - if (markedBlocks.count(block) > 0) + if (markedBlocks.count(block) > 0) { return; + } for (auto *childBlock : dominanceInfo.getNode(block)->children()) { visit(childBlock->getBlock()); } @@ -322,8 +329,9 @@ LogicalResult processRegion(Location loc, MLIRContext *context, Region ®ion, // creates a lot of new IR (up to O(op*partitions)). SetVector deadOps; for (auto &op : *block) { - if (op.hasTrait()) + if (op.hasTrait()) { continue; + } for (auto &partitionBuilder : partitionBuilders) { partitionBuilder.visit(&op); } @@ -436,8 +444,9 @@ LogicalResult processRegion(Location loc, MLIRContext *context, Region ®ion, } for (auto &subregion : op.getRegions()) { - if (failed(processRegion(loc, context, subregion, configAttr))) + if (failed(processRegion(loc, context, subregion, configAttr))) { return failure(); + } } } } @@ -479,8 +488,9 @@ struct ScheduleExecutionPass // order so that we are sure if we replace values that dominate other blocks // they see the correct values. auto ®ion = *parentOp.getCallableRegion(); - if (failed(processRegion(parentOp.getLoc(), context, region, configAttr))) + if (failed(processRegion(parentOp.getLoc(), context, region, configAttr))) { return signalPassFailure(); + } // Cleanup the dead ops. // TODO(benvanik): less work here - maybe no patterns to just force folding? diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeDispatches.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeDispatches.cpp index 38ce8d49b77a..50391ac33946 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeDispatches.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeDispatches.cpp @@ -64,8 +64,9 @@ buildConstantTable(mlir::FunctionOpInterface funcOp, llvm::BitVector constantOperandMap(operandCount, /*t=*/true); for (auto dispatchOp : dispatchOps) { for (unsigned idx = 0; idx < operandCount; ++idx) { - if (!constantOperandMap.test(idx)) + if (!constantOperandMap.test(idx)) { continue; + } auto value = dispatchOp.getUniformOperands()[idx]; Attribute constantValue; if (!matchPattern(value, m_Constant(&constantValue))) { @@ -86,8 +87,9 @@ buildConstantTable(mlir::FunctionOpInterface funcOp, DenseMap typeSets; SmallVector typeOrder; for (unsigned idx = 0; idx < operandCount; ++idx) { - if (!constantOperandMap.test(idx)) + if (!constantOperandMap.test(idx)) { continue; + } auto operandType = anyDispatchOp.getUniformOperands()[idx].getType(); auto &set = typeSets[operandType]; if (!set.type) { @@ -286,15 +288,17 @@ specializeDispatches(IREE::Stream::ExecutableOp executableOp, IREE::Stream::ExecutableExportOp exportOp, SmallVector &dispatchOps, MemoizedCmdConstants &memoizedConstants) { - if (dispatchOps.empty()) + if (dispatchOps.empty()) { return; // no-op if no dispatches + } auto funcOp = exportOp.lookupFunctionRef(); // Build a constant table for unique per-dispatch constant values. auto constantTable = buildConstantTable(funcOp, dispatchOps); - if (constantTable.coveredOperands.none()) + if (constantTable.coveredOperands.none()) { return; + } LLVM_DEBUG({ AsmState asmState(executableOp->getParentOp()); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/SplitParameterEncoder.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SplitParameterEncoder.cpp new file mode 100644 index 000000000000..de0c6fdb001f --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SplitParameterEncoder.cpp @@ -0,0 +1,2216 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include + +#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h" +#include "iree/compiler/Dialect/Stream/IR/StreamOps.h" +#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" +#include "iree/compiler/Dialect/Stream/Transforms/Passes.h" +#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "iree/compiler/Utils/IntegerSet.h" +#include "iree/compiler/Utils/ModuleUtils.h" +#include "iree/compiler/Utils/RegionOpUtils.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FileSystem.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Transforms/CSE.h" +#include "mlir/Transforms/RegionUtils.h" + +#define DEBUG_TYPE "iree-stream-split-parameter-encoder" +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") + +namespace mlir::iree_compiler::IREE::Stream { + +#define GEN_PASS_DEF_SPLITPARAMETERENCODERPASS +#include "iree/compiler/Dialect/Stream/Transforms/Passes.h.inc" + +namespace { + +//===----------------------------------------------------------------------===// +// Utilities +//===----------------------------------------------------------------------===// + +// Collection of IndexSets for managing index memoization across functions. +class IndexSetCollection { +public: + // Returns an index set for the parent function of |op|. + IndexSet *get(Operation *op) { + auto parentOp = op->getParentOfType(); + auto it = funcMap.find(parentOp); + if (it != funcMap.end()) { + return it->second.get(); + } + auto indexSet = std::make_unique( + op->getLoc(), OpBuilder::atBlockBegin(&parentOp.front())); + IndexSet *indexSetPtr = indexSet.get(); + funcMap.insert({parentOp, std::move(indexSet)}); + return indexSetPtr; + } + +private: + DenseMap> funcMap; +}; + +// Erases all ops in |leafOps| and all of their potentially newly-dead +// transitive producer dependencies. +// +// This custom DCE is required because MLIR's standard mlir::dce::removeDeadCode +// doesn't handle two cases we need: +// 1. Ops implementing HoistableOpInterface - control flow ops like scf.for/if +// whose bodies contain only hoistable (pure) operations can be deleted. +// 2. Ops with MemoryEffects::Read but no Write - these are "pure-ish" ops that +// can't be marked Pure (which would allow CSE) but are still safe to delete. +// +// TODO(benvanik): figure out how to move this to RegionOpUtils - it relies on +// some util op interfaces, though, so is hard to get out there now. +static void pruneDeadOps(ArrayRef leafOps) { + SmallVector deadOpWorklist{leafOps}; + + // Use a DenseSet to track already-processed operations to avoid duplicate + // processing when operations appear multiple times in the worklist. + DenseSet processedOps; + while (!deadOpWorklist.empty()) { + Operation *op = deadOpWorklist.pop_back_val(); + + // Skip if we've already processed this operation. + if (!processedOps.insert(op).second) { + continue; + } + + // Skip if the operation is no longer trivially dead (may have been + // deleted already or gained new uses). + // Also elide ops with no uses that have MemoryEffects::Effect::Read but no + // writes - these match the semantics of canonicalization ElideUnusedOp + // patterns (ops that are pure-ish but can't be marked Pure due to CSE). + bool canDelete = mlir::isOpTriviallyDead(op); + if (!canDelete && op->use_empty()) { + auto memInterface = dyn_cast(op); + if (memInterface) { + SmallVector effects; + memInterface.getEffects(effects); + // Safe to delete if it only has Allocate/Read effects (no Write). + canDelete = + llvm::none_of(effects, [](const MemoryEffects::EffectInstance &it) { + return isa(it.getEffect()); + }); + } else if (auto hoistableOp = + dyn_cast(op)) { + // Operations with HoistableOpInterface can be deleted if unused and + // hoistable (pure). This handles control flow ops like scf.for/scf.if + // whose bodies contain only hoistable operations. + canDelete = hoistableOp.isHoistableOp(); + } + } + if (!canDelete) { + continue; + } + + // Collect defining operations before we delete this op. + SmallVector producerOps; + for (Value operand : op->getOperands()) { + if (Operation *producer = operand.getDefiningOp()) { + producerOps.push_back(producer); + } + } + + // Erase the dead operation. + op->erase(); + + // Check if any of the producers now have no uses and add them to the + // worklist. The worklist loop will determine if they're safe to delete. + for (Operation *producer : producerOps) { + if (producer->use_empty()) { + deadOpWorklist.push_back(producer); + } + } + } +} + +//===----------------------------------------------------------------------===// +// EncodingExpr +//===----------------------------------------------------------------------===// + +// Configuration controlling which expressions are hoisted to the encoder +// module. This policy determines hoisting eligibility based on expression type, +// size growth limits, and parameter/constant handling preferences. +struct EncodingPolicy { + // Pack multiple parameters into larger slabs to reduce overheads. + // This can dramatically improve startup time, reduces memory fragmentation, + // and reduces dispatch overheads. + bool packParameters = true; // false; + // Include direct parameter loads that have no modifications. + // When true the output parameter indices will have all required parameters + // and any original parameters will not be required by the base program at + // runtime. When false the user must provide the original parameters. + bool includeUnmodified = true; + // Any splat under this size will be serialized to the output parameter index + // as if it were data instead of being embedded as a splat. + // This increases the file size but allows for better parameter batching and + // can reduce runtime overhead. + int64_t serializeSplatSizeThreshold = 1024; + + // Enable hoisting parameter transformation expressions. + // When true, expressions that transform parameters (parameter → + // dispatch/encoding) will be extracted into the encoder module for offline + // evaluation. + bool hoistParameterExpressions = true; + + // Enable hoisting pure constant expressions with transformations. + // When true, expressions that transform pure constants (constant → + // dispatch/encoding) will be extracted into the encoder module for offline + // evaluation. + bool hoistConstantExpressions = true; + + // Maximum ratio of output size to input size before rejecting hoisting. + // This prevents expressions that significantly increase storage from being + // hoisted. Example: 1.2 allows 20% growth for padding/alignment. + float maxEncodingGrowthFactor = 1.2f; +}; + +// An encoding expression represents a subgraph of operations that transforms +// input parameters/constants into output values stored to globals. Each +// expression can have multiple inputs (parameter loads) and multiple outputs +// (global stores). The expression is hoisted to the encoder module where it +// can be evaluated offline, with the results stored as pre-encoded parameters. +struct EncodingExpr { + // Affinity of consumers of the expression in the original program. + // All outputs share the same affinity. + IREE::Stream::AffinityAttr affinityAttr; + + struct Input { + // Inlined constant resource or parameter load. + mutable IREE::Stream::AsyncConstantOp constantOp; + + Location getLoc() const { return constantOp.getLoc(); } + + // Returns true if the input is sourced from a parameter. + bool isParameter() const { + return isa(constantOp.getValue()); + } + }; + SmallVector inputs; + + struct Output { + // Size in bytes of the output resource. + int64_t size = 0; + // Constant pattern value if this is a splat. + TypedAttr splatPattern; + // Sink op storing the produced output into a global. + mutable IREE::Util::GlobalStoreOpInterface storeOp; + // Produced value feeding into the store. + // This may be either be directly consumed by the store or an op earlier in + // the slice in cases where there are metadata ops we want to skip. + Value producedValue; + + Location getLoc() const { return storeOp.getLoc(); } + + // Returns true if the output is a constant splat that needs no execution. + // Only certain data types/widths are supported in the format and if not + // supported natively we'll need to splat the value into the file. It's + // rare for there to be splats that end up like this and it's unlikely the + // user wants a file full of splatted values but at this point in the + // pipeline we can only assume they asked for it. + bool isSupportedSplat() const { + if (!splatPattern || !splatPattern.getType().isIntOrFloat()) { + return false; + } + const unsigned bitWidth = splatPattern.getType().getIntOrFloatBitWidth(); + return bitWidth == 8 || bitWidth == 16 || bitWidth == 32 || + bitWidth == 64; + } + }; + SmallVector outputs; + + // All operations (excluding outputs). + SetVector ops; + + // Returns a fused location from all operations in the expression. + Location getLoc() const { + SetVector locs; + for (auto *op : ops) { + locs.insert(op->getLoc()); + } + for (auto &output : outputs) { + locs.insert(output.getLoc()); + } + return FusedLoc::get(ops.front()->getContext(), locs.getArrayRef()); + } + + // Returns the resource config for the expression by checking all outputs. + // If any outputs have differing configs + IREE::Stream::ResourceConfigAttr getResourceConfigAttr() const { + // Expressions should only be formed from outputs that share an affinity + // so we can look at the first output and assume they all match. + if (outputs.empty()) { + return {}; + } + auto globalStoreOp = outputs.front().storeOp; + Value storedValue = globalStoreOp.getStoredGlobalValue(); + auto *producingOp = storedValue.getDefiningOp(); + return IREE::Stream::ResourceConfigAttr::lookup( + producingOp ? producingOp : globalStoreOp); + } + + // Returns true if the expression has any parameter inputs. + bool hasParameterInputs() const { + return llvm::any_of(inputs, + [](const Input &input) { return input.isParameter(); }); + } + + // Returns true if the expression has any constant inputs (non-parameter). + bool hasConstantInputs() const { + return llvm::any_of( + inputs, [](const Input &input) { return !input.isParameter(); }); + } + + // Estimates total input size from all inputs in bytes. + int64_t estimateInputSize() const { + int64_t total = 0; + for (const auto &input : inputs) { + if (input.constantOp) { + Value sizeValue = input.constantOp.getResultSize(); + APInt size; + if (matchPattern(sizeValue, m_ConstantInt(&size))) { + total += size.getZExtValue(); + } + } + } + return total; + } + + // Estimates total output size from all outputs in bytes. + int64_t estimateOutputSize() const { + int64_t total = 0; + for (const auto &output : outputs) { + total += output.size; + } + return total; + } +}; + +struct EncodingExprSet { + // All expressions terminating in parameter outputs in the order they were + // originally present in the module (even if split across initializers). + SmallVector exprs; + + bool empty() const { return exprs.empty(); } +}; + +// Collects all external timepoint dependencies from the expression. This +// includes await timepoints from TimelineOpInterface ops in the expression that +// reference external values, and timepoints from external resource operands +// extracted via getResultTimepoint or by inserting a barrier. +static Value collectExternalTimepoints(const EncodingExpr &expr, + OpBuilder &builder) { + SetVector timepoints; + + // Build a set of ops that contribute RESOURCES (not just timepoints) to the + // expression. An op is a "resource contributor" if at least one of its + // non-timepoint results is used by another op in the expression. + // + // This distinction is important because the backward slice follows ALL + // operands including await timepoints. Ops that only contribute timepoints + // (like a timeline_op whose resource output is unused) should be considered + // "external" for synchronization purposes - their timepoints need to be + // awaited by the replacement op. + DenseSet resourceContributors; + for (Operation *op : expr.ops) { + for (Value result : op->getResults()) { + // Skip timepoint results - we only care about resource contributions. + if (isa(result.getType())) { + continue; + } + // Check if any user of this non-timepoint result is in the expression. + for (Operation *user : result.getUsers()) { + if (expr.ops.contains(user)) { + resourceContributors.insert(op); + break; + } + } + if (resourceContributors.contains(op)) { + break; + } + } + } + + // A timepoint is "internal" only if its defining op contributes resources + // (not just timepoints) to the expression. + auto isInternalTimepoint = [&](Value tp) -> bool { + Operation *defOp = tp.getDefiningOp(); + return defOp && resourceContributors.contains(defOp); + }; + + // Collect external await timepoints from resource-contributing ops only. + // We only look at resource contributors because: + // 1. They represent the "core" data flow of the expression + // 2. Non-resource-contributor ops (like joins, unused timeline ops) are + // "synchronization helpers" whose await timepoints are transitively + // covered by the resource contributors' awaits + // This ensures we don't collect both a joined timepoint AND its component + // timepoints when a join is in the expression but doesn't contribute + // resources. + for (Operation *op : resourceContributors) { + auto timelineOp = dyn_cast(op); + if (!timelineOp) { + continue; + } + for (Value awaitTp : timelineOp.getAwaitTimepoints()) { + if (!isInternalTimepoint(awaitTp)) { + timepoints.insert(awaitTp); + } + } + } + + // A resource is "internal" only if its defining op contributes resources + // (not just timepoints) to the expression. + auto isInternalResource = [&](Value resource) -> bool { + Operation *defOp = resource.getDefiningOp(); + return defOp && resourceContributors.contains(defOp); + }; + + // Collect timepoints from external resource operands. + for (Operation *op : expr.ops) { + for (Value operand : op->getOperands()) { + if (!isa(operand.getType())) { + continue; + } + if (isInternalResource(operand)) { + continue; + } + + // Try to get timepoint from TimelineOpInterface. + Value timepoint; + Operation *definingOp = operand.getDefiningOp(); + if (definingOp) { + if (auto timelineOp = + dyn_cast(definingOp)) { + timepoint = timelineOp.getResultTimepoint(); + } + } + + // If no timepoint available, insert barrier to extract it. + if (!timepoint) { + Value resourceSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( + operand.getLoc(), operand, builder); + assert(resourceSize && "stream resource must have queryable size"); + auto affinityAttr = IREE::Stream::AffinityAttr::lookup(definingOp); + auto barrierOp = IREE::Stream::TimepointBarrierOp::create( + builder, operand.getLoc(), operand.getType(), + builder.getType(), operand, + resourceSize, affinityAttr); + timepoint = barrierOp.getResultTimepoint(); + } + + if (timepoint) { + timepoints.insert(timepoint); + } + } + } + + if (timepoints.empty()) { + return {}; + } + return IREE::Stream::joinTimepoints( + expr.getLoc(), SmallVector(timepoints.begin(), timepoints.end()), + builder); +} + +// Finds all util.global.store-like ops that store constant resources in +// initializers. Stores are returned in program order. +// +// TODO: note that this does not check for stores in functions called by +// initializers and also does not currently check for variables (as they are +// usually uninitialized). +static SmallVector +findAllConstantStoreOps(mlir::ModuleOp moduleOp) { + SmallVector storeOps; + for (auto initializerOp : + moduleOp.getOps()) { + // Skip initializers that have CFGs. We don't handle conditional + // initialization of globals today. + auto ®ion = initializerOp.getInitializerRegion(); + if (!region.hasOneBlock()) { + LLVM_DEBUG(DBGS() << "ignoring initializer as it has multiple blocks\n"); + continue; + } + // Find all stores. Note that we purposefully skip nested regions today. + for (auto storeOp : + region.front().getOps()) { + Type storedType = storeOp.getStoredGlobalValue().getType(); + if (auto resourceType = + dyn_cast(storedType)) { + if (resourceType.getLifetime() == IREE::Stream::Lifetime::Constant) { + storeOps.push_back(storeOp); + } + } + } + } + return storeOps; +} + +// Returns true if the operation's memory effects allow it to be hoisted as a +// const-expr operation. We allow Allocate and Free effects (memory management) +// but reject Read/Write effects to external memory. +static bool hasHoistableMemoryEffects(Operation *op) { + auto effectInterface = dyn_cast(op); + if (!effectInterface) { + // No memory effect interface means no effects - hoistable. + return true; + } + + SmallVector effects; + effectInterface.getEffects(effects); + + for (const auto &effect : effects) { + // Allocate effects are fine (creating new memory). + if (isa(effect.getEffect())) { + continue; + } + // Free effects are also fine (releasing memory). + if (isa(effect.getEffect())) { + continue; + } + // Read or Write effects on non-result values are not const-expr. + // Operations can write to their own results (that's how they produce + // them), but reading/writing external memory is not allowed. + if (isa(effect.getEffect()) || + isa(effect.getEffect())) { + // Check if the effect is on a result of this op (allowed) or + // on external memory (not allowed). + if (Value value = llvm::dyn_cast_if_present(effect.getValue())) { + // If it's a result of this op, it's fine. + if (value.getDefiningOp() == op) { + continue; + } + } + // Read/Write to external memory - not const-expr. + return false; + } + } + + return true; +} + +static bool isConstExprOp(Operation *op) { + // Optimization barriers cannot be folded. + if (isa(op)) { + return false; + } + + // By default, ops without results are not const-expr. + if (op->getNumResults() == 0) { + return false; + } + + // If implementing the HoistableOpInterface, just use the decision made by + // the interface. + if (auto hoistableOp = dyn_cast(op)) { + return hoistableOp.isHoistableOp(); + } + + // Forbid if part of a parent that should be treated atomically. + Operation *parent = op; + while (auto hoistableParent = + parent->getParentOfType()) { + if (hoistableParent.isAtomicallyHoistableOp()) { + return false; + } + parent = hoistableParent; + } + + // Check memory effects: we allow Allocate effects (creating new memory) + // but reject Read/Write effects to external memory. This is more permissive + // than OpOracle.cpp's isMemoryEffectFree check, allowing operations like + // stream.async.splat that allocate but don't have other side effects. + return hasHoistableMemoryEffects(op); +} + +static IREE::Stream::AffinityAttr +lookupConsumerAffinityAttr(Value storedValue) { + if (auto affinityOp = dyn_cast( + storedValue.getDefiningOp())) { + return affinityOp.getResultAffinityAttr(); + } + return IREE::Stream::AffinityAttr::lookupOrDefault( + storedValue.getDefiningOp()); +} + +// Returns true if the expression producing |storedValue| is an input without +// any modification (such as inlined constants/parameters). +static bool isPassThroughStore(Value storedValue) { + Operation *op = storedValue.getDefiningOp(); + do { + if (auto transferOp = dyn_cast(op)) { + op = transferOp.getSource().getDefiningOp(); + } else if (auto constantOp = dyn_cast(op)) { + return true; + } else { + return false; + } + } while (op); + return false; +} + +// Returns the result index of |result| in the parent operation. +// The result must be a valid result of op. +static unsigned findResultIndex(Operation *op, Value result) { + for (unsigned i = 0; i < op->getNumResults(); ++i) { + if (op->getResult(i) == result) { + return i; + } + } + llvm_unreachable("result not found in operation"); +} + +// Attempts to evaluate a size value to a constant integer. +// This handles direct constants and analyzes through control flow operations +// where the size is provably constant (e.g., scf.if with matching branch +// sizes). +static std::optional tryEvaluateConstantSize(Value sizeValue) { + if (!sizeValue) { + return std::nullopt; + } + + // Try direct constant match (existing behavior). + APInt size; + if (matchPattern(sizeValue, m_ConstantInt(&size))) { + return size.getZExtValue(); + } + + // For scf.if, check if both branches yield the same constant size. + if (auto ifOp = sizeValue.getDefiningOp()) { + unsigned resultIndex = findResultIndex(ifOp, sizeValue); + + // Get the yielded values from both regions. + auto thenYield = + cast(ifOp.getThenRegion().front().getTerminator()); + auto elseYield = + cast(ifOp.getElseRegion().front().getTerminator()); + + // Recursively evaluate both branch sizes. + Value thenValue = thenYield.getOperand(resultIndex); + Value elseValue = elseYield.getOperand(resultIndex); + + // Find sizes for the yielded resource values. + auto thenSizeValue = IREE::Util::SizeAwareTypeInterface::findSizeValue( + thenValue, &ifOp.getThenRegion().front(), Block::iterator(thenYield)); + auto elseSizeValue = IREE::Util::SizeAwareTypeInterface::findSizeValue( + elseValue, &ifOp.getElseRegion().front(), Block::iterator(elseYield)); + + auto thenSize = tryEvaluateConstantSize(thenSizeValue); + auto elseSize = tryEvaluateConstantSize(elseSizeValue); + + // If both branches have the same constant size, return it. + if (thenSize && elseSize && *thenSize == *elseSize) { + return *thenSize; + } + + return std::nullopt; + } + + // For scf.for, check if the size is loop-invariant. + if (auto forOp = sizeValue.getDefiningOp()) { + unsigned resultIndex = findResultIndex(forOp, sizeValue); + + // Check the initial value (iter_arg). + Value initArg = forOp.getInitArgs()[resultIndex]; + auto initSizeValue = IREE::Util::SizeAwareTypeInterface::findSizeValue( + initArg, forOp->getBlock(), Block::iterator(forOp)); + auto initSize = tryEvaluateConstantSize(initSizeValue); + + if (!initSize) { + return std::nullopt; + } + + // Check the yielded value in the loop body. + auto yieldOp = + cast(forOp.getRegion().front().getTerminator()); + Value yieldedValue = yieldOp.getOperand(resultIndex); + + // Find size for the yielded resource value. + auto yieldedSizeValue = IREE::Util::SizeAwareTypeInterface::findSizeValue( + yieldedValue, &forOp.getRegion().front(), Block::iterator(yieldOp)); + auto yieldedSize = tryEvaluateConstantSize(yieldedSizeValue); + + // If the yielded size matches the initial size, it's invariant. + if (yieldedSize && *yieldedSize == *initSize) { + return *initSize; + } + + return std::nullopt; + } + + // Could not evaluate to a constant. + return std::nullopt; +} + +// Returns a constant pattern for a value derived entirely from a splatted +// value. Returns nullptr if the value is not derived from a splat or has a +// non-constant pattern. +static TypedAttr findConstantSplatPattern(Value storedValue) { + Operation *op = storedValue.getDefiningOp(); + do { + if (auto transferOp = dyn_cast(op)) { + op = transferOp.getSource().getDefiningOp(); + } else if (auto splatOp = dyn_cast(op)) { + TypedAttr pattern; + if (matchPattern(splatOp.getValue(), m_Constant(&pattern))) { + return pattern; + } + return {}; + } else { + return {}; + } + } while (op); + return {}; +} + +// Returns the last value produced that is non-metadata (according to us). +// This lets us skip meaningless ops like transfers and clones that change +// lifetime when cloning into the target program. Those ops, though valid, make +// the IR a lot more confusing to follow and prevent some early folding +// opportunities. +static Value findProducedValue(Value value) { + while (Operation *defOp = value.getDefiningOp()) { + if (auto transferOp = dyn_cast(defOp)) { + // We never care about transfers unless they are transferring to unknown. + auto resultType = + cast(transferOp.getResult().getType()); + if (resultType.getLifetime() != IREE::Stream::Lifetime::Unknown) { + value = transferOp.getSource(); + continue; + } + } else if (auto cloneOp = dyn_cast(defOp)) { + // Skip past clones to find the actual producing operation. + // Clones are just type/lifetime conversions, not data producers. + value = cloneOp.getSource(); + continue; + } + break; + } + return value; +} + +// Returns true if the expression should be hoisted based on policy. +static bool shouldHoistExpression(const EncodingExpr &expr, + const EncodingPolicy &policy) { + bool hasParams = expr.hasParameterInputs(); + bool hasConstants = expr.hasConstantInputs(); + + // Check if this expression type should be hoisted per policy. + if (hasParams && !policy.hoistParameterExpressions) { + LLVM_DEBUG(DBGS() << "skipping parameter expression per policy\n"); + return false; + } + if (!hasParams && hasConstants && !policy.hoistConstantExpressions) { + LLVM_DEBUG(DBGS() << "skipping constant expression per policy\n"); + return false; + } + if (!hasParams && !hasConstants) { + // No inputs at all - probably an error case or pure splat. + LLVM_DEBUG(DBGS() << "skipping expression with no inputs\n"); + return false; + } + + // Check size growth threshold. + int64_t inputSize = expr.estimateInputSize(); + int64_t outputSize = expr.estimateOutputSize(); + if (inputSize > 0) { + float growthFactor = static_cast(outputSize) / inputSize; + if (growthFactor > policy.maxEncodingGrowthFactor) { + LLVM_DEBUG(DBGS() << "rejecting expression due to size growth: " + << growthFactor << "x (threshold: " + << policy.maxEncodingGrowthFactor << "x)\n"); + return false; + } + } + + return true; +} + +// Analyzes |moduleOp| to find all expressions producing global constants that +// we can turn into parameters, if any. +static EncodingExprSet gatherEncodingExprSet(mlir::ModuleOp moduleOp, + EncodingPolicy policy) { + auto constantStoreOps = findAllConstantStoreOps(moduleOp); + + EncodingExprSet exprSet; + + std::unique_ptr asmState; + LLVM_DEBUG(asmState = std::make_unique( + moduleOp, OpPrintingFlags().elideLargeElementsAttrs())); + + for (auto storeOp : constantStoreOps) { + LLVM_DEBUG({ + DBGS() << "evaluating store slice for inclusion: "; + storeOp->print(llvm::dbgs(), *asmState); + llvm::dbgs() << "\n"; + }); + Value storedValue = storeOp.getStoredGlobalValue(); + + BackwardSliceOptions sliceOptions; + sliceOptions.inclusive = true; + bool foundAnyNonConstExprOps = false; + sliceOptions.filter = [&](Operation *op) { + if (isConstExprOp(op)) { + return true; + } + foundAnyNonConstExprOps = true; + return false; + }; + // Collect all values that need to be included in the slice: + // - The stored value itself + // - Values used inside nested regions that are defined outside + // + // We compute backward slices for all of them into the same SetVector, + // which gives us proper topological ordering with deduplication. + SetVector rootValues; + rootValues.insert(storedValue); + + // Do a first pass to find region-containing operations. + SetVector tempSlice; + if (failed(mlir::getBackwardSlice(storedValue, &tempSlice, sliceOptions)) || + foundAnyNonConstExprOps) { + LLVM_DEBUG(DBGS() << "failed to calculate backward slice for op or found " + "non-const-expr ops, skipping\n"); + continue; + } + + // Find external dependencies from nested regions using MLIR's standard API. + // getUsedValuesDefinedAbove returns all values used inside a region but + // defined outside of it - exactly what we need for region captures. + for (auto *op : tempSlice) { + for (Region ®ion : op->getRegions()) { + SetVector capturedValues; + mlir::getUsedValuesDefinedAbove(region, capturedValues); + LLVM_DEBUG({ + if (!capturedValues.empty()) { + DBGS() << "found " << capturedValues.size() + << " captured values in region of "; + op->print(llvm::dbgs()); + llvm::dbgs() << ":\n"; + for (Value captured : capturedValues) { + llvm::dbgs() << " "; + captured.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + } + } + }); + for (Value captured : capturedValues) { + rootValues.insert(captured); + } + } + } + + // Now compute backward slices for all root values. + // When we have multiple roots (due to captured values), calling + // getBackwardSlice iteratively can break topological order because new + // operations get appended. We need to sort after merging. + bool needsSort = rootValues.size() > 1; + SetVector slice; + for (Value rootValue : rootValues) { + if (failed(mlir::getBackwardSlice(rootValue, &slice, sliceOptions)) || + foundAnyNonConstExprOps) { + LLVM_DEBUG(DBGS() << "failed to calculate backward slice for op or " + "found non-const-expr ops, skipping\n"); + break; + } + } + + if (foundAnyNonConstExprOps) { + continue; + } + + // Sort only when we merged multiple slices (i.e., had captured values). + // This is a small set (one expression), not the whole program. + // Use mlir::topologicalSort which correctly handles operations across + // different blocks and regions, unlike isBeforeInBlock which only works + // for operations within the same block. + if (needsSort) { + slice = mlir::topologicalSort(slice); + } + + LLVM_DEBUG({ + DBGS() << "slice:\n"; + llvm::interleave( + slice, llvm::dbgs(), + [&](Operation *op) { + llvm::dbgs() << " "; + op->print(llvm::dbgs(), *asmState); + }, + "\n"); + llvm::dbgs() << "\n"; + }); + + // Overlay mode optimization: When a slice is just a parameter load with no + // transformation (detected by isPassThroughStore below), we skip including + // it as an output in overlay mode since the original parameter is + // unchanged. This is controlled by policy.includeUnmodified: + // - Consolidate mode (includeUnmodified=true): includes all parameters + // - Overlay mode (includeUnmodified=false): skips pass-through parameters + // + // Future enhancement: Could add overlap detection to merge expressions that + // write to overlapping parameter regions, possibly requiring a two-pass + // approach. For now, non-overlapping expressions work correctly. + + EncodingExpr expr; + expr.affinityAttr = + lookupConsumerAffinityAttr(storeOp.getStoredGlobalValue()); + + for (auto *op : slice) { + if (auto constantOp = dyn_cast(op)) { + EncodingExpr::Input input; + input.constantOp = constantOp; + expr.inputs.push_back(input); + } + } + + if (!isPassThroughStore(storeOp.getStoredGlobalValue()) || + policy.includeUnmodified) { + // Check if the produced value prefers cloning (like pure splats). + // These should be included in the slice for cloning but not serialized + // as outputs. + Value producedValue = findProducedValue(storeOp.getStoredGlobalValue()); + auto *producingOp = producedValue.getDefiningOp(); + if (producingOp) { + if (auto streamableOp = + dyn_cast(producingOp)) { + if (streamableOp.preferCloneToConsumers()) { + LLVM_DEBUG(DBGS() + << "skipping output for op that prefers cloning\n"); + continue; + } + } + } + + Value storedValue = storeOp.getStoredGlobalValue(); + Value sizeValue = IREE::Util::SizeAwareTypeInterface::findSizeValue( + storedValue, storeOp->getBlock(), Block::iterator(storeOp)); + + // If findSizeValue returns null, it might be because the value comes from + // a control flow operation (like scf.for or scf.if) that doesn't + // implement SizeAwareOpInterface. Try analyzing the control flow + // directly. + std::optional sizeOpt; + if (sizeValue) { + sizeOpt = tryEvaluateConstantSize(sizeValue); + } else if (auto *defOp = storedValue.getDefiningOp()) { + // Try analyzing control flow operations directly. + if (auto forOp = dyn_cast(defOp)) { + // Find which result this is. + unsigned resultIdx = 0; + for (unsigned i = 0; i < forOp.getNumResults(); ++i) { + if (forOp.getResult(i) == storedValue) { + resultIdx = i; + break; + } + } + // Get size from init arg. + Value initArg = forOp.getInitArgs()[resultIdx]; + Value initSizeValue = + IREE::Util::SizeAwareTypeInterface::findSizeValue( + initArg, forOp->getBlock(), Block::iterator(forOp)); + sizeOpt = tryEvaluateConstantSize(initSizeValue); + } else if (auto ifOp = dyn_cast(defOp)) { + // Find which result this is. + unsigned resultIdx = 0; + for (unsigned i = 0; i < ifOp.getNumResults(); ++i) { + if (ifOp.getResult(i) == storedValue) { + resultIdx = i; + break; + } + } + + // Get sizes from both branches. + auto thenYield = + cast(ifOp.getThenRegion().front().getTerminator()); + auto elseYield = + cast(ifOp.getElseRegion().front().getTerminator()); + Value thenValue = thenYield.getOperand(resultIdx); + Value elseValue = elseYield.getOperand(resultIdx); + auto thenSizeValue = + IREE::Util::SizeAwareTypeInterface::findSizeValue( + thenValue, &ifOp.getThenRegion().front(), + Block::iterator(thenYield)); + auto elseSizeValue = + IREE::Util::SizeAwareTypeInterface::findSizeValue( + elseValue, &ifOp.getElseRegion().front(), + Block::iterator(elseYield)); + auto thenSize = tryEvaluateConstantSize(thenSizeValue); + auto elseSize = tryEvaluateConstantSize(elseSizeValue); + + // Both branches must have the same constant size. + if (thenSize && elseSize && *thenSize == *elseSize) { + sizeOpt = *thenSize; + } + } + } + + if (!sizeOpt) { + LLVM_DEBUG(DBGS() << "failed to find stored resource size, skipping\n"); + continue; + } + EncodingExpr::Output output; + output.size = *sizeOpt; + output.splatPattern = + findConstantSplatPattern(storeOp.getStoredGlobalValue()); + output.storeOp = storeOp; + output.producedValue = producedValue; + expr.outputs.push_back(output); + } + + if (expr.outputs.empty()) { + LLVM_DEBUG(DBGS() << "no outputs produced by policy, skipping\n"); + continue; + } + + expr.ops = std::move(slice); + exprSet.exprs.push_back(std::move(expr)); + } + + return exprSet; +} + +//===----------------------------------------------------------------------===// +// ParameterIndex and builders +//===----------------------------------------------------------------------===// + +// An entry in the parameter index describing a single output parameter. +// Entries can be either SPLAT (constant pattern fill) or DATA (computed bytes). +// A single EncodingExpr may produce multiple entries if it has multiple +// outputs. +struct ParameterEntry { + // Location of the parameter based on the original consumer op. + std::optional loc; + enum class Type { + SPLAT = 0, + DATA = 1, + }; + // Type of the entry (indicates which value field is valid). + Type type; + // Key of the entry within the parameter scope. + StringAttr key; + // Optional metadata embedded with the entry. + SmallVector metadata; + // Total byte length of the parameter in memory. + int64_t length; + // Type-specific value. + union { + struct SplatEntry { + int64_t pattern; + int64_t patternLength; + } splat; + struct DataEntry { + int64_t minimumAlignment; + } data; + } value; + + static ParameterEntry createSplat(Location loc, StringAttr key, + int64_t length, int64_t pattern, + int64_t patternLength) { + ParameterEntry entry{loc}; + entry.type = Type::SPLAT; + entry.key = key; + entry.length = length; + entry.value.splat.pattern = pattern; + entry.value.splat.patternLength = patternLength; + return entry; + } + + static ParameterEntry createData(Location loc, StringAttr key, int64_t length, + int64_t minimumAlignment) { + ParameterEntry entry{loc}; + entry.type = Type::DATA; + entry.key = key; + entry.length = length; + entry.value.data.minimumAlignment = minimumAlignment; + return entry; + } + + Location getLoc() const { + return loc.has_value() ? loc.value() : UnknownLoc::get(key.getContext()); + } +}; + +// An IRPA parameter index. +struct ParameterIndex { + // Fused location derived from all included parameter locations. + Location loc; + // Scope name the index is referenced with, if any. + StringAttr scope; + // All parameter entries in the index. + SmallVector entries; + + void dump(llvm::raw_ostream &os) const { + os << "ParameterIndex[" << scope << "]:\n"; + llvm::interleave( + entries, os, + [&](const ParameterEntry &entry) { + os << " '" << entry.key << "' " << entry.length << " bytes "; + if (!entry.metadata.empty()) { + os << "(metadata: " << entry.metadata.size() << "B) "; + } + switch (entry.type) { + case ParameterEntry::Type::SPLAT: + os << "splat: " + << APInt(entry.value.splat.patternLength * 8, + entry.value.splat.pattern); + break; + case ParameterEntry::Type::DATA: + os << "data: min alignment " << entry.value.data.minimumAlignment + << "B"; + break; + } + }, + "\n"); + os << "\n"; + } +}; + +struct ParameterBuilder { + MLIRContext *context; + StringAttr scope; + StringAttr key; + + ParameterBuilder() = delete; + explicit ParameterBuilder(MLIRContext *context, StringAttr scope, + StringAttr key) + : context(context), scope(scope), key(key) {} + virtual ~ParameterBuilder() = default; + virtual ParameterEntry finalize() = 0; +}; + +struct SplatParameterBuilder : public ParameterBuilder { + Location loc; + int64_t length = 0; + Attribute pattern; + + SplatParameterBuilder(StringAttr scope, StringAttr key, Location loc, + int64_t length, Attribute pattern) + : ParameterBuilder(loc.getContext(), scope, key), loc(loc), + length(length), pattern(pattern) {} + + ParameterEntry finalize() override { + APInt intValue; + APFloat floatValue(0.0f); + if (matchPattern(pattern, m_ConstantFloat(&floatValue))) { + intValue = floatValue.bitcastToAPInt(); + } else if (matchPattern(pattern, m_ConstantInt(&intValue))) { + } else { + assert(false && "ints/floats only; should have been verified"); + } + return ParameterEntry::createSplat( + loc, key, length, intValue.getZExtValue(), intValue.getBitWidth() / 8); + } +}; + +struct DataParameterBuilder : public ParameterBuilder { + IREE::Stream::AffinityAttr affinityAttr; + int64_t maxSize = 0; + int64_t offsetAlignment = 0; + int64_t rangeAlignment = 0; + int64_t currentOffset = 0; + SmallVector locs; + + DataParameterBuilder(StringAttr scope, StringAttr key, + IREE::Stream::AffinityAttr affinityAttr, + IREE::Stream::ResourceConfigAttr resourceConfigAttr) + : ParameterBuilder(resourceConfigAttr.getContext(), scope, key), + affinityAttr(affinityAttr), + maxSize(resourceConfigAttr.getMaxAllocationSize()), + offsetAlignment(resourceConfigAttr.getMinBufferOffsetAlignment()), + rangeAlignment(resourceConfigAttr.getMinBufferRangeAlignment()) {} + + // Reserves |length| bytes of storage in the parameter and returns the aligned + // offset within the parameter if there is sufficient capacity remaining. + std::optional tryReserve(Location loc, int64_t length) { + int64_t alignedOffset = IREE::Util::align(currentOffset, offsetAlignment); + int64_t alignedLength = IREE::Util::align(length, rangeAlignment); + int64_t newOffset = std::max(currentOffset, alignedOffset + alignedLength); + if (newOffset > maxSize) { + // Capacity exceeded. + return std::nullopt; + } + currentOffset = newOffset; + return alignedOffset; + } + + ParameterEntry finalize() override { + return ParameterEntry::createData( + FusedLoc::get(context, locs), key, + IREE::Util::align(currentOffset, rangeAlignment), offsetAlignment); + } +}; + +// A subrange of an output parameter produced by an encoding expression. +// Note that a single expression may produce multiple output subranges. +struct ParameterSubrange { + // Parameter index scope. + StringAttr scope; + // Parameter key the subrange is referencing. + StringAttr key; + // Offset within the parameter where the produced value will be placed. + // Aligned to the requirements of the parameter. + int64_t offset = 0; + // Length of subrange the produced value occupies. Note that if padding is + // present this may not extend to all of the parameter storage. + int64_t length = 0; + + ParameterSubrange(StringAttr scope, StringAttr key, int64_t offset, + int64_t length) + : scope(scope), key(key), offset(offset), length(length) {} + + // Creates a named parameter attribute for this subrange with the given total + // length of the storage parameter. + IREE::Stream::NamedParameterAttr + createNamedParameterAttr(int64_t totalLength) const { + Type i8Type = IntegerType::get(scope.getContext(), 8); + auto parameterType = RankedTensorType::get({totalLength}, i8Type); + return IREE::Stream::NamedParameterAttr::get( + scope.getContext(), parameterType, scope, key, DictionaryAttr{}); + } +}; + +// Map of expression outputs to a reserved parameter subrange. +using OutputParameterSubrangeMap = + llvm::MapVector; + +// Incremental ParameterIndex builder with support for parameter combining. +class ParameterIndexBuilder { +public: + ParameterIndexBuilder(StringAttr scope, const EncodingPolicy &encodingPolicy) + : scope(scope), encodingPolicy(encodingPolicy) {} + + FailureOr insertExpr(const EncodingExpr *expr) { + OutputParameterSubrangeMap outputMap; + for (const auto &output : expr->outputs) { + FailureOr> subrangeOr; + if (output.isSupportedSplat() && + output.size > encodingPolicy.serializeSplatSizeThreshold) { + subrangeOr = insertSplatOutput(expr, &output); + } else { + subrangeOr = insertDataOutput(expr, &output); + } + if (failed(subrangeOr)) { + return failure(); + } + if (subrangeOr->has_value()) { + outputMap.insert( + std::make_pair(&output, std::move(subrangeOr->value()))); + } + } + return outputMap; + } + + ParameterIndex finalize() { + SmallVector parameterLocs; + SmallVector parameterEntries; + for (auto ¶meter : parameters) { + ParameterEntry parameterEntry = parameter->finalize(); + parameterLocs.push_back(parameterEntry.getLoc()); + parameterEntries.push_back(std::move(parameterEntry)); + } + ParameterIndex index{FusedLoc::get(scope.getContext(), parameterLocs)}; + index.scope = scope; + index.entries = std::move(parameterEntries); + return index; + } + +private: + StringAttr makeParameterName() { + return StringAttr::get(scope.getContext(), + Twine("parameter") + std::to_string(nextId++)); + } + + FailureOr> + insertSplatOutput(const EncodingExpr *expr, + const EncodingExpr::Output *output) { + auto splatBuilder = std::make_unique( + scope, makeParameterName(), expr->getLoc(), output->size, + output->splatPattern); + auto subrange = ParameterSubrange(splatBuilder->scope, splatBuilder->key, 0, + output->size); + parameters.push_back(std::move(splatBuilder)); + return {subrange}; + } + + // Inserts a data output into the parameter index, packing into existing + // parameters when possible. + // + // Uses first-fit allocation: iterates through existing parameters in order + // and places the output in the first one with matching affinity and available + // space. This is simple and fast for compilation, though not optimal for + // minimizing fragmentation. A best-fit or sorted-by-size approach could + // improve memory efficiency if parameter packing becomes a bottleneck. + FailureOr> + insertDataOutput(const EncodingExpr *expr, + const EncodingExpr::Output *output) { + if (encodingPolicy.packParameters) { + for (auto *existingBuilder : dataParameters) { + if (existingBuilder->affinityAttr == expr->affinityAttr) { + std::optional offset = + existingBuilder->tryReserve(expr->getLoc(), output->size); + if (offset.has_value()) { + auto subrange = + ParameterSubrange(existingBuilder->scope, existingBuilder->key, + offset.value(), output->size); + return {subrange}; + } + } + } + } + + auto newBuilder = std::make_unique( + scope, makeParameterName(), expr->affinityAttr, + expr->getResourceConfigAttr()); + std::optional offset = + newBuilder->tryReserve(expr->getLoc(), output->size); + if (offset.has_value()) { + auto subrange = ParameterSubrange(newBuilder->scope, newBuilder->key, + offset.value(), output->size); + dataParameters.push_back(newBuilder.get()); + parameters.push_back(std::move(newBuilder)); + return {subrange}; + } + + LLVM_DEBUG(llvm::dbgs() + << " ! failed to reserve " << output->size + << " bytes for output at " << output->getLoc() << "\n"); + return mlir::emitError(output->getLoc(), + "failed to reserve parameter space for output\n"); + } + + StringAttr scope; + const EncodingPolicy &encodingPolicy; + SmallVector> parameters; + SmallVector dataParameters; + unsigned nextId = 0; +}; + +//===----------------------------------------------------------------------===// +// Encoder work scheduling +//===----------------------------------------------------------------------===// + +// A target configuration for a set of specialized encodings. +// Contains the parameter indices (what parameters will be produced), the +// execution schedule (steps), and a lookup map from (scope, key) to entries. +// Targets may specialize for multiple devices simultaneously if the +// configuration is for heterogeneous execution and may produce multiple +// parameter indices. Currently only a single "all" target is supported. +struct TargetPlan { + // Name of the target for the user to specify in tools. + std::string name; + + // Affinity of the device performing the encoding in the encoder module. + // When cross-targeting encoders this will differ from the devices in the + // original program. For consistency it always has a new name. + IREE::Stream::AffinityAttr affinityAttr; + + // Parameter indices produced by the target. + std::vector parameterIndices; + + // A map of (scope, key) to the parameter in the specified index. + DenseMap, ParameterEntry> parameterEntries; + + // A discrete step in the encoding process. + struct Step { + std::string description; + int64_t globalByteOffset = 0; + int64_t globalByteLength = 0; + const EncodingExpr *expr = nullptr; + OutputParameterSubrangeMap outputMap; + + Location getLoc() const { return expr->getLoc(); } + }; + + // An unordered sequence of encoding steps. + // Steps _generally_ start in order but may end in any order and can be + // considered more as "chunks of work" than some point on a timeline. + // Each step may encode more than one parameter. + SmallVector steps; + + // Cumulative size of all writes to all parameters in all scopes. + int64_t globalByteSize = 0; + + // Appends an encoding expression and its output mapping to the schedule. + void appendExpr(const EncodingExpr *expr, + OutputParameterSubrangeMap outputMap) { + Step step; + + // Today we just name the steps in sequence, but could use the parameter + // names in the output map. + step.description = "step" + std::to_string(steps.size()); + + // Since order is largely undefined and each step may produce multiple + // parameters we track a cumulative write offset in a virtual global + // parameter file and use that. Tools can present % completed or use the + // virtual subranges to indicate fine-grained progress. + step.globalByteOffset = globalByteSize; + step.globalByteLength = std::accumulate( + expr->outputs.begin(), expr->outputs.end(), int64_t{0}, + [](int64_t sum, const EncodingExpr::Output &output) -> int64_t { + return sum + output.size; + }); + LLVM_DEBUG(DBGS() << "defining step `" << step.description << "` (at " + << step.globalByteOffset << " for " + << step.globalByteLength << ")\n"); + globalByteSize += step.globalByteLength; + + step.expr = expr; + step.outputMap = std::move(outputMap); + steps.push_back(std::move(step)); + } + + // Returns the named parameter reference attribute for the given subrange. + IREE::Stream::NamedParameterAttr + getNamedParameterAttr(const ParameterSubrange &subrange) const { + auto parameterEntryIt = + parameterEntries.find(std::make_pair(subrange.scope, subrange.key)); + assert(parameterEntryIt != parameterEntries.end() && + "map must contain all entries"); + const ParameterEntry ¶meterEntry = parameterEntryIt->second; + return subrange.createNamedParameterAttr(parameterEntry.length); + } +}; + +//===----------------------------------------------------------------------===// +// Parameter encoder construction +//===----------------------------------------------------------------------===// + +// Adds a function to the new encoder module that tries to automatically detect +// the target configuration given the list of HAL devices. The intent is that it +// performs the same device detection logic the main module performs at runtime +// but with a provided list instead of what the HAL module provides: the only +// device(s) we have at the global level are those of the host performing the +// encoding. +// +// Signature, returning a string constant target name: +// util.func public @__encode_parameter_detect_target( +// %devices: !util.list) -> !util.buffer +static void addAutoTargetDetectFunc(Location loc, + ArrayRef targetPlans, + OpBuilder &encoderBuilder) { + std::string funcName = "__encode_parameter_detect_target"; + LLVM_DEBUG(DBGS() << "emitting auto target detection function: " << funcName + << "...\n"); + + auto bufferType = encoderBuilder.getType(); + auto deviceType = encoderBuilder.getType(); + auto deviceListType = + encoderBuilder.getType(deviceType); + auto funcOp = IREE::Util::FuncOp::create( + encoderBuilder, loc, funcName, + encoderBuilder.getFunctionType({deviceListType}, {bufferType})); + funcOp.setVisibility(SymbolTable::Visibility::Public); + OpBuilder funcBuilder = OpBuilder::atBlockBegin(funcOp.addEntryBlock()); + funcOp->setAttr( + "iree.reflection", + funcBuilder.getDictionaryAttr({ + NamedAttribute("iree.encode.function", + funcBuilder.getStringAttr("detect_target")), + })); + + // Always unconditionally choose the first target today. + assert(!targetPlans.empty()); + Value targetName = IREE::Util::BufferConstantOp::create( + funcBuilder, loc, targetPlans.front().name); + IREE::Util::ReturnOp::create(funcBuilder, loc, targetName); +} + +// Builds a struct of `[scope name, [entries]]`. +// Supported entry types: +// +// SPLAT (iree_io_parameter_archive_builder_add_splat_entry): +// [0]: i64 type=0 +// [1]: !util.buffer key (not be NUL terminated) +// [2]: !util.buffer metadata (optional) +// [3]: i64 data length (total size of the parameter) +// [4]: i64 pattern (only up to pattern_bytes_length bytes used) +// [5]: i64 pattern_byte_length +// +// DATA (iree_io_parameter_archive_builder_add_data_entry): +// [0]: i64 type=1 +// [1]: !util.buffer key (not be NUL terminated) +// [2]: !util.buffer metadata (optional) +// [3]: i64 data length (total size of the parameter) +// [4]: i64 minimum alignment (or 0 if don't care) +static Value buildParameterIndexStruct(const ParameterIndex ¶meterIndex, + IntegerSet &i64Set, + OpBuilder &builder) { + LLVM_DEBUG({ + DBGS() << "emitting index with scope: `" << parameterIndex.scope << "` (" + << parameterIndex.entries.size() << " entries)\n"; + parameterIndex.dump(llvm::dbgs()); + }); + + auto loc = parameterIndex.loc; + auto listType = builder.getType(); + + Value scopeName = + IREE::Util::BufferConstantOp::create(builder, loc, parameterIndex.scope); + + SmallVector entryValues; + for (auto &entry : parameterIndex.entries) { + Location entryLoc = entry.getLoc(); + Value typeValue = i64Set.get(static_cast(entry.type)); + Value keyValue = + IREE::Util::BufferConstantOp::create(builder, entryLoc, entry.key); + Value metadataValue = IREE::Util::BufferConstantOp::createOrNull( + builder, entryLoc, entry.metadata); + SmallVector structFields = { + typeValue, + keyValue, + metadataValue, + i64Set.get(entry.length), + }; + switch (entry.type) { + case ParameterEntry::Type::SPLAT: + structFields.push_back(i64Set.get(entry.value.splat.pattern)); + structFields.push_back(i64Set.get(entry.value.splat.patternLength)); + break; + case ParameterEntry::Type::DATA: + structFields.push_back(i64Set.get(entry.value.data.minimumAlignment)); + break; + } + Value entryValue = IREE::Util::ListConstructOp::create( + builder, entryLoc, listType, structFields); + entryValues.push_back(entryValue); + } + + Value entryList = + IREE::Util::ListConstructOp::create(builder, loc, listType, entryValues); + Value indexStruct = IREE::Util::ListConstructOp::create( + builder, loc, listType, {scopeName, entryList}); + return indexStruct; +}; + +// Adds a function to the new encoder module that returns the parameter indices +// produced for a given target. A single target may result in more than one +// parameter file in cases where we want to shard parameters. +// +// Signature: +// util.func public @__encode_parameter_indices_TARGET() -> !util.list +static void addTargetIndexBuilderFunc(Location loc, + const TargetPlan &targetPlan, + OpBuilder &encoderBuilder) { + auto listType = encoderBuilder.getType(); + std::string funcName = "__encode_parameter_indices_" + targetPlan.name; + LLVM_DEBUG(DBGS() << "emitting index builder function: " << funcName + << "...\n"); + auto funcOp = IREE::Util::FuncOp::create( + encoderBuilder, loc, funcName, + encoderBuilder.getFunctionType({}, {listType})); + funcOp.setVisibility(SymbolTable::Visibility::Public); + OpBuilder funcBuilder = OpBuilder::atBlockBegin(funcOp.addEntryBlock()); + + // Reflection information lets the tool list available targets and required + // scopes without having to call each function. + // VM bytecode only supports string/integer reflection attributes, so we + // encode scopes as a comma-separated string. + // Note: Empty string values crash the flatbuffer serializer, so we only + // include the scopes attribute if there are non-empty scopes. + std::string scopesStr; + for (const auto ¶meterIndex : targetPlan.parameterIndices) { + StringRef scope = parameterIndex.scope.getValue(); + if (scope.empty()) { + continue; + } + if (!scopesStr.empty()) { + scopesStr += ","; + } + scopesStr += scope; + } + SmallVector reflectionAttrs; + reflectionAttrs.push_back(NamedAttribute( + "iree.encode.function", funcBuilder.getStringAttr("indices"))); + reflectionAttrs.push_back(NamedAttribute( + "iree.encode.target", funcBuilder.getStringAttr(targetPlan.name))); + if (!scopesStr.empty()) { + reflectionAttrs.push_back(NamedAttribute( + "iree.encode.scopes", funcBuilder.getStringAttr(scopesStr))); + } + funcOp->setAttr("iree.reflection", + funcBuilder.getDictionaryAttr(reflectionAttrs)); + + IntegerSet i64Set(loc, funcBuilder); + SmallVector indicesStructs; + for (const auto ¶meterIndex : targetPlan.parameterIndices) { + indicesStructs.push_back( + buildParameterIndexStruct(parameterIndex, i64Set, funcBuilder)); + } + + Value indicesList = IREE::Util::ListConstructOp::create( + funcBuilder, loc, listType, indicesStructs); + IREE::Util::ReturnOp::create(funcBuilder, loc, {indicesList}); +} + +// Adds a function to the new encoder module that produces a list of steps +// involved in encoding the parameters for a specific target. Steps do not +// correspond 1:1 with parameters in either the input or output module and may +// complete in any order so we return a list of structs and fences that can be +// used to observe the state and report on progress. If progress capture is +// desired the list needs to be passed back into the encoder function so that it +// can instrument the encoding process with the fences. +// +// A global byte range is attached to each step for presentation purposes only: +// multiple parameter indices may be constructed and an individual step may +// produce values for each. Tools may only use the global byte range to denote +// cumulative bytes written by each step. +// +// Each step entry consists of: +// [0]: i64 reserved 0 +// [1]: !hal.fence indicating encoding has begun +// [2]: !hal.fence indicating encoding has ended +// [3]: !util.buffer descriptive comment (not be NUL terminated) +// [4]: i64 synthetic global byte offset +// [5]: i64 synthetic global byte length +// +// Signature: +// util.func public @__encode_parameter_steps_TARGET() -> !util.list +static void addTargetEncoderStepsFunc(Location loc, + const TargetPlan &targetPlan, + OpBuilder &encoderBuilder) { + std::string funcName = "__encode_parameter_steps_" + targetPlan.name; + LLVM_DEBUG(DBGS() << "emitting encoder steps function: " << funcName + << "...\n"); + auto listType = encoderBuilder.getType(); + auto funcOp = IREE::Util::FuncOp::create( + encoderBuilder, loc, funcName, + encoderBuilder.getFunctionType({}, {listType})); + funcOp.setVisibility(SymbolTable::Visibility::Public); + OpBuilder funcBuilder = OpBuilder::atBlockBegin(funcOp.addEntryBlock()); + funcOp->setAttr( + "iree.reflection", + funcBuilder.getDictionaryAttr({ + NamedAttribute("iree.encode.function", + funcBuilder.getStringAttr("steps")), + NamedAttribute("iree.encode.target", + funcBuilder.getStringAttr(targetPlan.name)), + })); + + Type deviceType = funcBuilder.getType(); + Value deviceValue = + IREE::Stream::ContextResolveOp::create(funcBuilder, loc, {deviceType}, + targetPlan.affinityAttr) + .getResult(0); + + SmallVector stepStructs; + IntegerSet i64Set(loc, funcBuilder); + for (auto &step : targetPlan.steps) { + Value beginFence = IREE::HAL::FenceCreateOp::create( + funcBuilder, loc, deviceValue, IREE::HAL::FenceFlagBitfield::None); + Value endFence = IREE::HAL::FenceCreateOp::create( + funcBuilder, loc, deviceValue, IREE::HAL::FenceFlagBitfield::None); + Value descriptionValue = IREE::Util::BufferConstantOp::create( + funcBuilder, loc, step.description); + stepStructs.push_back(IREE::Util::ListConstructOp::create( + funcBuilder, loc, listType, + { + i64Set.get(0), + beginFence, + endFence, + descriptionValue, + i64Set.get(step.globalByteOffset), + i64Set.get(step.globalByteLength), + })); + } + + Value stepsList = IREE::Util::ListConstructOp::create(funcBuilder, loc, + listType, stepStructs); + IREE::Util::ReturnOp::create(funcBuilder, loc, stepsList); +} + +using MarkObjectReference = + std::function; + +// Adds a function to the new encoder module that encodes parameters for a +// specific target. Encoding will wait for the provided `wait_fence` prior to +// starting any processing and signal the provided `signal_fence` when all +// processing has completed. The steps list is the result of the paired +// `__encode_parameter_steps_TARGET` function and the fences within will be +// signaled as encoding progresses. +// +// Signature (follows standard coarse-fences ABI with fences at end): +// util.func public @__encode_parameters_TARGET( +// %steps: !util.list, +// %wait_fence: !hal.fence, +// %signal_fence: !hal.fence) +static LogicalResult +addTargetEncoderFunc(Location loc, const TargetPlan &targetPlan, + const MarkObjectReference &markObjectReference, + OpBuilder &encoderBuilder) { + std::string funcName = "__encode_parameters_" + targetPlan.name; + LLVM_DEBUG(DBGS() << "emitting encoder function: " << funcName << "...\n"); + auto fenceType = encoderBuilder.getType(); + auto listType = encoderBuilder.getType(); + auto funcOp = IREE::Util::FuncOp::create( + encoderBuilder, loc, funcName, + encoderBuilder.getFunctionType({listType, fenceType, fenceType}, {})); + funcOp.setVisibility(SymbolTable::Visibility::Public); + OpBuilder funcBuilder = OpBuilder::atBlockBegin(funcOp.addEntryBlock()); + funcOp->setAttr( + "iree.reflection", + funcBuilder.getDictionaryAttr({ + NamedAttribute("iree.abi.model", + funcBuilder.getStringAttr("coarse-fences")), + NamedAttribute("iree.encode.function", + funcBuilder.getStringAttr("encode")), + NamedAttribute("iree.encode.target", + funcBuilder.getStringAttr(targetPlan.name)), + })); + + // TODO(benvanik): make steps optional, probably by just calling the steps + // function internally when not provided so that we can keep all the encoding + // code branch-free. For now we require it be provided. + + Value waitFence = funcOp.getArgument(1); + Value signalFence = funcOp.getArgument(2); + + Type timepointType = funcBuilder.getType(); + Value lastTimepoint = IREE::Stream::TimepointImportOp::create( + funcBuilder, loc, timepointType, waitFence, targetPlan.affinityAttr); + + // Use explicit transient lifetime for all output slab allocations. + // This storage is allocated at the start of each step and deallocated at the + // end, making transient the correct lifetime. + Type resourceType = funcBuilder.getType( + IREE::Stream::Lifetime::Transient); + IndexSet indexSet(loc, funcBuilder); + IntegerSet i64Set(loc, funcBuilder); + for (const auto &step : targetPlan.steps) { + Location stepLoc = step.getLoc(); + + // Build a map of scope name to the outputs going to it and their parameter + // references. Note that this mapping is target-specific (as each target may + // have a different mix of parameters and parameter sizes due to differences + // in encodings). + struct OutputReservation { + const EncodingExpr::Output *output = nullptr; + const ParameterSubrange *parameterSubrange = nullptr; + IREE::Stream::NamedParameterAttr parameterAttr; + size_t slabOffsetOrdinal = 0; + }; + llvm::MapVector> scopeOutputs; + SmallVector outputSizes; + for (auto &output : step.expr->outputs) { + auto it = step.outputMap.find(&output); + if (it == step.outputMap.end()) { + continue; // no serialization required + } + const ParameterSubrange &subrange = it->second; + OutputReservation reservation; + reservation.output = &output; + reservation.parameterSubrange = &subrange; + reservation.parameterAttr = targetPlan.getNamedParameterAttr(subrange); + reservation.slabOffsetOrdinal = outputSizes.size(); + scopeOutputs[reservation.parameterAttr.getScope()].push_back(reservation); + outputSizes.push_back(indexSet.get(subrange.length)); + } + + // Allocate transient storage for all the parameter outputs. + // If we were overlapping we'd want to get this from a ringbuffer. + // TODO(benvanik): stream.async.ringbuffer-style ops for safely doing bump + // pointer allocation with timeline-awareness at this level. + auto reservationPackOp = IREE::Stream::ResourcePackOp::create( + funcBuilder, stepLoc, /*offset=*/nullptr, outputSizes, + targetPlan.affinityAttr); + Value outputSlabSize = reservationPackOp.getTotalLength(); + auto outputSlabAllocaOp = IREE::Stream::ResourceAllocaOp::create( + funcBuilder, stepLoc, resourceType, timepointType, outputSlabSize, + /*indeterminate_lifetime=*/nullptr, lastTimepoint, + targetPlan.affinityAttr); + Value outputSlab = outputSlabAllocaOp.getResult(); + + // Note: Input parameters are NOT included in this slab allocation. + // Inputs are loaded via stream.async.constant operations (cloned below) + // which reference external parameter storage and don't require allocation. + // Only outputs need slab allocation as transient working memory before + // being scattered to their final parameter locations. + // + // Wait for the slab to be ready before we transition back into async IR. + outputSlab = IREE::Stream::TimepointAwaitOp::create( + funcBuilder, stepLoc, {outputSlab}, {outputSlabSize}, + outputSlabAllocaOp.getResultTimepoint()) + .getResult(0); + + // Clone the expression IR and fix it up for use in the new module. + // We have to remove any affinities referencing the devices in the source + // program and ensure we also bring along any referenced objects + // (executables, etc). + // + // The slice is already in topological order from getBackwardSlice, and + // all captured values from nested regions have been included via + // getUsedValuesDefinedAbove, so we can clone directly without sorting. + // + // AsyncConstantOp with parameter values are converted to + // AsyncParameterLoadOp during cloning because the lowering path through + // ResourceConstantsOp does not preserve await_timepoint. + // AsyncParameterLoadOp lowers directly to CmdParameterLoadOp which does + // preserve await_timepoint. + IRMapping exprMapping; + for (auto *sourceOp : step.expr->ops) { + auto *clonedOp = funcBuilder.clone(*sourceOp, exprMapping); + if (auto affinityOp = + dyn_cast(clonedOp)) { + affinityOp.removeAffinityAttrs(); + } + // Convert AsyncConstantOp with parameter values to AsyncParameterLoadOp. + // This ensures await_timepoint is preserved through lowering, since + // AsyncConstantOp goes through ResourceConstantsOp which drops await. + if (auto constantOp = dyn_cast(clonedOp)) { + if (auto parameterAttr = + dyn_cast(constantOp.getValue())) { + // Extract parameter scope and key from the attribute. + StringAttr scopeAttr = parameterAttr.getScope(); + StringAttr keyAttr = parameterAttr.getKey(); + // Create zero offset for full parameter load. + Value zeroOffset = i64Set.get(0); + Value resultSize = constantOp.getResultSize(); + // Create AsyncParameterLoadOp with the wait fence as await. + auto paramLoadOp = IREE::Stream::AsyncParameterLoadOp::create( + funcBuilder, constantOp.getLoc(), + constantOp.getResult().getType(), + funcBuilder.getType(), + /*await_timepoint=*/lastTimepoint, scopeAttr, keyAttr, zeroOffset, + resultSize, targetPlan.affinityAttr); + // Await the result timepoint to get a resolved resource that can be + // used by streamable ops without explicit synchronization. + auto awaitOp = IREE::Stream::TimepointAwaitOp::create( + funcBuilder, constantOp.getLoc(), paramLoadOp.getResult(), + resultSize, paramLoadOp.getResultTimepoint()); + // Update mapping to use the awaited result. + exprMapping.map(sourceOp->getResult(0), awaitOp.getResults().front()); + // Erase the cloned AsyncConstantOp. + constantOp.erase(); + clonedOp = awaitOp; + } else { + // Non-parameter constant: just set await_timepoint. + constantOp.getAwaitTimepointMutable().assign(lastTimepoint); + } + } + auto symbolUses = SymbolTable::getSymbolUses(clonedOp); + if (symbolUses.has_value()) { + for (auto &use : symbolUses.value()) { + if (failed(markObjectReference(clonedOp, use.getSymbolRef()))) { + return failure(); + } + } + } + } + + // Scatter the outputs into the parameter(s) for each scope. + for (auto [scope, outputReservations] : scopeOutputs) { + for (auto &reservation : outputReservations) { + Location outputLoc = reservation.output->getLoc(); + Value outputValue = + exprMapping.lookup(reservation.output->producedValue); + Value packedOffset = + reservationPackOp.getPackedOffsets()[reservation.slabOffsetOrdinal]; + Value packedEnd = + indexSet.add(packedOffset, reservation.parameterSubrange->length); + Value outputSize = indexSet.get(reservation.parameterSubrange->length); + auto updateOp = IREE::Stream::AsyncUpdateOp::create( + funcBuilder, outputLoc, outputSlab.getType(), outputSlab, + outputSlabSize, packedOffset, packedEnd, outputValue, outputSize, + targetPlan.affinityAttr); + outputSlab = updateOp.getResult(); + } + } + auto outputBarrierOp = IREE::Stream::TimepointBarrierOp::create( + funcBuilder, step.getLoc(), outputSlab, outputSlabSize, + targetPlan.affinityAttr); + outputSlab = outputBarrierOp.getResult(); + + // Scatter parameters from the transient slab into each target scope. + SmallVector scatterTimepoints; + for (auto [scope, outputReservations] : scopeOutputs) { + SmallVector outputLocs; + SmallVector sourceOffsets; + SmallVector sourceEnds; + SmallVector sourceLengths; + SmallVector targetKeys; + SmallVector targetOffsets; + for (auto &reservation : outputReservations) { + outputLocs.push_back(reservation.output->getLoc()); + Value packedOffset = + reservationPackOp.getPackedOffsets()[reservation.slabOffsetOrdinal]; + Value packedSize = indexSet.get(reservation.parameterSubrange->length); + sourceOffsets.push_back(packedOffset); + sourceLengths.push_back(packedSize); + targetKeys.push_back(reservation.parameterAttr.getKey()); + targetOffsets.push_back( + i64Set.get(reservation.parameterSubrange->offset)); + } + // Compute source ends (offset + length) for async parameter scatter. + for (auto [offset, length] : + llvm::zip_equal(sourceOffsets, sourceLengths)) { + auto end = funcBuilder.createOrFold( + funcBuilder.getFusedLoc(outputLocs), offset, length); + sourceEnds.push_back(end); + } + auto scatterOp = IREE::Stream::AsyncParameterScatterOp::create( + funcBuilder, funcBuilder.getFusedLoc(outputLocs), outputSlab, + outputSlabSize, sourceOffsets, sourceEnds, sourceLengths, scope, + funcBuilder.getArrayAttr(targetKeys), targetOffsets, + outputBarrierOp.getResultTimepoint(), targetPlan.affinityAttr); + // AsyncParameterScatterOp returns (resource, timepoint) tuple. + outputSlab = scatterOp.getResult(); + scatterTimepoints.push_back(scatterOp.getResultTimepoint()); + } + Value scattersTimepoint = IREE::Stream::TimepointJoinOp::create( + funcBuilder, stepLoc, scatterTimepoints); + + // Deallocate the output slab (now the scattered resource). + Value deallocaTimepoint = IREE::Stream::ResourceDeallocaOp::create( + funcBuilder, stepLoc, outputSlab, outputSlabSize, + /*prefer_origin=*/false, scattersTimepoint, targetPlan.affinityAttr); + + lastTimepoint = deallocaTimepoint; + } + + // Chain the final timepoint (which depends on all steps via the loop above) + // with the external signal fence. This signals completion of all encoding + // steps. We use a single chain at the end rather than chaining after each + // step because: (1) the function has only one signal fence parameter, and + // (2) callers wait on the fence to know when all encoding is complete, not + // individual steps. + IREE::Stream::TimepointChainExternalOp::create(funcBuilder, funcOp.getLoc(), + lastTimepoint, {signalFence}, + targetPlan.affinityAttr); + + IREE::Util::ReturnOp::create(funcBuilder, loc); + + return success(); +} + +// Replaces all encoded exprs in the original module with loads/gathers from the +// new encoded parameters. +static void replaceEncodedExprs(ArrayRef targetPlans) { + // TODO: support multiple targets by emitting a big switch, a detection + // function, and then conditionally execute each plan. Each plan should + // encompass all the required expressions but heterogeneous makes things + // more complicated in a way I can't yet see. For now we assume all + // expressions are grouped into a single target and always evaluated (vs. + // conditionally evaluated per target). + const TargetPlan &targetPlan = targetPlans.front(); + + // Since expressions may share ops we accumulate all the root ops we believe + // are dead and then burn them down after we're done accessing them. + SmallVector deadOpWorklist; + + // Note that it's possible for targets to not have all expressions: if we are + // specializing a heterogeneous module we may produce one encoder module per + // target each with its own set of placed parameters. + IndexSetCollection indexSetCollection; + for (auto &step : targetPlan.steps) { + // Collect external timepoints once per expression (shared by all outputs). + Value expressionAwaitTimepoint; + if (!step.expr->outputs.empty()) { + OpBuilder timepointBuilder(step.expr->outputs.front().storeOp); + expressionAwaitTimepoint = + collectExternalTimepoints(*step.expr, timepointBuilder); + } + + for (auto &output : step.expr->outputs) { + auto it = step.outputMap.find(&output); + if (it == step.outputMap.end()) { + continue; // no serialization required + } + auto *indexSet = indexSetCollection.get(output.storeOp); + OpBuilder builder(output.storeOp); + + // Since each target may have a unique size and packing of their + // encoded parameters we need to reference the plan-specific parameter. + const ParameterSubrange &subrange = it->second; + auto parameterAttr = targetPlan.getNamedParameterAttr(subrange); + const int64_t storageSize = parameterAttr.getStorageSize(); + Value storageSizeValue = indexSet->get(storageSize); + + // Embed an inline constant referencing the parameter and slice out the + // subrange (if any). + Value oldValue = output.storeOp.getStoredGlobalValue(); + + Value constantValue = IREE::Stream::AsyncConstantOp::create( + builder, output.getLoc(), oldValue.getType(), + expressionAwaitTimepoint, parameterAttr, storageSizeValue, + step.expr->affinityAttr); + Value newValue = constantValue; + if (subrange.offset != 0 || subrange.length != storageSize) { + // TODO(benvanik): use AsyncSliceOp instead; today ElideAsyncCopiesPass + // does not do any IPO and inserting slices here forces each parameter + // to be cloned at execution. Inserting ResourceSubviewOp is only barely + // safe here because we otherwise don't allow it and know we can run a + // propagation pass immediately after this pass. It's shady, though, and + // may block other optimizations. + // + // Should be: + // newValue = IREE::Stream::AsyncSliceOp::create( + // builder, output.getLoc(), constantValue, storageSizeValue, + // indexSet->get(subrange.offset), + // indexSet->add(subrange.offset, subrange.length), + // indexSet->get(subrange.length), step.expr->affinityAttr); + newValue = IREE::Stream::ResourceSubviewOp::create( + builder, output.getLoc(), constantValue, storageSizeValue, + indexSet->get(subrange.offset), indexSet->get(subrange.length)); + } + output.storeOp.setStoredGlobalValue(newValue); + + // Now that we've replaced a use (but maybe not all uses!) we may be able + // to kill one or more ops. Since expressions/outputs may share IR we + // enqueue the deletion check to the end. + if (auto *producerRootOp = oldValue.getDefiningOp()) { + // Enqueue ops with no uses for pruning - pruneDeadOps will determine + // if they're actually safe to delete. + if (producerRootOp->use_empty()) { + deadOpWorklist.push_back(producerRootOp); + } + } + } + } + + // Recursively delete unused operations and their producers. + pruneDeadOps(std::move(deadOpWorklist)); +} + +//===----------------------------------------------------------------------===// +// --iree-stream-split-parameter-encoder +//===----------------------------------------------------------------------===// + +// Placeholder planning for taking an expression set and producing a +// target-specialized set of parameter indices and an encoding schedule. +// +// TODO: use analysis to identify a set of a target configurations. This may +// be too tricky to do automatically (what would we call the +// configurations?) and require the user to specify the exact names and +// constituent devices. We'd want to take the configuration and prune the +// expression set to those used with involved devices, potentially allow for +// a second specialization round, etc. For now we just have one default +// target and let the tool auto select it. +static FailureOr planDefaultTarget(const EncodingExprSet &exprSet, + StringAttr scope, + EncodingPolicy encodingPolicy) { + LLVM_DEBUG( + DBGS() + << "building parameter index and schedule for default target in scope `" + << scope << "`\n"); + + TargetPlan targetPlan; + targetPlan.name = "all"; + + // For now we leave the encoding host target unspecified. This allows the + // user to compile for any device they want. We could copy the device from + // the source module if we wanted to do 1:1 encoding:execution. + targetPlan.affinityAttr = IREE::HAL::DevicePromiseAttr::get( + scope.getContext(), StringAttr::get(scope.getContext(), "__device_0"), + -1); + + ParameterIndexBuilder parameterIndexBuilder(scope, encodingPolicy); + for (int i = 0; i < exprSet.exprs.size(); ++i) { + const EncodingExpr &expr = exprSet.exprs[i]; + auto outputMapOr = parameterIndexBuilder.insertExpr(&expr); + if (failed(outputMapOr)) { + return mlir::emitError(expr.getLoc(), + "failed to add expression to parameter index"); + } + targetPlan.appendExpr(&expr, std::move(outputMapOr.value())); + } + ParameterIndex parameterIndex = parameterIndexBuilder.finalize(); + for (auto &entry : parameterIndex.entries) { + targetPlan.parameterEntries[std::make_pair(scope, entry.key)] = entry; + } + targetPlan.parameterIndices.push_back(std::move(parameterIndex)); + return targetPlan; +} + +struct SplitParameterEncoderPass + : public IREE::Stream::impl::SplitParameterEncoderPassBase< + SplitParameterEncoderPass> { + using IREE::Stream::impl::SplitParameterEncoderPassBase< + SplitParameterEncoderPass>::SplitParameterEncoderPassBase; + void runOnOperation() override { + MLIRContext *context = &getContext(); + mlir::ModuleOp moduleOp = getOperation(); + + // Scan the program and find candidate expressions. + EncodingPolicy encodingPolicy; + encodingPolicy.includeUnmodified = + mode == IREE::Stream::ParameterEncoderMode::Consolidate; + encodingPolicy.hoistParameterExpressions = hoistParameterExpressions; + encodingPolicy.hoistConstantExpressions = hoistConstantExpressions; + encodingPolicy.maxEncodingGrowthFactor = maxEncodingGrowthFactor; + + EncodingExprSet exprSet = gatherEncodingExprSet(moduleOp, encodingPolicy); + + // Filter expressions by policy (size growth, expression type). + EncodingExprSet filteredExprSet; + for (const auto &expr : exprSet.exprs) { + if (shouldHoistExpression(expr, encodingPolicy)) { + filteredExprSet.exprs.push_back(expr); + } else { + LLVM_DEBUG(DBGS() << "skipping expression based on policy\n"); + } + } + + if (filteredExprSet.empty()) { + // No candidates detected (or none the policy approves) so no-op. + // + // The user invoking this pass did ask for a new file, though, so we need + // to at least delete any existing one so the user doesn't get confused + // (old artifacts from a run where we did write something carried across). + LLVM_DEBUG(DBGS() << "no candidate expressions detected; skipping pass " + "and deleting existing output file\n"); + if (!outputFile.empty()) { + (void)llvm::sys::fs::remove(outputFile); + } + return; + } + + // Create the new encoder module we'll be populating. Note that we may have + // multiple targets that contribute functions to the module. + OwningOpRef encoderModuleOpRef = + mlir::ModuleOp::create(moduleOp.getLoc(), "encoder"); + mlir::ModuleOp encoderModuleOp = *encoderModuleOpRef; + encoderModuleOp->setAttr( + "iree.reflection", + DictionaryAttr::get( + context, { + NamedAttribute("iree.tool", + StringAttr::get( + context, "iree-encode-parameters")), + NamedAttribute("iree.encode.version", + IntegerAttr::get( + IntegerType::get(context, 32), 1)), + })); + OpBuilder encoderBuilder = + OpBuilder::atBlockBegin(encoderModuleOp.getBody()); + + // Today we only support a single target and build the index for that. + // A few things in here will need to change when we specialize but most of + // the data structures are set up for it. + std::string targetOutputScope = + outputScope.hasValue() ? outputScope.getValue() : ""; + auto defaultTargetOr = planDefaultTarget( + filteredExprSet, StringAttr::get(context, targetOutputScope), + encodingPolicy); + if (failed(defaultTargetOr)) { + return signalPassFailure(); + } + LLVM_DEBUG(DBGS() << "note: default target '" << defaultTargetOr->name + << "' used in place of target specialization\n"); + SmallVector targetPlans; + targetPlans.push_back(std::move(defaultTargetOr).value()); + + // Emit the target detection function used by tools to try to infer the host + // target (useful for post-deployment encoding). + addAutoTargetDetectFunc(moduleOp->getLoc(), targetPlans, encoderBuilder); + + // Emit the per-target metadata functions. + for (const auto &targetPlan : targetPlans) { + addTargetIndexBuilderFunc(moduleOp->getLoc(), targetPlan, encoderBuilder); + addTargetEncoderStepsFunc(moduleOp->getLoc(), targetPlan, encoderBuilder); + } + + // Accumulate object references during cloning so that we can deduplicate + // and clone them all afterward. This avoids interleaving the objects with + // the encoder functions - sometimes that is good, but it's easier to read + // the IR when they aren't. + SymbolTable sourceSymbolTable(moduleOp); + SetVector objectsToClone; + + // Capture the last op (if any) so we can insert after it later. + // This ensures objects go before any encoder functions we're about to add. + Operation *lastOpBeforeEncoders = + &*std::prev(encoderModuleOp.getBody()->end(), 1); + auto markObjectReference = [&](Operation *userOp, + SymbolRefAttr symbolRef) -> LogicalResult { + auto objectNameAttr = symbolRef.getRootReference(); + auto *objectOp = sourceSymbolTable.lookup(objectNameAttr); + if (!objectOp) { + return userOp->emitOpError() + << "reference to undefined symbol " << symbolRef; + } + if (!objectOp->hasTrait()) { + return userOp->emitOpError() + << "reference to non-object-like symbol " << symbolRef; + } + objectsToClone.insert(objectOp); + return success(); + }; + + // Produce all of the encoder functions and gather the objects we need to + // clone. + for (const auto &targetPlan : targetPlans) { + if (failed(addTargetEncoderFunc(moduleOp->getLoc(), targetPlan, + markObjectReference, encoderBuilder))) { + return signalPassFailure(); + } + } + + // Clone all objects referenced by the encoder module. + // Object-like ops are isolated and safe to copy wholesale. + // Insert after the last op that existed before we added encoder functions. + encoderBuilder.setInsertionPointAfter(lastOpBeforeEncoders); + for (Operation *objectOp : objectsToClone) { + encoderBuilder.clone(*objectOp); + } + + // Replace the expressions in the original module with parameter lookups. + replaceEncodedExprs(targetPlans); + + // CSE to clean up the encoder IR before dumping. + // This is important for deduplicating operations shared across multiple + // encoding expressions. When expressions are cloned into the encoder + // module, shared intermediate operations get duplicated at clone time. CSE + // removes these duplicates, ensuring efficient encoder module output. The + // original module likely needs a bit of cleanup but as compilation + // continues that'll happen. + { + IRRewriter rewriter(context); + DominanceInfo domInfo; + mlir::eliminateCommonSubExpressions(rewriter, domInfo, encoderModuleOp); + } + + if (failed(mlir::verify(encoderModuleOp))) { + mlir::emitError(encoderModuleOp.getLoc()) + << "failed to verify produced encoder module"; + return signalPassFailure(); + } + + // Write module to the file specified, or stdout if empty. + if (outputFile.empty()) { + LLVM_DEBUG(DBGS() << "writing encoder module to stdout...\n"); + OpPrintingFlags flags; + encoderModuleOp.print(llvm::outs(), flags); + llvm::outs() << "\n"; + } else { + LLVM_DEBUG(DBGS() << "writing encoder module to '" << outputFile + << "'...\n"); + if (failed(writeModule(encoderModuleOp, outputFile))) { + LLVM_DEBUG(DBGS() << "MODULE WRITE FAILED\n"); + return signalPassFailure(); + } + } + } +}; + +} // namespace + +} // namespace mlir::iree_compiler::IREE::Stream diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/UnifyEncodingForGlobals.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/UnifyEncodingForGlobals.cpp index 304c24d89f68..8d66b30189db 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/UnifyEncodingForGlobals.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/UnifyEncodingForGlobals.cpp @@ -16,6 +16,7 @@ #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/DebugLog.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/SymbolTable.h" @@ -538,18 +539,151 @@ static void updateTensorDispatchOp(TensorDispatchOp dispatchOp, } } +// Inserts a re-encode op before the given op if the source encoding doesn't +// match the new (unified) encoding. Returns the re-encoded value, or the +// original source if no re-encoding is needed. +static Value maybeInsertReencode(IRRewriter &rewriter, Operation *op, + Value source, Type sourceEncodingType, + ValueRange sourceEncodingDims, + Value sourceSize, Attribute newEncoding, + AffinityAttr affinityAttr) { + auto expectedType = cast(sourceEncodingType); + Attribute expectedEncoding = expectedType.getEncoding(); + + // No re-encode needed if encodings match. + if (expectedEncoding == newEncoding) { + return source; + } + + LDBG() << " Inserting re-encode: " << newEncoding << " -> " + << expectedEncoding; + + // Build the source type (with unified encoding). + RankedTensorType unifiedType = expectedType.cloneWithEncoding(newEncoding); + + // Compute sizes for unified encoding. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + Value unifiedSize = TensorSizeOfOp::create( + rewriter, op->getLoc(), rewriter.getIndexType(), + TypeAttr::get(unifiedType), sourceEncodingDims, affinityAttr); + + // Insert the re-encode op: unified -> expected. + auto reencodeOp = TensorEncodeOp::create( + rewriter, op->getLoc(), source.getType(), source, + TypeAttr::get(unifiedType), + /*source_encoding_dims=*/sourceEncodingDims, unifiedSize, + TypeAttr::get(expectedType), + /*result_encoding_dims=*/sourceEncodingDims, sourceSize, affinityAttr); + + LDBG() << " Created: " << reencodeOp; + return reencodeOp.getResult(); +} + +// Updates TensorCloneOp by inserting re-encode if needed. +static void updateTensorCloneOp(TensorCloneOp cloneOp, + const OperandEncodingUpdates &operandUpdates, + IRRewriter &rewriter) { + int operandNumber = cloneOp.getSourceMutable().getOperandNumber(); + if (!operandUpdates.contains(operandNumber)) { + return; + } + Attribute newEncoding = operandUpdates.lookup(operandNumber); + Value reencoded = maybeInsertReencode( + rewriter, cloneOp, cloneOp.getSource(), cloneOp.getSourceEncoding(), + cloneOp.getSourceEncodingDims(), cloneOp.getSourceSize(), newEncoding, + cloneOp.getAffinityAttr()); + if (reencoded != cloneOp.getSource()) { + rewriter.modifyOpInPlace( + cloneOp, [&] { cloneOp.getSourceMutable().set(reencoded); }); + } +} + +// Updates TensorEncodeOp by updating the source_encoding attribute. +static void updateTensorEncodeOp(TensorEncodeOp encodeOp, + const OperandEncodingUpdates &operandUpdates, + IRRewriter &rewriter) { + int operandNumber = encodeOp.getSourceMutable().getOperandNumber(); + if (!operandUpdates.contains(operandNumber)) { + return; + } + Attribute newEncoding = operandUpdates.lookup(operandNumber); + auto oldSourceType = cast(encodeOp.getSourceEncoding()); + RankedTensorType newSourceType = oldSourceType.cloneWithEncoding(newEncoding); + rewriter.modifyOpInPlace(encodeOp, [&] { + encodeOp.setSourceEncodingAttr(TypeAttr::get(newSourceType)); + }); + LDBG() << " Updated TensorEncodeOp source encoding to " << newEncoding; +} + +// Updates TensorUpdateOp by inserting re-encode if needed. +static void updateTensorUpdateOp(TensorUpdateOp updateOp, + const OperandEncodingUpdates &operandUpdates, + IRRewriter &rewriter) { + // Handle target operand. + int targetOperandNum = updateOp.getTargetMutable().getOperandNumber(); + if (operandUpdates.contains(targetOperandNum)) { + Attribute newEncoding = operandUpdates.lookup(targetOperandNum); + Value reencoded = maybeInsertReencode( + rewriter, updateOp, updateOp.getTarget(), updateOp.getTargetEncoding(), + updateOp.getTargetEncodingDims(), updateOp.getTargetSize(), newEncoding, + updateOp.getAffinityAttr()); + if (reencoded != updateOp.getTarget()) { + rewriter.modifyOpInPlace( + updateOp, [&] { updateOp.getTargetMutable().set(reencoded); }); + } + } + + // Handle update operand. + unsigned updateOperandNum = updateOp.getUpdateMutable().getOperandNumber(); + if (operandUpdates.contains(updateOperandNum)) { + Attribute newEncoding = operandUpdates.lookup(updateOperandNum); + Value reencoded = maybeInsertReencode( + rewriter, updateOp, updateOp.getUpdate(), updateOp.getUpdateEncoding(), + updateOp.getUpdateEncodingDims(), updateOp.getUpdateSize(), newEncoding, + updateOp.getAffinityAttr()); + if (reencoded != updateOp.getUpdate()) { + rewriter.modifyOpInPlace( + updateOp, [&] { updateOp.getUpdateMutable().set(reencoded); }); + } + } +} + // Applies all cached encoding updates to tensor ops. static void applyTensorEncodingUpdates(TensorEncodingUpdates &updates) { for (auto &[op, operandUpdates] : updates) { + // Copy to local variable to allow capture in C++17 lambdas. + const OperandEncodingUpdates &opUpdates = operandUpdates; IRRewriter rewriter(op->getContext()); - // TODO: Handle other TensorPhaseOp ops (TensorFillOp, etc.) via TypeSwitch. - if (auto dispatchOp = dyn_cast(op)) { - updateTensorDispatchOp(dispatchOp, operandUpdates, rewriter); - } + TypeSwitch(op) + .Case([&](auto dispatchOp) { + updateTensorDispatchOp(dispatchOp, opUpdates, rewriter); + }) + .Case([&](auto cloneOp) { + updateTensorCloneOp(cloneOp, opUpdates, rewriter); + }) + .Case([&](auto encodeOp) { + updateTensorEncodeOp(encodeOp, opUpdates, rewriter); + }) + .Case([&](auto updateOp) { + updateTensorUpdateOp(updateOp, opUpdates, rewriter); + }) + .Case( + [&](auto) { + assert(false && "unexpected tensor op needing encoding update"); + }) + .Default([](Operation *op) { + LDBG() << " Unhandled op: " << op->getName() + << ", maybe it is a new tensor op?"; + assert(false); + }); } } -// Collects updates for stream tensor ops by walking from global loads. +// Collects updates for stream tensor ops by walking from global loads. Fixup +// should be applied to all stream tensor ops that use the encoded global's +// data. static void collectUpdatesForStreamTensorOps(Explorer &explorer, EncodedGlobalInfo &encodedInfo, Attribute newEncoding, @@ -582,18 +716,24 @@ static void collectUpdatesForStreamTensorOps(Explorer &explorer, return WalkResult::advance(); } - // TODO: Handle other tensor phase ops (TensorFillOp, etc.) - auto dispatchOp = dyn_cast(user); - if (!dispatchOp) { - return WalkResult::advance(); - } - - // The operand number is the index in the full operand list (including - // workload). We need the index in getMixedOperands() for encoding lookup. - unsigned mixedOperandIdx = - operand.getOperandNumber() - dispatchOp.getWorkload().size(); - LDBG() << " Found TensorDispatchOp operand " << mixedOperandIdx; - updates[user][mixedOperandIdx] = newEncoding; + // Do not continue walking past these ops because this is the end point. + // The fixup will be applied directly to these ops, so updates are not + // needed for their users. + TypeSwitch(user) + .Case([&](auto dispatchOp) { + // The operand number is the index in the full operand list + // (including workload). We need the index in getMixedOperands() for + // encoding lookup. + unsigned mixedOperandIdx = + operand.getOperandNumber() - dispatchOp.getWorkload().size(); + LDBG() << " Found TensorDispatchOp operand " + << mixedOperandIdx; + updates[user][mixedOperandIdx] = newEncoding; + }) + .Case([&](auto op) { + updates[user][operand.getOperandNumber()] = newEncoding; + }) + .Default([](Operation *op) {}); return WalkResult::advance(); }); } @@ -673,10 +813,12 @@ struct UnifyEncodingForGlobalsPass [](TensorDispatchOp a, TensorDispatchOp b) { std::string aStr, bStr; llvm::raw_string_ostream aStream(aStr), bStream(bStr); - if (auto aAffinity = a.getAffinityAttr()) + if (auto aAffinity = a.getAffinityAttr()) { aStream << aAffinity; - if (auto bAffinity = b.getAffinityAttr()) + } + if (auto bAffinity = b.getAffinityAttr()) { bStream << bAffinity; + } return aStr < bStr; }); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Utils.h b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Utils.h index 5d5b972602f7..4cbc1405be19 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Utils.h +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Utils.h @@ -25,8 +25,9 @@ SmallVector gatherUsedDialectInterfaces(mlir::ModuleOp moduleOp) { SmallPtrSet resultSet; for (auto dialect : moduleOp.getContext()->getLoadedDialects()) { auto *dialectInterface = dialect->getRegisteredInterface(); - if (!dialectInterface) + if (!dialectInterface) { continue; + } resultSet.insert(dialectInterface); } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAffinities.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAffinities.cpp index 5cd52e393a04..dc2be3195fbd 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAffinities.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAffinities.cpp @@ -60,8 +60,9 @@ struct VerifyAffinitiesPass ? WalkResult::skip() : WalkResult::advance(); }) - .wasInterrupted()) + .wasInterrupted()) { return signalPassFailure(); + } // Preserve all analyses since this is a read-only verification pass. markAllAnalysesPreserved(); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAsyncAccessRanges.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAsyncAccessRanges.cpp index a19861592066..3f19b271217c 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAsyncAccessRanges.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAsyncAccessRanges.cpp @@ -23,11 +23,13 @@ namespace mlir::iree_compiler::IREE::Stream { namespace { static std::optional matchConstant(Value value) { - if (!value) + if (!value) { return std::nullopt; + } APInt constant; - if (!matchPattern(value, m_ConstantInt(&constant))) + if (!matchPattern(value, m_ConstantInt(&constant))) { return std::nullopt; + } return constant.getSExtValue(); } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyLowerings.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyLowerings.cpp index a6deb12024eb..e1defae5de3d 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyLowerings.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyLowerings.cpp @@ -122,15 +122,17 @@ class Verifier { // Check types for operands/results. for (auto operandType : llvm::enumerate(op->getOperandTypes())) { - if (isTypeLegal(operandType.value())) + if (isTypeLegal(operandType.value())) { continue; + } emitIllegalTypeError(op, "operand", operandType.index(), operandType.value()); foundAnyIllegal = true; } for (auto resultType : llvm::enumerate(op->getResultTypes())) { - if (isTypeLegal(resultType.value())) + if (isTypeLegal(resultType.value())) { continue; + } emitIllegalTypeError(op, "result", resultType.index(), resultType.value()); foundAnyIllegal = true; @@ -358,8 +360,9 @@ struct VerifyLoweringToAsyncPass } // Allow metadata ops outside of execution regions. - if (op.isMetadata()) + if (op.isMetadata()) { return Verifier::Legality::LEGAL; + } // TODO(benvanik): execution region interface to make this generic. if (!op->template getParentOfType()) { diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel index c36bb57955ce..98a16f2cda2b 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "annotate_affinities.mlir", "annotate_constant_transient_size.mlir", @@ -67,6 +68,7 @@ iree_lit_test_suite( "schedule_execution_timeline_aware.mlir", "specialize_dispatches.mlir", "specialize_encodings.mlir", + "split_parameter_encoder.mlir", "sync_initializers.mlir", "unify_encoding_for_globals.mlir", "verify_affinities.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt index 09c45c587c0a..16bbbb2da114 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt @@ -65,6 +65,7 @@ iree_lit_test_suite( "schedule_execution_timeline_aware.mlir" "specialize_dispatches.mlir" "specialize_encodings.mlir" + "split_parameter_encoder.mlir" "sync_initializers.mlir" "unify_encoding_for_globals.mlir" "verify_affinities.mlir" diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/e2e/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/e2e/BUILD.bazel index e17c889f3868..a7c22e2d0f88 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/e2e/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/e2e/BUILD.bazel @@ -20,6 +20,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "async_parameters.mlir", ], diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/split_parameter_encoder.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/split_parameter_encoder.mlir new file mode 100644 index 000000000000..8d3cd53aa5bf --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/split_parameter_encoder.mlir @@ -0,0 +1,1770 @@ +// RUN: iree-opt --split-input-file --iree-stream-split-parameter-encoder %s | FileCheck %s +// RUN: iree-opt --split-input-file --iree-stream-split-parameter-encoder='mode=overlay' %s | FileCheck %s --check-prefix=OVERLAY +// RUN: iree-opt --split-input-file --iree-stream-split-parameter-encoder='mode=overlay' %s | FileCheck %s --check-prefix=OVERLAY-MIXED +// RUN: iree-opt --split-input-file --iree-stream-split-parameter-encoder='mode=consolidate' %s | FileCheck %s --check-prefix=COMPARE-CONSOLIDATE +// RUN: iree-opt --split-input-file --iree-stream-split-parameter-encoder='mode=overlay' %s | FileCheck %s --check-prefix=COMPARE-OVERLAY +// RUN: iree-opt --split-input-file --iree-stream-split-parameter-encoder='output-scope=my_custom_scope' %s | FileCheck %s --check-prefix=SCOPE +// RUN: iree-opt --split-input-file --iree-stream-split-parameter-encoder='max-encoding-growth-factor=2.0' %s | FileCheck %s --check-prefix=GROWTH2 +// RUN: iree-opt --split-input-file --iree-stream-split-parameter-encoder='mode=overlay' %s | FileCheck %s --check-prefix=EMPTY + +// Tests simple constant with splat initialization. +// This is the most basic case - a global initialized with a constant splat. +// This should NOT be hoisted (no parameter input). + +// CHECK-LABEL: module { +// CHECK-NOT: module @encoder +// CHECK: util.global private @simple_constant : !stream.resource +util.global private @simple_constant : !stream.resource + +// CHECK: util.initializer { +util.initializer { + // CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32 + %c0_i32 = arith.constant 0 : i32 + // CHECK: %[[C1024:.+]] = arith.constant 1024 : index + %c1024 = arith.constant 1024 : index + // CHECK: %[[SPLAT:.+]] = stream.async.splat %[[C0_I32]] : i32 -> !stream.resource{%[[C1024]]} + %splat = stream.async.splat %c0_i32 : i32 -> !stream.resource{%c1024} + // CHECK: util.global.store %[[SPLAT]], @simple_constant : !stream.resource + util.global.store %splat, @simple_constant : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests computed constant with transformation. +// This tests a constant that undergoes some computation (fill operation). +// Should NOT be hoisted (no parameter input). + +// CHECK-LABEL: module { +// CHECK-NOT: module @encoder +// CHECK: util.global private @computed_constant : !stream.resource +util.global private @computed_constant : !stream.resource + +// CHECK: util.initializer { +util.initializer { + // CHECK-DAG: %[[C0_I32:.+]] = arith.constant 0 : i32 + %c0_i32 = arith.constant 0 : i32 + // CHECK-DAG: %[[C42_I32:.+]] = arith.constant 42 : i32 + %c42_i32 = arith.constant 42 : i32 + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + // CHECK-DAG: %[[C256:.+]] = arith.constant 256 : index + %c256 = arith.constant 256 : index + // CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index + %c1024 = arith.constant 1024 : index + + // Create base splat. + // CHECK: %[[SPLAT:.+]] = stream.async.splat %[[C0_I32]] : i32 -> !stream.resource{%[[C1024]]} + %splat = stream.async.splat %c0_i32 : i32 -> !stream.resource{%c1024} + + // Fill a region with different value. + // CHECK: %[[FILLED:.+]] = stream.async.fill %[[C42_I32]], %[[SPLAT]][%[[C0]] to %[[C256]] for %[[C256]]] : i32 -> %[[SPLAT]] as !stream.resource{%[[C1024]]} + %filled = stream.async.fill %c42_i32, %splat[%c0 to %c256 for %c256] : i32 -> %splat as !stream.resource{%c1024} + + // CHECK: util.global.store %[[FILLED]], @computed_constant : !stream.resource + util.global.store %filled, @computed_constant : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests multiple constants with different patterns. +// This tests that the pass can handle multiple globals, some with splat and some +// with more complex initialization. + +// CHECK-LABEL: module { +// CHECK-NOT: module @encoder +// CHECK: util.global private @constant_a : !stream.resource +util.global private @constant_a : !stream.resource +// CHECK: util.global private @constant_b : !stream.resource +util.global private @constant_b : !stream.resource + +// CHECK: util.initializer { +util.initializer { + // CHECK-DAG: %[[C0_I32:.+]] = arith.constant 0 : i32 + %c0_i32 = arith.constant 0 : i32 + // CHECK-DAG: %[[C1_I32:.+]] = arith.constant 1 : i32 + %c1_i32 = arith.constant 1 : i32 + // CHECK-DAG: %[[C512:.+]] = arith.constant 512 : index + %c512 = arith.constant 512 : index + // CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index + %c1024 = arith.constant 1024 : index + + // First constant: simple splat. + // CHECK: %[[SPLAT_A:.+]] = stream.async.splat %[[C0_I32]] : i32 -> !stream.resource{%[[C512]]} + %splat_a = stream.async.splat %c0_i32 : i32 -> !stream.resource{%c512} + // CHECK: util.global.store %[[SPLAT_A]], @constant_a : !stream.resource + util.global.store %splat_a, @constant_a : !stream.resource + + // Second constant: different splat value and size. + // CHECK: %[[SPLAT_B:.+]] = stream.async.splat %[[C1_I32]] : i32 -> !stream.resource{%[[C1024]]} + %splat_b = stream.async.splat %c1_i32 : i32 -> !stream.resource{%c1024} + // CHECK: util.global.store %[[SPLAT_B]], @constant_b : !stream.resource + util.global.store %splat_b, @constant_b : !stream.resource + + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests parameter transformation. +// This tests loading a parameter and applying a transformation (fill operation). +// This SHOULD be hoisted since it has a parameter input. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK: util.global private @parameter_transformed : !stream.resource +util.global private @parameter_transformed : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c42_i32 = arith.constant 42 : i32 + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + // Load parameter from external source. + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"param0"> : vector<1024xi8> + + // Fill a region with different value. + // CHECK-NOT: stream.async.fill + %filled = stream.async.fill %c42_i32, %param[%c0 to %c256 for %c256] : i32 -> %param as !stream.resource{%c1024} + + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @parameter_transformed + util.global.store %filled, @parameter_transformed : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests pure splat should NOT be hoisted (negative case). +// This tests that a pure splat with no inputs and no transformation is not hoisted. +// The pass should leave this module unchanged. + +// CHECK-LABEL: module { +// CHECK-NOT: module @encoder +// CHECK: util.global private @pure_splat_only : !stream.resource +util.global private @pure_splat_only : !stream.resource + +// CHECK: util.initializer { +util.initializer { + // CHECK: %[[C99_I32:.+]] = arith.constant 99 : i32 + %c99_i32 = arith.constant 99 : i32 + // CHECK: %[[C2048:.+]] = arith.constant 2048 : index + %c2048 = arith.constant 2048 : index + + // Pure splat with no parameter input - should NOT be hoisted. + // CHECK: %[[SPLAT:.+]] = stream.async.splat %[[C99_I32]] : i32 -> !stream.resource{%[[C2048]]} + %splat = stream.async.splat %c99_i32 : i32 -> !stream.resource{%c2048} + + // CHECK: util.global.store %[[SPLAT]], @pure_splat_only : !stream.resource + util.global.store %splat, @pure_splat_only : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests a parameter transformed by a dispatch operation. +// Should be hoisted as it represents expensive computation on a parameter. +// Real-world: Elementwise operations, quantization, or encoding on weights. + +stream.executable private @executable { + stream.executable.export public @dispatch +} + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK: util.global private @param_with_dispatch : !stream.resource +util.global private @param_with_dispatch : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + + // Load parameter from external source. + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"dispatch_param"> : vector<1024xi8> + + // Dispatch performing operation on parameter. + // CHECK-NOT: stream.async.dispatch + %result = stream.async.dispatch @executable::@dispatch[%c1, %c1, %c1](%param[%c0 to %c1024 for %c1024]) : + (!stream.resource{%c1024}) -> !stream.resource{%c1024} + + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @param_with_dispatch : !stream.resource + util.global.store %result, @param_with_dispatch : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests parameter + splat + dispatch pattern. +// Splat should be cloned to consumers but not serialized (preferCloneToConsumers). +// Real-world: Parameter combined with constant baseline (e.g., weight + bias). + +stream.executable private @executable { + stream.executable.export public @add_dispatch +} + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK: util.global private @param_splat_dispatch : !stream.resource +util.global private @param_splat_dispatch : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + + // Load parameter from external source. + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"weights"> : vector<1024xi8> + + // Create splat (should be cloned but not serialized). + %c0_i32 = arith.constant 0 : i32 + // CHECK-NOT: stream.async.splat + %splat = stream.async.splat %c0_i32 : i32 -> !stream.resource{%c1024} + + // Dispatch using both parameter and splat. + // CHECK-NOT: stream.async.dispatch + %result = stream.async.dispatch @executable::@add_dispatch[%c1, %c1, %c1]( + %param[%c0 to %c1024 for %c1024], + %splat[%c0 to %c1024 for %c1024] + ) : (!stream.resource{%c1024}, !stream.resource{%c1024}) -> !stream.resource{%c1024} + + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @param_splat_dispatch : !stream.resource + util.global.store %result, @param_splat_dispatch : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests parameter with metadata operations (subview). +// Metadata operations should not prevent hoisting. +// Real-world: Extract layer weights from combined parameter. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK: util.global private @param_metadata_ops : !stream.resource +util.global private @param_metadata_ops : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c1024 = arith.constant 1024 : index + + // Load larger parameter. + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"combined_param"> : vector<1024xi8> + + // Extract slice (metadata operation). + // CHECK-NOT: stream.async.slice + %slice = stream.async.slice %param[%c256 to %c512] : !stream.resource{%c1024} -> !stream.resource{%c256} + + // Apply transformation to slice. + %c100_i32 = arith.constant 100 : i32 + // CHECK-NOT: stream.async.fill + %filled = stream.async.fill %c100_i32, %slice[%c0 to %c256 for %c256] : i32 -> %slice as !stream.resource{%c256} + + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @param_metadata_ops : !stream.resource + util.global.store %filled, @param_metadata_ops : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests borderline growth (1.15x) - should pass. +// Within threshold growth should be allowed. +// Real-world: Small padding for alignment. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK: util.global private @param_acceptable_growth : !stream.resource +util.global private @param_acceptable_growth : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c1024 = arith.constant 1024 : index + %c1180 = arith.constant 1180 : index // ~1.15x growth + + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"param_growth"> : vector<1024xi8> + + // Slight growth for padding - should be within 1.2x threshold + %c0 = arith.constant 0 : index + %c156 = arith.constant 156 : index + %c0_i32 = arith.constant 0 : i32 + + // Create slightly larger buffer + // CHECK-NOT: stream.async.splat + %padded = stream.async.splat %c0_i32 : i32 -> !stream.resource{%c1180} + + // Copy parameter into padded buffer + // CHECK-NOT: stream.async.update + %result = stream.async.update %param, %padded[%c0 to %c1024] : + !stream.resource{%c1024} -> %padded as !stream.resource{%c1180} + + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: %[[RESULT:.+]] = stream.resource.subview %[[PARAM]] + // CHECK: util.global.store %[[RESULT]], @param_acceptable_growth : !stream.resource + util.global.store %result, @param_acceptable_growth : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +//===----------------------------------------------------------------------===// +// Control Flow in Initializers +//===----------------------------------------------------------------------===// + +// Tests scf.for loop with fixed bounds. +// Loop should be unrolled if bounds are constant. +// Real-world: Fixed preprocessing iterations. +// Tests scf.for loop with constant bounds. +// Should hoist the loop and its body since bounds are constant. +// Real-world: Iterative parameter transformations. + +// Encoder module should be generated with scf.for hoisted. +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// Encoder should contain the scf.for loop with result scattered to parameter. +// CHECK: %[[IMPORT_TP:.+]] = stream.timepoint.import {{.+}} %arg1 : (!hal.fence) => !stream.timepoint +// CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index +// CHECK: %[[PACK_SIZE:.+]]:2 = stream.resource.pack {{.+}} slices({ +// CHECK-NEXT: [0, 0] = %[[C1024]] +// CHECK-NEXT: }) : index +// CHECK: %[[ALLOCA:.+]], %[[ALLOCA_TP:.+]] = stream.resource.alloca uninitialized {{.+}} await(%[[IMPORT_TP]]) => !stream.resource{%[[PACK_SIZE]]#0} => !stream.timepoint +// CHECK: %[[ALLOCA_READY:.+]] = stream.timepoint.await %[[ALLOCA_TP]] => %[[ALLOCA]] : !stream.resource{%[[PACK_SIZE]]#0} +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C256:.+]] = arith.constant 256 : index +// CHECK-DAG: %[[C0_I64:.+]] = arith.constant 0 : i64 +// CHECK: %[[PARAM_RESOURCE:.+]], %[[PARAM_TP:.+]] = stream.async.parameter.load {{.+}} await(%[[IMPORT_TP]]) "model"::"iterative_param"[%[[C0_I64]]] : !stream.resource{%[[C1024]]} => !stream.timepoint +// CHECK: %[[INPUT:.+]] = stream.timepoint.await %[[PARAM_TP]] => %[[PARAM_RESOURCE]] : !stream.resource{%[[C1024]]} +// CHECK: %[[LOOP_RESULT:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG:.+]] = %[[INPUT]]) -> (!stream.resource) { +// CHECK: %[[C100:.+]] = arith.constant 100 : i32 +// CHECK: %[[FILLED:.+]] = stream.async.fill %[[C100]], %[[ARG]][%[[C0]] to %[[C256]] for %[[C256]]] : i32 -> %[[ARG]] as !stream.resource{%[[C1024]]} +// CHECK: scf.yield %[[FILLED]] : !stream.resource +// CHECK: } +// CHECK: %[[UPDATE_END:.+]] = arith.addi %[[PACK_SIZE]]#1, %[[C1024]] : index +// CHECK: %[[UPDATED:.+]] = stream.async.update {{.+}} %[[LOOP_RESULT]], %[[ALLOCA_READY]][%[[PACK_SIZE]]#1 to %[[UPDATE_END]]] : !stream.resource{%[[C1024]]} -> %[[ALLOCA_READY]] as !stream.resource{%[[PACK_SIZE]]#0} +// CHECK: %[[BARRIER_RESULT:.+]], %[[BARRIER_TP:.+]] = stream.timepoint.barrier {{.+}} %[[UPDATED]] : !stream.resource{%[[PACK_SIZE]]#0} => !stream.timepoint +// CHECK: %[[SCATTER_RESULT:.+]], %[[SCATTER_TP:.+]] = stream.async.parameter.scatter {{.+}} await(%[[BARRIER_TP]]) { +// CHECK-NEXT: %[[BARRIER_RESULT]][%[[PACK_SIZE]]#1 to %[[UPDATE_END]] for %[[C1024]]] : !stream.resource{%[[PACK_SIZE]]#0} -> ""::"parameter0"[%[[C0_I64]]] +// CHECK-NEXT: } : !stream.resource => !stream.timepoint +// CHECK: %[[JOIN_TP:.+]] = stream.timepoint.join max(%[[SCATTER_TP]]) => !stream.timepoint +// CHECK: %[[DEALLOCA_TP:.+]] = stream.resource.dealloca {{.+}} await(%[[JOIN_TP]]) => %[[SCATTER_RESULT]] : !stream.resource{%[[PACK_SIZE]]#0} => !stream.timepoint +// CHECK: stream.timepoint.chain_external {{.+}} %[[DEALLOCA_TP]] => (%arg2 : !hal.fence) + +// Original module should have parameter load instead of scf.for. +// CHECK-LABEL: util.global private @scf_for_fixed_bounds +util.global private @scf_for_fixed_bounds : !stream.resource + +util.initializer { + %c0 = arith.constant 0 : index + %c3 = arith.constant 3 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"iterative_param"> : vector<1024xi8> + + // Fixed-bound loop that could be unrolled. + // CHECK-NOT: scf.for + // CHECK-NOT: stream.async.fill + %result = scf.for %i = %c0 to %c3 step %c1 + iter_args(%arg = %param) -> (!stream.resource) { + // Apply transformation in each iteration. + %c100_i32 = arith.constant 100 : i32 + %processed = stream.async.fill %c100_i32, %arg[%c0 to %c256 for %c256] : + i32 -> %arg as !stream.resource{%c1024} + scf.yield %processed : !stream.resource + } + + // Original module loads from parameter instead of executing loop. + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @scf_for_fixed_bounds + util.global.store %result, @scf_for_fixed_bounds : !stream.resource + util.return +} + +// ----- + +// Tests scf.if conditional with compile-time constant condition. +// Should hoist the taken branch if condition is constant. +// Real-world: Conditional initialization for specific target. + +// Encoder module should be generated with scf.if hoisted. +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// Encoder should contain the scf.if conditional with result scattered to parameter. +// CHECK: %[[IMPORT_TP:.+]] = stream.timepoint.import {{.+}} %arg1 : (!hal.fence) => !stream.timepoint +// CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index +// CHECK: %[[PACK_SIZE:.+]]:2 = stream.resource.pack {{.+}} slices({ +// CHECK-NEXT: [0, 0] = %[[C1024]] +// CHECK-NEXT: }) : index +// CHECK: %[[ALLOCA:.+]], %[[ALLOCA_TP:.+]] = stream.resource.alloca uninitialized {{.+}} await(%[[IMPORT_TP]]) => !stream.resource{%[[PACK_SIZE]]#0} => !stream.timepoint +// CHECK: %[[ALLOCA_READY:.+]] = stream.timepoint.await %[[ALLOCA_TP]] => %[[ALLOCA]] : !stream.resource{%[[PACK_SIZE]]#0} +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C256:.+]] = arith.constant 256 : index +// CHECK-DAG: %[[C0_I64:.+]] = arith.constant 0 : i64 +// CHECK: %[[PARAM_RESOURCE:.+]], %[[PARAM_TP:.+]] = stream.async.parameter.load {{.+}} await(%[[IMPORT_TP]]) "model"::"conditional_param"[%[[C0_I64]]] : !stream.resource{%[[C1024]]} => !stream.timepoint +// CHECK: %[[INPUT:.+]] = stream.timepoint.await %[[PARAM_TP]] => %[[PARAM_RESOURCE]] : !stream.resource{%[[C1024]]} +// CHECK: %[[IF_RESULT:.+]] = scf.if %[[TRUE]] -> (!stream.resource) { +// CHECK: %[[C42:.+]] = arith.constant 42 : i32 +// CHECK: %[[FILLED:.+]] = stream.async.fill %[[C42]], %[[INPUT]][%[[C0]] to %[[C256]] for %[[C256]]] : i32 -> %[[INPUT]] as !stream.resource{%[[C1024]]} +// CHECK: scf.yield %[[FILLED]] : !stream.resource +// CHECK: } else { +// CHECK: scf.yield %[[INPUT]] : !stream.resource +// CHECK: } +// CHECK: %[[UPDATE_END:.+]] = arith.addi %[[PACK_SIZE]]#1, %[[C1024]] : index +// CHECK: %[[UPDATED:.+]] = stream.async.update {{.+}} %[[IF_RESULT]], %[[ALLOCA_READY]][%[[PACK_SIZE]]#1 to %[[UPDATE_END]]] : !stream.resource{%[[C1024]]} -> %[[ALLOCA_READY]] as !stream.resource{%[[PACK_SIZE]]#0} +// CHECK: %[[BARRIER_RESULT:.+]], %[[BARRIER_TP:.+]] = stream.timepoint.barrier {{.+}} %[[UPDATED]] : !stream.resource{%[[PACK_SIZE]]#0} => !stream.timepoint +// CHECK: %[[SCATTER_RESULT:.+]], %[[SCATTER_TP:.+]] = stream.async.parameter.scatter {{.+}} await(%[[BARRIER_TP]]) { +// CHECK-NEXT: %[[BARRIER_RESULT]][%[[PACK_SIZE]]#1 to %[[UPDATE_END]] for %[[C1024]]] : !stream.resource{%[[PACK_SIZE]]#0} -> ""::"parameter0"[%[[C0_I64]]] +// CHECK-NEXT: } : !stream.resource => !stream.timepoint +// CHECK: %[[JOIN_TP:.+]] = stream.timepoint.join max(%[[SCATTER_TP]]) => !stream.timepoint +// CHECK: %[[DEALLOCA_TP:.+]] = stream.resource.dealloca {{.+}} await(%[[JOIN_TP]]) => %[[SCATTER_RESULT]] : !stream.resource{%[[PACK_SIZE]]#0} => !stream.timepoint +// CHECK: stream.timepoint.chain_external {{.+}} %[[DEALLOCA_TP]] => (%arg2 : !hal.fence) + +// Original module should have parameter load instead of scf.if. +// CHECK-LABEL: util.global private @scf_if_constant_condition +util.global private @scf_if_constant_condition : !stream.resource + +util.initializer { + %true = arith.constant true + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"conditional_param"> : vector<1024xi8> + + // Conditional with compile-time constant. + // CHECK-NOT: scf.if + // CHECK-NOT: stream.async.fill + %result = scf.if %true -> (!stream.resource) { + // True branch - should be taken. + %c42_i32 = arith.constant 42 : i32 + %processed = stream.async.fill %c42_i32, %param[%c0 to %c256 for %c256] : + i32 -> %param as !stream.resource{%c1024} + scf.yield %processed : !stream.resource + } else { + // False branch - should be eliminated. + scf.yield %param : !stream.resource + } + + // Original module loads from parameter instead of executing conditional. + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @scf_if_constant_condition + util.global.store %result, @scf_if_constant_condition : !stream.resource + util.return +} + +// ----- + +//===----------------------------------------------------------------------===// +// Multiple Outputs from Single Parameter +//===----------------------------------------------------------------------===// + +// Tests single parameter producing multiple transformed outputs. +// Should hoist both transformations, outputs packed. +// Real-world: Different quantization formats for different layers. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK-DAG: util.global private @single_param_multi_output_a : !stream.resource +util.global private @single_param_multi_output_a : !stream.resource +// CHECK-DAG: util.global private @single_param_multi_output_b : !stream.resource +util.global private @single_param_multi_output_b : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c768 = arith.constant 768 : index + %c1024 = arith.constant 1024 : index + + // Single parameter input + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"shared_param"> : vector<1024xi8> + + // First transformation + %c100_i32 = arith.constant 100 : i32 + // CHECK-NOT: stream.async.fill + %output_a = stream.async.fill %c100_i32, %param[%c0 to %c256 for %c256] : + i32 -> %param as !stream.resource{%c1024} + + // Second transformation + %c200_i32 = arith.constant 200 : i32 + %output_b = stream.async.fill %c200_i32, %param[%c512 to %c768 for %c256] : + i32 -> %param as !stream.resource{%c1024} + + // Both outputs are packed into a single parameter, loaded twice and extracted via subviews. + // CHECK-DAG: %[[PACKED_SIZE:.+]] = arith.constant 2048 : index + // CHECK-DAG: %[[SUBVIEW_OFFSET_0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[SUBVIEW_SIZE:.+]] = arith.constant 1024 : index + // CHECK-DAG: %[[PARAM_A:.+]] = stream.async.constant : !stream.resource{%[[PACKED_SIZE]]} = #stream.parameter.named<""::"parameter0"> + // CHECK-DAG: %[[RESULT_A:.+]] = stream.resource.subview %[[PARAM_A]][%[[SUBVIEW_OFFSET_0]]] : !stream.resource{%[[PACKED_SIZE]]} -> !stream.resource{%[[SUBVIEW_SIZE]]} + // CHECK-DAG: util.global.store %[[RESULT_A]], @single_param_multi_output_a : !stream.resource + util.global.store %output_a, @single_param_multi_output_a : !stream.resource + // CHECK-DAG: %[[PARAM_B:.+]] = stream.async.constant : !stream.resource{%[[PACKED_SIZE]]} = #stream.parameter.named<""::"parameter0"> + // CHECK-DAG: %[[RESULT_B:.+]] = stream.resource.subview %[[PARAM_B]][%[[SUBVIEW_SIZE]]] : !stream.resource{%[[PACKED_SIZE]]} -> !stream.resource{%[[SUBVIEW_SIZE]]} + // CHECK-DAG: util.global.store %[[RESULT_B]], @single_param_multi_output_b : !stream.resource + util.global.store %output_b, @single_param_multi_output_b : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +//===----------------------------------------------------------------------===// +// Device Specialization & Affinity +//===----------------------------------------------------------------------===// + +// Tests parameter with affinity annotation. +// Should hoist with affinity preserved in encoder. +// Real-world: GPU-specific parameter transformation. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK: util.global private @param_with_affinity : !stream.resource +util.global private @param_with_affinity : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + // Parameter with device affinity + %param = stream.async.constant on(#hal.device.affinity<@device_0>) : + !stream.resource{%c1024} = + #stream.parameter.named<"model"::"gpu_param"> : vector<1024xi8> + + // Transformation maintaining affinity + %c42_i32 = arith.constant 42 : i32 + // CHECK-NOT: stream.async.fill + %result = stream.async.fill on(#hal.device.affinity<@device_0>) %c42_i32, + %param[%c0 to %c256 for %c256] : i32 -> %param as !stream.resource{%c1024} + + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @param_with_affinity : !stream.resource + util.global.store %result, @param_with_affinity : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + + +// ----- + +//===----------------------------------------------------------------------===// +// Stress Tests +//===----------------------------------------------------------------------===// + +// Tests very small parameter (1 byte). +// Should handle minimum size parameters. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK: util.global private @minimum_size_param : !stream.resource +util.global private @minimum_size_param : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c1 = arith.constant 1 : index + + // Tiny 1-byte parameter + %param = stream.async.constant : !stream.resource{%c1} = + #stream.parameter.named<"model"::"tiny"> : vector<1xi8> + + // Even tiny transform should work + %c0 = arith.constant 0 : index + %c42_i32 = arith.constant 42 : i32 + // CHECK-NOT: stream.async.fill + %filled = stream.async.fill %c42_i32, %param[%c0 to %c1 for %c1] : + i32 -> %param as !stream.resource{%c1} + + // CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + // CHECK: %[[PARAM:.+]] = stream.async.constant : !stream.resource{%[[C64]]} = #stream.parameter.named<""::"parameter0"> + // CHECK: %[[RESULT:.+]] = stream.resource.subview %[[PARAM]][%[[C0]]] : !stream.resource{%[[C64]]} -> !stream.resource{%[[C1]]} + // CHECK: util.global.store %[[RESULT]], @minimum_size_param : !stream.resource + util.global.store %filled, @minimum_size_param : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests deep expression DAG (multiple levels of operations). +// Should handle deep computation chains. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK: util.global private @deep_expression_dag : !stream.resource +util.global private @deep_expression_dag : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c192 = arith.constant 192 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + // Load parameter + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"deep_param"> : vector<1024xi8> + + // Deep chain of transformations + // CHECK-NOT: stream.async.fill + %c1_i32 = arith.constant 1 : i32 + %stage1 = stream.async.fill %c1_i32, %param[%c0 to %c64 for %c64] : + i32 -> %param as !stream.resource{%c1024} + + %c2_i32 = arith.constant 2 : i32 + %stage2 = stream.async.fill %c2_i32, %stage1[%c64 to %c128 for %c64] : + i32 -> %stage1 as !stream.resource{%c1024} + + %c3_i32 = arith.constant 3 : i32 + %stage3 = stream.async.fill %c3_i32, %stage2[%c128 to %c192 for %c64] : + i32 -> %stage2 as !stream.resource{%c1024} + + %c4_i32 = arith.constant 4 : i32 + %stage4 = stream.async.fill %c4_i32, %stage3[%c192 to %c256 for %c64] : + i32 -> %stage3 as !stream.resource{%c1024} + + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @deep_expression_dag : !stream.resource + util.global.store %stage4, @deep_expression_dag : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +//===----------------------------------------------------------------------===// +// Advanced Growth Factor Tests +//===----------------------------------------------------------------------===// + +// Tests exact 1.2x growth threshold - should pass. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK: util.global private @exact_growth_threshold : !stream.resource +util.global private @exact_growth_threshold : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c1000 = arith.constant 1000 : index + %c1200 = arith.constant 1200 : index // Exactly 1.2x + + %param = stream.async.constant : !stream.resource{%c1000} = + #stream.parameter.named<"model"::"exact_threshold"> : vector<1000xi8> + + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + // CHECK-NOT: stream.async.splat + %padded = stream.async.splat %c0_i32 : i32 -> !stream.resource{%c1200} + // CHECK-NOT: stream.async.update + %result = stream.async.update %param, %padded[%c0 to %c1000] : + !stream.resource{%c1000} -> %padded as !stream.resource{%c1200} + + // CHECK-DAG: %[[C1216:.+]] = arith.constant 1216 : index + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[C1200:.+]] = arith.constant 1200 : index + // CHECK: %[[PARAM:.+]] = stream.async.constant : !stream.resource{%[[C1216]]} = #stream.parameter.named<""::"parameter0"> + // CHECK: %[[RESULT:.+]] = stream.resource.subview %[[PARAM]][%[[C0]]] : !stream.resource{%[[C1216]]} -> !stream.resource{%[[C1200]]} + // CHECK: util.global.store %[[RESULT]], @exact_growth_threshold : !stream.resource + util.global.store %result, @exact_growth_threshold : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests just over 1.2x growth (1.21x) - should reject. + +// CHECK-LABEL: module { +// CHECK-NOT: module @encoder +// CHECK: util.global private @over_growth_threshold : !stream.resource +util.global private @over_growth_threshold : !stream.resource + +// CHECK: util.initializer { +util.initializer { + // CHECK-DAG: %[[C1000:.+]] = arith.constant 1000 : index + %c1000 = arith.constant 1000 : index + // CHECK-DAG: %[[C1210:.+]] = arith.constant 1210 : index + %c1210 = arith.constant 1210 : index // 1.21x - over threshold + + // CHECK: %[[PARAM:.+]] = stream.async.constant : !stream.resource{%[[C1000]]} = #stream.parameter.named<"model"::"over_threshold"> + %param = stream.async.constant : !stream.resource{%c1000} = + #stream.parameter.named<"model"::"over_threshold"> : vector<1000xi8> + + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + // CHECK: %[[PADDED:.+]] = stream.async.splat + %padded = stream.async.splat %c0_i32 : i32 -> !stream.resource{%c1210} + // CHECK: %[[RESULT:.+]] = stream.async.update %[[PARAM]], %[[PADDED]] + %result = stream.async.update %param, %padded[%c0 to %c1000] : + !stream.resource{%c1000} -> %padded as !stream.resource{%c1210} + + // CHECK: util.global.store %[[RESULT]], @over_growth_threshold : !stream.resource + util.global.store %result, @over_growth_threshold : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +//===----------------------------------------------------------------------===// +// Complex Data Flow Patterns +//===----------------------------------------------------------------------===// + +// Tests parameter used by multiple operations (wide DAG). +// Single parameter with many consumers. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK-DAG: util.global private @wide_expression_dag_a : !stream.resource +util.global private @wide_expression_dag_a : !stream.resource +// CHECK-DAG: util.global private @wide_expression_dag_b : !stream.resource +util.global private @wide_expression_dag_b : !stream.resource +// CHECK-DAG: util.global private @wide_expression_dag_c : !stream.resource +util.global private @wide_expression_dag_c : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c0 = arith.constant 0 : index + %c100 = arith.constant 100 : index + %c200 = arith.constant 200 : index + %c300 = arith.constant 300 : index + %c1024 = arith.constant 1024 : index + + // Single parameter used by many operations + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"wide_param"> : vector<1024xi8> + + // Many transformations using the same parameter + %c10_i32 = arith.constant 10 : i32 + // CHECK-NOT: stream.async.fill + %out_a = stream.async.fill %c10_i32, %param[%c0 to %c100 for %c100] : + i32 -> %param as !stream.resource{%c1024} + + %c20_i32 = arith.constant 20 : i32 + %out_b = stream.async.fill %c20_i32, %param[%c100 to %c200 for %c100] : + i32 -> %param as !stream.resource{%c1024} + + %c30_i32 = arith.constant 30 : i32 + %out_c = stream.async.fill %c30_i32, %param[%c200 to %c300 for %c100] : + i32 -> %param as !stream.resource{%c1024} + + // All outputs packed into a single parameter and extracted via subviews. + // CHECK-DAG: %[[PACKED_SIZE:.+]] = arith.constant 3072 : index + // CHECK-DAG: %[[OFFSET_0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[SUBVIEW_SIZE:.+]] = arith.constant 1024 : index + // CHECK-DAG: %[[OFFSET_2048:.+]] = arith.constant 2048 : index + // CHECK-DAG: %[[PARAM_A:.+]] = stream.async.constant : !stream.resource{%[[PACKED_SIZE]]} = #stream.parameter.named<""::"parameter0"> + // CHECK-DAG: %[[RESULT_A:.+]] = stream.resource.subview %[[PARAM_A]][%[[OFFSET_0]]] : !stream.resource{%[[PACKED_SIZE]]} -> !stream.resource{%[[SUBVIEW_SIZE]]} + // CHECK-DAG: util.global.store %[[RESULT_A]], @wide_expression_dag_a : !stream.resource + util.global.store %out_a, @wide_expression_dag_a : !stream.resource + // CHECK-DAG: %[[PARAM_B:.+]] = stream.async.constant : !stream.resource{%[[PACKED_SIZE]]} = #stream.parameter.named<""::"parameter0"> + // CHECK-DAG: %[[RESULT_B:.+]] = stream.resource.subview %[[PARAM_B]][%[[SUBVIEW_SIZE]]] : !stream.resource{%[[PACKED_SIZE]]} -> !stream.resource{%[[SUBVIEW_SIZE]]} + // CHECK-DAG: util.global.store %[[RESULT_B]], @wide_expression_dag_b : !stream.resource + util.global.store %out_b, @wide_expression_dag_b : !stream.resource + // CHECK-DAG: %[[PARAM_C:.+]] = stream.async.constant : !stream.resource{%[[PACKED_SIZE]]} = #stream.parameter.named<""::"parameter0"> + // CHECK-DAG: %[[RESULT_C:.+]] = stream.resource.subview %[[PARAM_C]][%[[OFFSET_2048]]] : !stream.resource{%[[PACKED_SIZE]]} -> !stream.resource{%[[SUBVIEW_SIZE]]} + // CHECK-DAG: util.global.store %[[RESULT_C]], @wide_expression_dag_c : !stream.resource + util.global.store %out_c, @wide_expression_dag_c : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests parameter transformation with clone operation. +// Clone operations should be handled (may have preferCloneToConsumers). + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK: util.global private @param_with_clone : !stream.resource +util.global private @param_with_clone : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"clone_param"> : vector<1024xi8> + + // Clone operation (might have preferCloneToConsumers) + // CHECK-NOT: stream.async.clone + %cloned = stream.async.clone %param : !stream.resource{%c1024} -> + !stream.resource{%c1024} + + // Transform the clone + %c99_i32 = arith.constant 99 : i32 + // CHECK-NOT: stream.async.fill + %result = stream.async.fill %c99_i32, %cloned[%c0 to %c256 for %c256] : + i32 -> %cloned as !stream.resource{%c1024} + + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @param_with_clone : !stream.resource + util.global.store %result, @param_with_clone : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests parameter transformation with clone at END of expression. +// This tests findProducedValue skipping past final clone to find producer. +// Pattern: param → clone(to *) → dispatch → clone(to constant) → store. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +stream.executable private @dispatch_for_clone_test { + stream.executable.export public @fill +} + +// CHECK-LABEL: module { +// CHECK: util.global private @param_with_trailing_clone : !stream.resource +util.global private @param_with_trailing_clone : !stream.resource + +// The original ops (clone → dispatch → clone) should all be hoisted to encoder. +// CHECK: util.initializer { +// CHECK-NOT: stream.async.clone +// CHECK-NOT: stream.async.dispatch +// CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> +// CHECK: util.global.store %[[PARAM]], @param_with_trailing_clone +// CHECK: util.return +// CHECK: } +util.initializer { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"trailing_clone_param"> : vector<1024xi8> + + // Clone to unknown lifetime for dispatch input. + %for_dispatch = stream.async.clone %param : + !stream.resource{%c1024} -> !stream.resource<*>{%c1024} + + // Dispatch transforms the parameter. + %dispatched = stream.async.dispatch @dispatch_for_clone_test::@fill[%c1, %c1, %c1](%for_dispatch[%c0 to %c1024 for %c1024]) : + (!stream.resource<*>{%c1024}) -> !stream.resource<*>{%c1024} + + // Clone at END of expression back to constant lifetime. + // findProducedValue must skip this to find the dispatch as the producer. + %result = stream.async.clone %dispatched : + !stream.resource<*>{%c1024} -> !stream.resource{%c1024} + + util.global.store %result, @param_with_trailing_clone : !stream.resource + util.return +} + +// ----- + +//===----------------------------------------------------------------------===// +// Transfer Operations +//===----------------------------------------------------------------------===// + +// Tests parameter with transfer operations. +// Transfers should be handled correctly. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK: util.global private @param_transfer : !stream.resource +util.global private @param_transfer : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"transfer_param"> : vector<1024xi8> + + // Transfer to different lifetime (if needed) + // CHECK-NOT: stream.async.transfer + %transferred = stream.async.transfer %param : + !stream.resource{%c1024} -> !stream.resource{%c1024} + + // Transform transferred value + %c88_i32 = arith.constant 88 : i32 + // CHECK-NOT: stream.async.fill + %result = stream.async.fill %c88_i32, %transferred[%c0 to %c256 for %c256] : + i32 -> %transferred as !stream.resource{%c1024} + + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @param_transfer : !stream.resource + util.global.store %result, @param_transfer : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests parameter copying (combining two parameters). +// Copy operations should be hoisted to encoder. +// Real-world: Combining parameter shards into single buffer. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK: util.global private @param_copy_combine : !stream.resource +util.global private @param_copy_combine : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c0 = arith.constant 0 : index + %c512 = arith.constant 512 : index + %c1024 = arith.constant 1024 : index + + // Load two parameters that will be combined. + %param1 = stream.async.constant : !stream.resource{%c512} = + #stream.parameter.named<"model"::"shard0"> : vector<512xi8> + %param2 = stream.async.constant : !stream.resource{%c512} = + #stream.parameter.named<"model"::"shard1"> : vector<512xi8> + + // Create destination buffer. + %c0_i32 = arith.constant 0 : i32 + // CHECK-NOT: stream.async.splat + %combined = stream.async.splat %c0_i32 : i32 -> !stream.resource{%c1024} + + // Copy first parameter. + // CHECK-NOT: stream.async.copy + %with_first = stream.async.copy %param1[%c0 to %c512], %combined[%c0 to %c512], %c512 : + !stream.resource{%c512} -> %combined as !stream.resource{%c1024} + + // Copy second parameter. + // CHECK-NOT: stream.async.copy + %result = stream.async.copy %param2[%c0 to %c512], %with_first[%c512 to %c1024], %c512 : + !stream.resource{%c512} -> %with_first as !stream.resource{%c1024} + + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @param_copy_combine : !stream.resource + util.global.store %result, @param_copy_combine : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +//===----------------------------------------------------------------------===// +// Multiple Initializers +//===----------------------------------------------------------------------===// + +// Tests multiple initializers in same module. +// All should be processed independently. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK-DAG: util.global private @multi_init_a : !stream.resource +util.global private @multi_init_a : !stream.resource +// CHECK-DAG: util.global private @multi_init_b : !stream.resource +util.global private @multi_init_b : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c512 = arith.constant 512 : index + %param_a = stream.async.constant : !stream.resource{%c512} = + #stream.parameter.named<"model"::"init_a"> : vector<512xi8> + + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c11_i32 = arith.constant 11 : i32 + // CHECK-NOT: stream.async.fill + %result_a = stream.async.fill %c11_i32, %param_a[%c0 to %c256 for %c256] : + i32 -> %param_a as !stream.resource{%c512} + + // CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[C512:.+]] = arith.constant 512 : index + // CHECK: %[[PARAM_A:.+]] = stream.async.constant : !stream.resource{%[[C1024]]} = #stream.parameter.named<""::"parameter0"> + // CHECK: %[[RESULT_A:.+]] = stream.resource.subview %[[PARAM_A]][%[[C0]]] : !stream.resource{%[[C1024]]} -> !stream.resource{%[[C512]]} + // CHECK: util.global.store %[[RESULT_A]], @multi_init_a : !stream.resource + util.global.store %result_a, @multi_init_a : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// CHECK: util.initializer { +util.initializer { + %c512 = arith.constant 512 : index + %param_b = stream.async.constant : !stream.resource{%c512} = + #stream.parameter.named<"model"::"init_b"> : vector<512xi8> + + %c256 = arith.constant 256 : index + %c512_0 = arith.constant 512 : index + %c22_i32 = arith.constant 22 : i32 + // CHECK-NOT: stream.async.fill + %result_b = stream.async.fill %c22_i32, %param_b[%c256 to %c512_0 for %c256] : + i32 -> %param_b as !stream.resource{%c512} + + // CHECK-DAG: %[[C1024_0:.+]] = arith.constant 1024 : index + // CHECK-DAG: %[[C512_0:.+]] = arith.constant 512 : index + // CHECK: %[[PARAM_B:.+]] = stream.async.constant : !stream.resource{%[[C1024_0]]} = #stream.parameter.named<""::"parameter0"> + // CHECK: %[[RESULT_B:.+]] = stream.resource.subview %[[PARAM_B]][%[[C512_0]]] : !stream.resource{%[[C1024_0]]} -> !stream.resource{%[[C512_0]]} + // CHECK: util.global.store %[[RESULT_B]], @multi_init_b : !stream.resource + util.global.store %result_b, @multi_init_b : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +//===----------------------------------------------------------------------===// +// Resource Lifetime Tests +//===----------------------------------------------------------------------===// + +// Tests non-constant resource lifetime (should skip). + +// CHECK-LABEL: module { +// CHECK-NOT: module @encoder +// CHECK: util.global private @non_constant_lifetime : !stream.resource +util.global private @non_constant_lifetime : !stream.resource + +// CHECK: util.initializer { +util.initializer { + // CHECK: %[[C1024:.+]] = arith.constant 1024 : index + %c1024 = arith.constant 1024 : index + + // Transient resource (not constant) - should skip + // CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32 + %c0_i32 = arith.constant 0 : i32 + // CHECK: %[[TRANSIENT:.+]] = stream.async.splat %[[C0_I32]] : i32 -> !stream.resource{%[[C1024]]} + %transient = stream.async.splat %c0_i32 : i32 -> !stream.resource{%c1024} + + // CHECK: util.global.store %[[TRANSIENT]], @non_constant_lifetime : !stream.resource + util.global.store %transient, @non_constant_lifetime : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +//===----------------------------------------------------------------------===// +// Mode Testing: Consolidate vs Overlay +//===----------------------------------------------------------------------===// + +// Tests pass-through parameter in consolidate mode (default). +// A parameter loaded and stored directly with no transformation should be +// included in the encoder output when in consolidate mode. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all + +// CHECK-LABEL: module { +// CHECK: util.global private @passthrough_consolidate : !stream.resource +util.global private @passthrough_consolidate : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c1024 = arith.constant 1024 : index + + // Load parameter directly without transformation. + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"passthrough_param"> : vector<1024xi8> + + // Store directly - this is a pass-through (no transformation). + // In consolidate mode, this should be included in encoder output. + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @passthrough_consolidate + util.global.store %param, @passthrough_consolidate : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests pass-through parameter in overlay mode. +// Same as previous test but with overlay mode enabled. +// The parameter should NOT be included in the encoder output since it's +// unmodified (includeUnmodified=false in overlay mode). + +// Anchor to this specific test's main module +// OVERLAY-LABEL: util.global private @passthrough_overlay : !stream.resource +util.global private @passthrough_overlay : !stream.resource + +// OVERLAY: util.initializer { +util.initializer { + %c1024 = arith.constant 1024 : index + + // Load parameter directly without transformation. + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"passthrough_param_overlay"> : vector<1024xi8> + + // Store directly - pass-through with no transformation. + // In overlay mode, this should NOT be in encoder output. + // The original parameter load should remain unchanged. + // OVERLAY: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<"model"::"passthrough_param_overlay"> + // OVERLAY: util.global.store %[[PARAM]], @passthrough_overlay + util.global.store %param, @passthrough_overlay : !stream.resource + // OVERLAY: util.return + util.return + // OVERLAY: } +} + +// ----- + +// Tests mixed parameters in consolidate mode. +// One parameter with transformation, one pass-through. +// Consolidate mode should include both in encoder output. + +// CHECK-LABEL: util.global private @mixed_transformed : !stream.resource +util.global private @mixed_transformed : !stream.resource +// CHECK: util.global private @mixed_passthrough : !stream.resource +util.global private @mixed_passthrough : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c42_i32 = arith.constant 42 : i32 + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + // Parameter 1: Transformed with fill operation. + %param1 = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"mixed_param1"> : vector<1024xi8> + // CHECK-NOT: stream.async.fill + %filled = stream.async.fill %c42_i32, %param1[%c0 to %c256 for %c256] : i32 -> %param1 as !stream.resource{%c1024} + + // Parameter 2: Pass-through (no transformation). + %param2 = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"mixed_param2"> : vector<1024xi8> + + // In consolidate mode, both should be loaded from encoder output. + // CHECK-DAG: %[[C2048:.+]] = arith.constant 2048 : index + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index + // CHECK: %[[PARAM:.+]] = stream.async.constant : !stream.resource{%[[C2048]]} = #stream.parameter.named<""::"parameter0"> + // CHECK: %[[SUBVIEW1:.+]] = stream.resource.subview %[[PARAM]][%[[C0]]] : !stream.resource{%[[C2048]]} -> !stream.resource{%[[C1024]]} + // CHECK: util.global.store %[[SUBVIEW1]], @mixed_transformed + util.global.store %filled, @mixed_transformed : !stream.resource + + // CHECK: %[[PARAM_0:.+]] = stream.async.constant : !stream.resource{%[[C2048]]} = #stream.parameter.named<""::"parameter0"> + // CHECK: %[[SUBVIEW2:.+]] = stream.resource.subview %[[PARAM_0]][%[[C1024]]] : !stream.resource{%[[C2048]]} -> !stream.resource{%[[C1024]]} + // CHECK: util.global.store %[[SUBVIEW2]], @mixed_passthrough + util.global.store %param2, @mixed_passthrough : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests mixed parameters in overlay mode. +// One parameter with transformation, one pass-through. +// Overlay mode should only include the transformed parameter. + +// Anchor to the main module's first global to scope checks to this section +// OVERLAY-MIXED-LABEL: util.global private @mixed_transformed_overlay : !stream.resource +util.global private @mixed_transformed_overlay : !stream.resource +// OVERLAY-MIXED: util.global private @mixed_passthrough_overlay : !stream.resource +util.global private @mixed_passthrough_overlay : !stream.resource + +// OVERLAY-MIXED: util.initializer { +util.initializer { + %c42_i32 = arith.constant 42 : i32 + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + // Parameter 1: Transformed with fill operation. + %param1 = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"mixed_param1_overlay"> : vector<1024xi8> + // OVERLAY-MIXED-NOT: stream.async.fill + %filled = stream.async.fill %c42_i32, %param1[%c0 to %c256 for %c256] : i32 -> %param1 as !stream.resource{%c1024} + + // Parameter 2: Pass-through (no transformation). + %param2 = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"mixed_param2_overlay"> : vector<1024xi8> + + // Overlay mode: transformed parameter from encoder, pass-through from original. + // Parameters can be loaded in any order (SSA), use DAG to allow flexibility. + // OVERLAY-MIXED-DAG: %{{.+}} = stream.async.constant {{.+}} #stream.parameter.named<"model"::"mixed_param2_overlay"> + // OVERLAY-MIXED-DAG: %{{.+}} = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + + // Stores should happen in this order. + // OVERLAY-MIXED: util.global.store %{{.+}}, @mixed_transformed_overlay + util.global.store %filled, @mixed_transformed_overlay : !stream.resource + + // OVERLAY-MIXED: util.global.store %{{.+}}, @mixed_passthrough_overlay + util.global.store %param2, @mixed_passthrough_overlay : !stream.resource + // OVERLAY-MIXED: util.return + util.return + // OVERLAY-MIXED: } +} + +// ----- + +// Tests side-by-side mode comparison. +// Same input tested with both consolidate and overlay modes using different +// check prefixes to verify behavioral differences. + +// Anchor to this test's unique globals. +// COMPARE-CONSOLIDATE-LABEL: util.global private @compare_transformed : !stream.resource +// COMPARE-OVERLAY-LABEL: util.global private @compare_transformed : !stream.resource +util.global private @compare_transformed : !stream.resource +// COMPARE-CONSOLIDATE: util.global private @compare_passthrough : !stream.resource +// COMPARE-OVERLAY: util.global private @compare_passthrough : !stream.resource +util.global private @compare_passthrough : !stream.resource + +// COMPARE-CONSOLIDATE: util.initializer { +// COMPARE-OVERLAY: util.initializer { +util.initializer { + %c42_i32 = arith.constant 42 : i32 + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + // Transformed parameter. + %param1 = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"compare_param1"> : vector<1024xi8> + // COMPARE-CONSOLIDATE-NOT: stream.async.fill + // COMPARE-OVERLAY-NOT: stream.async.fill + %filled = stream.async.fill %c42_i32, %param1[%c0 to %c256 for %c256] : i32 -> %param1 as !stream.resource{%c1024} + + // Pass-through parameter. + %param2 = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"compare_param2"> : vector<1024xi8> + + // Consolidate: Both from encoder output, packed into single parameter0, then subviewed. + // Just verify key operations exist without strict ordering. + // COMPARE-CONSOLIDATE-DAG: stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // COMPARE-CONSOLIDATE-DAG: stream.resource.subview + // COMPARE-CONSOLIDATE-DAG: util.global.store %{{.+}}, @compare_transformed + util.global.store %filled, @compare_transformed : !stream.resource + + // COMPARE-CONSOLIDATE-DAG: util.global.store %{{.+}}, @compare_passthrough + + // Overlay: Transformed from encoder (parameter0), pass-through from original. + // COMPARE-OVERLAY-DAG: stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // COMPARE-OVERLAY-DAG: stream.async.constant {{.+}} #stream.parameter.named<"model"::"compare_param2"> + // COMPARE-OVERLAY-DAG: util.global.store %{{.+}}, @compare_transformed + + // COMPARE-OVERLAY-DAG: util.global.store %{{.+}}, @compare_passthrough + util.global.store %param2, @compare_passthrough : !stream.resource + + // COMPARE-CONSOLIDATE: util.return + // COMPARE-OVERLAY: util.return + util.return + // COMPARE-CONSOLIDATE: } + // COMPARE-OVERLAY: } +} + +// ----- + +// Tests custom output scope. +// Verifies that the encoder uses a custom scope name instead of default "encoded". + +// Anchor to this test's unique global. +// SCOPE-LABEL: util.global private @custom_scope_global : !stream.resource +util.global private @custom_scope_global : !stream.resource + +// SCOPE: util.initializer { +util.initializer { + %c42_i32 = arith.constant 42 : i32 + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + // Parameter with transformation. + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"custom_scope_param"> : vector<1024xi8> + // SCOPE-NOT: stream.async.fill + %filled = stream.async.fill %c42_i32, %param[%c0 to %c256 for %c256] : i32 -> %param as !stream.resource{%c1024} + + // Should load from custom scope "my_custom_scope" instead of default "encoded". + // SCOPE: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<"my_custom_scope"::"parameter0"> + // SCOPE: util.global.store %[[PARAM]], @custom_scope_global + util.global.store %filled, @custom_scope_global : !stream.resource + // SCOPE: util.return + util.return + // SCOPE: } +} + +// ----- + +// Tests growth factor threshold with increased limit. +// A parameter that grows 1.8x should be rejected with default threshold (1.2x) +// but accepted with custom threshold (2.0x). + +// Anchor to this test's unique global. +// GROWTH2-LABEL: util.global private @growth_factor_test : !stream.resource +util.global private @growth_factor_test : !stream.resource + +// GROWTH2: util.initializer { +util.initializer { + %c42_i32 = arith.constant 42 : i32 + %c0 = arith.constant 0 : index + %c1000 = arith.constant 1000 : index + %c1800 = arith.constant 1800 : index + + // Parameter that grows from 1000 bytes (input) to 1800 bytes (after fill/pad). + // 1.8x growth exceeds default 1.2x threshold but passes with 2.0x threshold. + %param = stream.async.constant : !stream.resource{%c1000} = + #stream.parameter.named<"model"::"growth_param"> : vector<1000xi8> + + // Fill operation that expands the parameter size (1000 -> 1800 bytes). + // GROWTH2-NOT: stream.async.fill + %expanded = stream.async.fill %c42_i32, %param[%c0 to %c1800 for %c1800] : i32 -> %param as !stream.resource{%c1800} + + // With growth factor 2.0, this should be hoisted (1.8x < 2.0). + // Verify transformation was hoisted: parameter loads from encoder output. + // GROWTH2-DAG: stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // GROWTH2-DAG: util.global.store %{{.+}}, @growth_factor_test + util.global.store %expanded, @growth_factor_test : !stream.resource + // GROWTH2: util.return + util.return + // GROWTH2: } +} + +// ----- + +// Tests empty encoder module in overlay mode. +// When all parameters are pass-through (no transformations) and in overlay mode, +// no encoder module should be generated since there's nothing to encode. + +// Anchor to this test's unique global. +// EMPTY-LABEL: util.global private @empty_test_1 : !stream.resource +util.global private @empty_test_1 : !stream.resource +// EMPTY: util.global private @empty_test_2 : !stream.resource +util.global private @empty_test_2 : !stream.resource + +// EMPTY: util.initializer { +util.initializer { + %c512 = arith.constant 512 : index + %c1024 = arith.constant 1024 : index + + // Pass-through parameter 1 (no transformation). + %param1 = stream.async.constant : !stream.resource{%c512} = + #stream.parameter.named<"model"::"empty_param1"> : vector<512xi8> + + // Pass-through parameter 2 (no transformation). + %param2 = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"empty_param2"> : vector<1024xi8> + + // Both should load from original parameters (no encoder output). + // EMPTY-DAG: stream.async.constant {{.+}} #stream.parameter.named<"model"::"empty_param1"> + // EMPTY-DAG: stream.async.constant {{.+}} #stream.parameter.named<"model"::"empty_param2"> + // EMPTY-DAG: util.global.store %{{.+}}, @empty_test_1 + // EMPTY-DAG: util.global.store %{{.+}}, @empty_test_2 + util.global.store %param1, @empty_test_1 : !stream.resource + + util.global.store %param2, @empty_test_2 : !stream.resource + + // EMPTY: util.return + util.return + // EMPTY: } +} + +// ----- + +//===----------------------------------------------------------------------===// +// Multi-Block Slice Ordering Tests +//===----------------------------------------------------------------------===// +// These tests exercise the slice ordering logic when operations span multiple +// blocks or regions. The backward slice collection must maintain proper +// topological order even when captured values from nested regions are involved. + +// Tests that captured values from scf.if regions are handled correctly. +// This exercises the multi-root slice ordering logic where values defined +// outside an scf.if are used inside its regions. + +stream.executable private @captured_dispatch { + stream.executable.export public @dispatch +} + +// CHECK-LABEL: util.global private @captured_value_if_ordering : !stream.resource +util.global private @captured_value_if_ordering : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c1024 = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // Load parameter (will be in slice). + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"captured_param"> : vector<1024xi8> + + // Value defined outside scf.if but captured inside - this tests that + // the slice ordering handles captured values correctly. + %outside_value = arith.constant 42 : i32 + + // The scf.if captures %outside_value and %param from outside. + // When building the backward slice, we collect both the stored value + // and captured values. The ordering must ensure %outside_value's producer + // (arith.constant) comes before any op inside the region that uses it. + %cond = arith.constant true + // CHECK-NOT: scf.if + %result = scf.if %cond -> !stream.resource { + // Uses %outside_value (captured) and %param. + %filled = stream.async.fill %outside_value, %param[%c0 to %c1024 for %c1024] + : i32 -> %param as !stream.resource{%c1024} + scf.yield %filled : !stream.resource + } else { + scf.yield %param : !stream.resource + } + + // Encoder should transform this to load from encoded parameter. + // CHECK: stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %{{.+}}, @captured_value_if_ordering + util.global.store %result, @captured_value_if_ordering : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests captured values with scf.for loop. +// Similar to the scf.if test but with loop-carried values. + +stream.executable private @for_dispatch { + stream.executable.export public @dispatch +} + +// CHECK-LABEL: util.global private @captured_value_for_ordering : !stream.resource +util.global private @captured_value_for_ordering : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c1024 = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %c256 = arith.constant 256 : index + + // Load parameter. + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"for_param"> : vector<1024xi8> + + // Value captured by the loop body. + %fill_pattern = arith.constant 7 : i32 + + // Loop that captures %fill_pattern from outside. + // CHECK-NOT: scf.for + %result = scf.for %i = %c0 to %c3 step %c1 iter_args(%acc = %param) -> !stream.resource { + // Uses captured %fill_pattern. + %offset = arith.muli %i, %c256 : index + %end = arith.addi %offset, %c256 : index + %filled = stream.async.fill %fill_pattern, %acc[%offset to %end for %c256] + : i32 -> %acc as !stream.resource{%c1024} + scf.yield %filled : !stream.resource + } + + // CHECK: stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %{{.+}}, @captured_value_for_ordering + util.global.store %result, @captured_value_for_ordering : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests nested scf.if with dispatch that uses multiple captured values. +// This more complex case exercises ordering across multiple region levels. + +stream.executable private @nested_dispatch { + stream.executable.export public @compute +} + +// CHECK-LABEL: util.global private @nested_captured_ordering : !stream.resource +util.global private @nested_captured_ordering : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c2048 = arith.constant 2048 : index + %c1024 = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // Load two parameters that will both be used inside nested regions. + %param_a = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"nested_a"> : vector<1024xi8> + %param_b = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"nested_b"> : vector<1024xi8> + + // Dispatch using both parameters - creates slice with multiple inputs. + // CHECK-NOT: stream.async.dispatch + %combined = stream.async.dispatch @nested_dispatch::@compute[%c1, %c1, %c1]( + %param_a[%c0 to %c1024 for %c1024], + %param_b[%c0 to %c1024 for %c1024] + ) : (!stream.resource{%c1024}, !stream.resource{%c1024}) -> !stream.resource{%c2048} + + // CHECK: stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %{{.+}}, @nested_captured_ordering + util.global.store %combined, @nested_captured_ordering : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests that an empty module with no parameter expressions runs cleanly. +// This verifies that when no output file is specified (the default) and +// no encoding work is found, the pass completes without errors. + +// CHECK-LABEL: module { +// CHECK-NOT: module @encoder +module { + // A simple global that doesn't involve any parameters. + // CHECK: util.global private @no_params : i32 + util.global private @no_params : i32 + util.initializer { + // CHECK: %[[C42:.+]] = arith.constant 42 : i32 + %c42 = arith.constant 42 : i32 + // CHECK: util.global.store %[[C42]], @no_params + util.global.store %c42, @no_params : i32 + // CHECK: util.return + util.return + } +} + +// ----- + +//===----------------------------------------------------------------------===// +// External Timepoint Synchronization Tests +//===----------------------------------------------------------------------===// + +// Tests that when a parameter load awaits on an external timepoint, the +// replacement async.constant also awaits on that timepoint. +// This exercises Source A of collectExternalTimepoints: external await +// timepoints from TimelineOpInterface ops in the expression. + +// CHECK: module @encoder +// CHECK-LABEL: module { +// CHECK: util.global private @param_with_external_await : !stream.resource +util.global private @param_with_external_await : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c1024 = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c42_i32 = arith.constant 42 : i32 + + // External timeline op that produces a timepoint we must wait on. + // This is NOT part of the encoding expression (doesn't feed into the store). + // CHECK: %[[EXTERNAL_RESOURCE:.+]], %[[EXTERNAL_TP:.+]] = stream.test.timeline_op + %external_resource, %external_tp = stream.test.timeline_op + with() : () -> !stream.resource{%c1024} => !stream.timepoint + + // Parameter load that awaits on the external timepoint. + // The expression starts here - this op and the fill below form the expression. + %param = stream.async.constant await(%external_tp) : + !stream.resource{%c1024} = + #stream.parameter.named<"model"::"awaiting_param"> : vector<1024xi8> + + // Transform the parameter so it gets hoisted. + // CHECK-NOT: stream.async.fill + %filled = stream.async.fill %c42_i32, %param[%c0 to %c256 for %c256] : + i32 -> %param as !stream.resource{%c1024} + + // The replacement should await on the external timepoint. + // CHECK: %[[PARAM:.+]] = stream.async.constant await(%[[EXTERNAL_TP]]) + // CHECK-SAME: #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @param_with_external_await + util.global.store %filled, @param_with_external_await : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests that when a parameter load awaits on a joined timepoint from multiple +// external timeline ops, the replacement async.constant awaits on that same +// joined timepoint. This exercises the case where the join is in the expression +// slice but is not a resource contributor (it only produces a timepoint). + +// CHECK: module @encoder +// CHECK-LABEL: module { +// CHECK: util.global private @param_with_joined_external_timepoints : !stream.resource +util.global private @param_with_joined_external_timepoints : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c1024 = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c42_i32 = arith.constant 42 : i32 + + // Two external timeline ops that produce timepoints we must wait on. + // Their resources are unused, so they're not resource contributors. + // CHECK-DAG: %[[EXT_R1:.+]], %[[EXT_TP1:.+]] = stream.test.timeline_op + %ext_r1, %ext_tp1 = stream.test.timeline_op + with() : () -> !stream.resource{%c1024} => !stream.timepoint + // CHECK-DAG: %[[EXT_R2:.+]], %[[EXT_TP2:.+]] = stream.test.timeline_op + %ext_r2, %ext_tp2 = stream.test.timeline_op + with() : () -> !stream.resource{%c1024} => !stream.timepoint + + // Join the timepoints. The join is in the expression but doesn't contribute + // resources, so its result timepoint should be considered external. + // CHECK: %[[JOINED_TP:.+]] = stream.timepoint.join max(%[[EXT_TP1]], %[[EXT_TP2]]) => !stream.timepoint + %joined_tp = stream.timepoint.join max(%ext_tp1, %ext_tp2) => !stream.timepoint + + // Parameter load that awaits on the joined timepoint. + %param = stream.async.constant await(%joined_tp) : + !stream.resource{%c1024} = + #stream.parameter.named<"model"::"joined_await_param"> : vector<1024xi8> + + // Transform the parameter so it gets hoisted. + // CHECK-NOT: stream.async.fill + %filled = stream.async.fill %c42_i32, %param[%c0 to %c256 for %c256] : + i32 -> %param as !stream.resource{%c1024} + + // The replacement should await on the same joined timepoint. + // CHECK: %[[PARAM:.+]] = stream.async.constant await(%[[JOINED_TP]]) + // CHECK-SAME: #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @param_with_joined_external_timepoints + util.global.store %filled, @param_with_joined_external_timepoints : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/unify_encoding_for_globals.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/unify_encoding_for_globals.mlir index 1e2640f19af7..296bb1cb5a75 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/unify_encoding_for_globals.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/unify_encoding_for_globals.mlir @@ -859,3 +859,105 @@ util.initializer { util.return } + +// ----- + +// Test: TensorCloneOp, TensorEncodeOp, and TensorUpdateOp in dispatch site. + +#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {iree.encoding.resolver = #iree_encoding.specialization_resolver<123>}> +#device_target_local = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device +// CHECK-DAG: #[[$ENC:.+]] = #iree_encoding.testing]> +// CHECK-DAG: #[[$ENC2:.+]] = #iree_encoding.testing]> +#encoding1 = #iree_encoding.testing]> +#encoding2 = #iree_encoding.testing]> + +// CHECK: util.global private @[[$DEVICE_A:.+]] = +util.global private @device_a = #device_target_local +util.global private @weight : !stream.resource +util.global private @weight_size : index +util.global private @encoded_v1 : !stream.resource +util.global private @encoded_v1_size : index +util.global private @encoded_v2 : !stream.resource +util.global private @encoded_v2_size : index + +// CHECK: util.initializer +util.initializer { + %cst = stream.tensor.constant on(#hal.device.affinity<@device_a>) : tensor<4096x4096xf32> in !stream.resource = #stream.parameter.named<"model"::"weight"> : tensor<4096x4096xf32> + %0 = stream.resource.size %cst : !stream.resource + util.global.store %cst, @weight : !stream.resource + util.global.store %0, @weight_size : index + // CHECK: %[[SOURCE:.+]] = util.global.load @weight + %source = util.global.load @weight : !stream.resource + %source_size = util.global.load @weight_size : index + + // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #iree_encoding.specialized<123>> + // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) %[[SOURCE]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.specialized<123>> + %size1 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4096x4096xf32, #encoding1> : index + %enc1 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %source : tensor<4096x4096xf32> in !stream.resource{%source_size} -> tensor<4096x4096xf32, #encoding1> in !stream.resource{%size1} + util.global.store %enc1, @encoded_v1 : !stream.resource + util.global.store %size1, @encoded_v1_size : index + + // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #iree_encoding.specialized<123>> + // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) %[[SOURCE]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.specialized<123>> + %size2 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4096x4096xf32, #encoding2> : index + %enc2 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %source : tensor<4096x4096xf32> in !stream.resource{%source_size} -> tensor<4096x4096xf32, #encoding2> in !stream.resource{%size2} + util.global.store %enc2, @encoded_v2 : !stream.resource + util.global.store %size2, @encoded_v2_size : index + + util.return +} + +// CHECK-LABEL: util.func public @tensor_clone_reencode +util.func public @tensor_clone_reencode(%arg0: !stream.resource<*>, %arg1: !stream.resource<*>, %arg2: index) { + %loaded_v1 = util.global.load @encoded_v1 : !stream.resource + %loaded_v1_size = util.global.load @encoded_v1_size : index + + // Re-encode should be inserted before the clone op. + // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #iree_encoding.specialized<123>> + // CHECK: %[[REENC:.+]] = stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) + // CHECK-SAME: tensor<4096x4096xf32, #iree_encoding.specialized<123>> + // CHECK-SAME: -> tensor<4096x4096xf32, #[[$ENC]]> + // CHECK: stream.tensor.clone on(#hal.device.affinity<@[[$DEVICE_A]]>) %[[REENC]] + // CHECK-SAME: tensor<4096x4096xf32, #[[$ENC]]> + %0 = stream.tensor.clone on(#hal.device.affinity<@device_a>) %loaded_v1 + : tensor<4096x4096xf32, #encoding1> in !stream.resource{%loaded_v1_size} + -> tensor<4096x4096xf32, #encoding1> in !stream.resource<*>{%loaded_v1_size} + + util.return +} + +// CHECK-LABEL: util.func public @tensor_encode_update_source +util.func public @tensor_encode_update_source(%arg0: !stream.resource<*>, %arg1: !stream.resource<*>, %arg2: index) { + %loaded_v1 = util.global.load @encoded_v1 : !stream.resource + %loaded_v1_size = util.global.load @encoded_v1_size : index + + // The encode op's source_encoding should be updated to the unified encoding. + // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) + // CHECK-SAME: tensor<4096x4096xf32, #iree_encoding.specialized<123>> + // CHECK-SAME: -> tensor<4096x4096xf32, #[[$ENC2]]> + %1 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %loaded_v1 + : tensor<4096x4096xf32, #encoding1> in !stream.resource{%loaded_v1_size} + -> tensor<4096x4096xf32, #encoding2> in !stream.resource<*>{%loaded_v1_size} + + util.return +} + +// CHECK-LABEL: util.func public @tensor_update_reencode +util.func public @tensor_update_reencode(%arg0: !stream.resource<*>, %arg1: !stream.resource<*>, %arg2: index) { + %loaded_v1 = util.global.load @encoded_v1 : !stream.resource + %loaded_v1_size = util.global.load @encoded_v1_size : index + %c0 = arith.constant 0 : index + + // Re-encode should be inserted before the update op. + // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #iree_encoding.specialized<123>> + // CHECK: %[[REENC:.+]] = stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) + // CHECK-SAME: tensor<4096x4096xf32, #iree_encoding.specialized<123>> + // CHECK-SAME: -> tensor<4096x4096xf32, #[[$ENC]]> + // CHECK: stream.tensor.update on(#hal.device.affinity<@[[$DEVICE_A]]>) %[[REENC]] + // CHECK-SAME: tensor<4096x4096xf32, #[[$ENC]]> + %2 = stream.tensor.update on(#hal.device.affinity<@device_a>) + %loaded_v1, %arg0[%c0, %c0] : tensor<4096x4096xf32, #encoding1> in !stream.resource{%loaded_v1_size} + -> tensor<4096x4096xf32, #encoding1> in %arg0 as !stream.resource<*>{%arg2} + + util.return +} diff --git a/compiler/src/iree/compiler/Dialect/TensorExt/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/TensorExt/IR/BUILD.bazel index a91cd2f27e82..a278378e9ce0 100644 --- a/compiler/src/iree/compiler/Dialect/TensorExt/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/TensorExt/IR/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "TensorExtAttrs.td", "TensorExtBase.td", diff --git a/compiler/src/iree/compiler/Dialect/TensorExt/IR/TensorExtOpFolders.cpp b/compiler/src/iree/compiler/Dialect/TensorExt/IR/TensorExtOpFolders.cpp index 93082e5a23ef..04faecbfeb40 100644 --- a/compiler/src/iree/compiler/Dialect/TensorExt/IR/TensorExtOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/TensorExt/IR/TensorExtOpFolders.cpp @@ -22,8 +22,9 @@ struct ReplaceBitCastIfTensorOperandEmpty final : OpRewritePattern { PatternRewriter &rewriter) const override { auto emptyOp = dyn_cast_if_present(op.getSource().getDefiningOp()); - if (!emptyOp) + if (!emptyOp) { return failure(); + } rewriter.replaceOpWithNewOp(op, op.getResult().getType(), op.getResultDims()); return success(); @@ -36,8 +37,9 @@ struct BitCastOfTensorCastStaticInfo final : OpRewritePattern { LogicalResult matchAndRewrite(BitCastOp bitcastOp, PatternRewriter &rewriter) const final { auto tensorCastOp = bitcastOp.getSource().getDefiningOp(); - if (!tensorCastOp) + if (!tensorCastOp) { return failure(); + } auto tensorCastSrcType = dyn_cast(tensorCastOp.getOperand().getType()); if (!tensorCastSrcType) { @@ -66,8 +68,9 @@ struct BitCastOfTensorCastStaticInfo final : OpRewritePattern { // Drop the dynamic dims that become static after incorporating the cast. for (auto [castSize, sourceSize] : llvm::zip_equal( tensorCastSrcType.getShape(), intermediateTensorType.getShape())) { - if (!ShapedType::isDynamic(sourceSize)) + if (!ShapedType::isDynamic(sourceSize)) { continue; + } while (!ShapedType::isDynamic(resShape[resDynamicDim])) { ++resDynamicDim; @@ -135,8 +138,9 @@ static bool updateTensorOpDims(RewriterBase &rewriter, Operation *op, MutableOperandRange mutableDimValues) { auto dynamicDimsOr = IREE::Util::findDynamicDims(tensorValue, op->getBlock(), Block::iterator(op)); - if (!dynamicDimsOr.has_value()) + if (!dynamicDimsOr.has_value()) { return false; + } auto dynamicDims = dynamicDimsOr.value(); bool anyChanged = false; OperandRange oldValueRange = mutableDimValues; @@ -235,8 +239,9 @@ canonicalizeSubViewParts(OpTy op, RankedTensorType sliceType, llvm::SmallVector newShape; llvm::SmallBitVector droppedDims = op.getDroppedDims(); for (auto size : llvm::enumerate(mixedSizes)) { - if (droppedDims.test(size.index())) + if (droppedDims.test(size.index())) { continue; + } std::optional staticSize = getConstantIntValue(size.value()); newShape.push_back(staticSize ? staticSize.value() : ShapedType::kDynamic); } @@ -256,8 +261,9 @@ struct DispatchTensorLoadOpWithOffsetSizesAndStridesConstantArgumentFolder final RankedTensorType resultType = loadOp.getType(); auto newResultType = canonicalizeSubViewParts( loadOp, resultType, mixedOffsets, mixedSizes, mixedStrides); - if (failed(newResultType)) + if (failed(newResultType)) { return failure(); + } // We need to resolve the new inferred type with the specified type. Location loc = loadOp.getLoc(); @@ -355,8 +361,9 @@ struct DispatchTensorStoreOpWithOffsetSizesAndStridesConstantArgumentFolder RankedTensorType valueType = storeOp.getValueType(); auto newValueType = canonicalizeSubViewParts( storeOp, valueType, mixedOffsets, mixedSizes, mixedStrides); - if (failed(newValueType)) + if (failed(newValueType)) { return failure(); + } Value value = storeOp.getValue(); Location loc = storeOp.getLoc(); diff --git a/compiler/src/iree/compiler/Dialect/TensorExt/IR/TensorExtTypes.cpp b/compiler/src/iree/compiler/Dialect/TensorExt/IR/TensorExtTypes.cpp index d6f62207d57c..b70ff74860c3 100644 --- a/compiler/src/iree/compiler/Dialect/TensorExt/IR/TensorExtTypes.cpp +++ b/compiler/src/iree/compiler/Dialect/TensorExt/IR/TensorExtTypes.cpp @@ -58,8 +58,9 @@ int64_t DispatchTensorType::getNumElements() const { assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); auto shape = getShape(); int64_t num = 1; - for (auto dim : shape) + for (auto dim : shape) { num *= dim; + } return num; } @@ -197,10 +198,12 @@ void printType(DispatchTensorType &type, DialectAsmPrinter &p) { Type IREETensorExtDialect::parseType(DialectAsmParser &parser) const { StringRef mnemonic; - if (parser.parseKeyword(&mnemonic)) + if (parser.parseKeyword(&mnemonic)) { return {}; - if (mnemonic == "dispatch.tensor") + } + if (mnemonic == "dispatch.tensor") { return DispatchTensorType::parse(parser); + } parser.emitError(parser.getCurrentLocation()) << "unknown TensorExt type: " << mnemonic; return {}; diff --git a/compiler/src/iree/compiler/Dialect/TensorExt/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/TensorExt/IR/test/BUILD.bazel index 9b34141476ab..41a26b2ff6ad 100644 --- a/compiler/src/iree/compiler/Dialect/TensorExt/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/TensorExt/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "dispatch_tensor_folding.mlir", "dispatch_workload_ordinal_folding.mlir", diff --git a/compiler/src/iree/compiler/Dialect/TensorExt/Transforms/Folders.cpp b/compiler/src/iree/compiler/Dialect/TensorExt/Transforms/Folders.cpp index 72dd9f38d568..591fc681a6a3 100644 --- a/compiler/src/iree/compiler/Dialect/TensorExt/Transforms/Folders.cpp +++ b/compiler/src/iree/compiler/Dialect/TensorExt/Transforms/Folders.cpp @@ -23,8 +23,9 @@ struct FoldTensorLoadWithExtractSlice auto dispatchTensorLoadOp = extractSliceOp.getSource() .getDefiningOp(); - if (!dispatchTensorLoadOp) + if (!dispatchTensorLoadOp) { return failure(); + } SmallVector offsets, sizes, strides; // `tensor.extract_slice` (i.e. the producer) folds **into** @@ -56,8 +57,9 @@ struct FoldInsertSliceWithTensorStoreOp PatternRewriter &rewriter) const override { auto insertSliceOp = dispatchTensorStoreOp.getValue().getDefiningOp(); - if (!insertSliceOp) + if (!insertSliceOp) { return failure(); + } SmallVector offsets, sizes, strides; // `tensor.insert_slice` (i.e. the producer) folds **into** diff --git a/compiler/src/iree/compiler/Dialect/TensorExt/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/TensorExt/Transforms/test/BUILD.bazel index 2b205eaaa6ef..83ce0a2d4733 100644 --- a/compiler/src/iree/compiler/Dialect/TensorExt/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/TensorExt/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "sparse_interface_methods.mlir", "sparse_interface_methods_estimated_loop_range_fail.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes/Range.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes/Range.cpp index 62c5f9ac8b23..074e028fbf3b 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes/Range.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes/Range.cpp @@ -43,8 +43,9 @@ void FloatRangeStats::addDomainValue(double value) { } std::string FloatRangeStats::getAsStr(AsmState &asmState) const { - if (!valid) + if (!valid) { return std::string("<>"); + } std::string s("["); s += std::to_string(minValue); s += ", "; @@ -192,8 +193,9 @@ ChangeStatus FloatRangeValueElement::updateValue(Value value, newState ^= inner; // Stop traversal if tied OpOperand is not used in the op body. if (!linalgOp.payloadUsesValueFromOperand( - linalgOp.getDpsInitOperand(result.getResultNumber()))) + linalgOp.getDpsInitOperand(result.getResultNumber()))) { return WalkResult::skip(); + } return WalkResult::advance(); } else if (auto minfOp = dyn_cast(definingOp)) { auto lhs = solver.getElementFor( diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.cpp index 586809176887..b03e8116379d 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.cpp @@ -24,8 +24,9 @@ namespace mlir::iree_compiler::IREE::Util { static OpOperand *findOperandFor(Operation *op, Value input) { for (OpOperand &operand : op->getOpOperands()) { - if (operand.get() == input) + if (operand.get() == input) { return &operand; + } } return nullptr; } @@ -33,10 +34,12 @@ static OpOperand *findOperandFor(Operation *op, Value input) { bool ConstExprAnalysis::isConstExprOperation(Operation *queryOp) const { if (queryOp->getNumResults() == 0) { bool hasNoMemoryEffects = false; - if (auto effectOp = dyn_cast(queryOp)) + if (auto effectOp = dyn_cast(queryOp)) { hasNoMemoryEffects = effectOp.hasNoEffect(); - if (hasNoMemoryEffects && queryOp->hasTrait()) + } + if (hasNoMemoryEffects && queryOp->hasTrait()) { return true; + } return false; } // NOTE: this only checks the first result as all results are added to the map @@ -79,25 +82,31 @@ ConstExprAnalysis::ConstExprAnalysis(Operation *rootOp) // such as if they are initialized based on values only available at runtime. explorer.forEachGlobal([&](const Explorer::GlobalInfo *info) { // Rely on globals having been canonicalized to immutable correctly. - if (info->isIndirect || info->op.isGlobalMutable()) + if (info->isIndirect || info->op.isGlobalMutable()) { return; - if (!isLegalConstExprRootType(info->op.getGlobalType())) + } + if (!isLegalConstExprRootType(info->op.getGlobalType())) { return; - for (auto loadOp : info->getLoads()) + } + for (auto loadOp : info->getLoads()) { constantRoots[loadOp.getLoadedGlobalValue()] = loadOp; + } }); // Populate the constant roots for all inline constants in the program. explorer.forEachFunctionLikeOp([&](FunctionOpInterface funcOp) { funcOp.walk([&](Operation *op) { - if (!op->hasTrait()) + if (!op->hasTrait()) { return; + } for (auto resultType : op->getResultTypes()) { - if (!isLegalConstExprRootType(resultType)) + if (!isLegalConstExprRootType(resultType)) { return; + } } - for (auto result : op->getResults()) + for (auto result : op->getResults()) { constantRoots[result] = op; + } }); }); @@ -135,8 +144,9 @@ ConstExprAnalysis::ConstExprAnalysis(Operation *rootOp) iterWorklist.clear(); iterWorklist.swap(worklist); for (ConstValueInfo *info : iterWorklist) { - if (info->state != ConstValueInfo::UNKNOWN) + if (info->state != ConstValueInfo::UNKNOWN) { continue; + } bool allConstants = true; for (ConstValueInfo *producerInfo : info->producers) { assert(producerInfo->state != ConstValueInfo::UNANALYZED && @@ -220,12 +230,14 @@ void ConstExprAnalysis::expandToOpStep( ConstExprOpInfo opInfo = ConstExprOpInfo::getForOp(op); for (auto result : op->getResults()) { auto *valueInfo = constInfoMap.lookup(result); - if (valueInfo && valueInfo->state != ConstValueInfo::UNANALYZED) + if (valueInfo && valueInfo->state != ConstValueInfo::UNANALYZED) { continue; + } // Generate new info record. - if (!valueInfo) + if (!valueInfo) { valueInfo = addInfo(result); + } // Update the producers first as we might early-return below. for (Value producer : opInfo.producers) { @@ -288,8 +300,9 @@ void ConstExprAnalysis::expandToOpStep( void ConstExprAnalysis::print(raw_ostream &os) const { os << "[ConstExprAnalysis] found constants:\n"; for (auto &info : allocedConstInfos) { - if (info->state != ConstValueInfo::CONSTANT || info->isRoot) + if (info->state != ConstValueInfo::CONSTANT || info->isRoot) { continue; + } if (!info->roots.empty()) { os << "\n[ConstExprAnalysis] constexpr "; info->constValue.print(os, asmState); @@ -334,8 +347,9 @@ void ConstExprHoistingPolicy::initialize() { for (auto &it : analysis.allocedConstInfos) { auto *info = it.get(); // Skip unanalyzed values. - if (info->state == ConstExprAnalysis::ConstValueInfo::UNANALYZED) + if (info->state == ConstExprAnalysis::ConstValueInfo::UNANALYZED) { continue; + } worklist.push_back(info); } @@ -366,8 +380,9 @@ void ConstExprHoistingPolicy::initialize() { bool madeChange = false; for (auto *info : worklist) { Decision *decision = getDecision(info); - if (decision->getOutcome() != UNDECIDED) + if (decision->getOutcome() != UNDECIDED) { continue; + } makeDecision(info, decision); if (decision->getOutcome() != UNDECIDED) { @@ -481,8 +496,9 @@ void ConstExprHoistingPolicy::makeDecision( if (!hasLegalEscape) { for (auto *consumerInfo : info->consumers) { Decision *consumerDecision = getDecision(consumerInfo); - if (consumerDecision->getOutcome() != DISABLE_HOIST) + if (consumerDecision->getOutcome() != DISABLE_HOIST) { continue; + } Operation *consumerOp = consumerInfo->getOperation(); OpOperand *consumerOperand = findOperandFor(consumerOp, info->constValue); @@ -544,13 +560,15 @@ struct DOTGraphTraits getNodeAttributes(const ConstExprAnalysis::ConstValueInfo *Node, const ConstExprHoistingPolicy *g) { // Roots are colored red. - if (Node->isRoot) + if (Node->isRoot) { return "fillcolor=red,style=filled"; + } // Hoisted values are colored green. ConstExprHoistingPolicy::Outcome outcome = g->getOutcome(Node); - if (outcome == ConstExprHoistingPolicy::Outcome::ENABLE_HOIST) + if (outcome == ConstExprHoistingPolicy::Outcome::ENABLE_HOIST) { return "fillcolor=green,style=filled"; + } return ""; } diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h index 9c0bb4290767..235743357bf9 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h @@ -44,8 +44,9 @@ class ConstExprAnalysis { // uninitialized. If they are all initialized, then they will either be all // const-expr or all non const-expr, so just return the first result's info. const ConstValueInfo *lookup(Operation *queryOp) const { - if (queryOp->getNumResults() == 0) + if (queryOp->getNumResults() == 0) { return nullptr; + } if (llvm::any_of(queryOp->getResults(), [&](Value v) { return !lookup(v); })) { return nullptr; @@ -56,8 +57,9 @@ class ConstExprAnalysis { // Returns true if the given value is only derived from immutable inputs. bool isConstExprValue(Value queryValue) const { ConstValueInfo *found = constInfoMap.lookup(queryValue); - if (!found) + if (!found) { return false; + } return found->state == ConstValueInfo::CONSTANT; } diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.cpp index da22db8f0bcd..4045bace9b0a 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.cpp @@ -97,8 +97,9 @@ static bool isEligibleConstExpr(Operation *op) { Operation *parent = op; while (auto hoistableParent = parent->getParentOfType()) { - if (hoistableParent.isAtomicallyHoistableOp()) + if (hoistableParent.isAtomicallyHoistableOp()) { return false; + } parent = hoistableParent; } @@ -154,8 +155,9 @@ bool isHoistableConstExprLeaf(const ConstExprAnalysis::ConstValueInfo *info) { // If implementing the HoistableOpInterface, check whether the op is legal to // hoist. We still need to check for type legality afterwards though. if (auto hoistableOp = dyn_cast(op)) { - if (!hoistableOp.isHoistableLeafOp()) + if (!hoistableOp.isHoistableLeafOp()) { return false; + } } // If implementing the HoistableTypeInterface, at this point we can just diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/DepGraph.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/DepGraph.cpp index d1a5a116e9cf..5361ad5fa02a 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/DepGraph.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/DepGraph.cpp @@ -49,8 +49,9 @@ void DepGraph::dumpGraph() { std::error_code ec; llvm::raw_fd_ostream file(filename, ec, llvm::sys::fs::OF_TextWithCRLF); - if (!ec) + if (!ec) { llvm::WriteGraph(file, this); + } callTimes++; } diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Element.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Element.cpp index ca8d7329460b..4724a5c26d82 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Element.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Element.cpp @@ -15,8 +15,9 @@ namespace mlir::iree_compiler::DFX { ChangeStatus AbstractElement::update(Solver &solver) { ChangeStatus changeStatus = ChangeStatus::UNCHANGED; - if (getState().isAtFixpoint()) + if (getState().isAtFixpoint()) { return changeStatus; + } LLVM_DEBUG({ llvm::dbgs() << "[Solver] updating: "; diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Element.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Element.h index 0a7d95fc28aa..803feb60958f 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Element.h +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Element.h @@ -143,8 +143,9 @@ struct TypedOperationElement : public AbstractElement { ChangeStatus updateImpl(Solver &solver) override { if (isOperation()) { auto op = dyn_cast(getOperation()); - if (op) + if (op) { return updateOperation(op, solver); + } } return getState().indicatePessimisticFixpoint(); } diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Solver.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Solver.cpp index adbe01cf13f6..80df8a1736ef 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Solver.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Solver.cpp @@ -113,8 +113,9 @@ LogicalResult Solver::runTillFixpoint(int maxIterations) { // Use the invalidElements vector to propagate invalid states fast // transitively without requiring updates. - if (!elementState.isValidState()) + if (!elementState.isValidState()) { invalidElements.insert(element); + } } // Add elements to the changed set if they have been created in the last @@ -140,8 +141,9 @@ LogicalResult Solver::runTillFixpoint(int maxIterations) { SmallPtrSet visitedElements; for (size_t i = 0; i < changedElements.size(); i++) { auto *changedElement = changedElements[i]; - if (!visitedElements.insert(changedElement).second) + if (!visitedElements.insert(changedElement).second) { continue; + } auto &elementState = changedElement->getState(); if (!elementState.isAtFixpoint()) { @@ -183,8 +185,9 @@ ChangeStatus Solver::updateElement(AbstractElement &element) { // will not change and we can indicate that right away. elementState.indicateOptimisticFixpoint(); } - if (!elementState.isAtFixpoint()) + if (!elementState.isAtFixpoint()) { rememberDependencies(); + } // Verify the stack is balanced by ensuring we pop the vector we pushed above. auto *poppedDependencies = dependencyStack.pop_back_val(); @@ -198,15 +201,18 @@ ChangeStatus Solver::updateElement(AbstractElement &element) { void Solver::recordDependency(const AbstractElement &fromElement, const AbstractElement &toElement, Resolution resolution) { - if (resolution == Resolution::NONE) + if (resolution == Resolution::NONE) { return; + } // If we are outside of an update, thus before the actual fixpoint iteration // started (= when we create elements), we do not track dependencies because // we will put all elements into the initial worklist anyway. - if (dependencyStack.empty()) + if (dependencyStack.empty()) { return; - if (fromElement.getState().isAtFixpoint()) + } + if (fromElement.getState().isAtFixpoint()) { return; + } // NOTE: this may record several of the same dependency as there is no // deduplication. Deduplication is more expensive than the rarer case of // duplication, though, so we deal with it. diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Solver.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Solver.h index d340812ecd83..8bec39cd311c 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Solver.h +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Solver.h @@ -193,8 +193,9 @@ class Solver { // Lookup the abstract element of type ElementT and if found return it after // registering a dependence of queryingElement on the one returned element. auto *elementPtr = elementMap.lookup({&ElementT::ID, pos}); - if (!elementPtr) + if (!elementPtr) { return nullptr; + } auto *element = static_cast(elementPtr); // Do not register a dependence on an element with an invalid state. diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/State.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/State.cpp index d765f4c838c4..3966328941ec 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/State.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/State.cpp @@ -24,10 +24,12 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, if (!S.isValidState()) { os << "full-set"; } else { - for (auto &it : S.getAssumedSet()) + for (auto &it : S.getAssumedSet()) { os << it << ", "; - if (S.isUndefContained()) + } + if (S.isUndefContained()) { os << "undef "; + } } os << "} >)"; return os; diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/State.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/State.h index 75ac3c11e831..5782b9a3d272 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/State.h +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/State.h @@ -195,12 +195,14 @@ struct BooleanState : public IntegerStateBase { private: void handleNewKnownValue(base_t value) override { - if (value) + if (value) { known = (assumed = value); + } } void handleNewAssumedValue(base_t value) override { - if (!value) + if (!value) { assumed = known; + } } void joinOR(base_t assumedValue, base_t knownValue) override { @@ -423,12 +425,15 @@ struct PotentialValuesState : AbstractState { } bool operator==(const PotentialValuesState &rhs) const { - if (isValidState() != rhs.isValidState()) + if (isValidState() != rhs.isValidState()) { return false; - if (!isValidState() && !rhs.isValidState()) + } + if (!isValidState() && !rhs.isValidState()) { return true; - if (isUndefContained() != rhs.isUndefContained()) + } + if (isUndefContained() != rhs.isUndefContained()) { return false; + } return set == rhs.getAssumedSet(); } @@ -487,8 +492,9 @@ struct PotentialValuesState : AbstractState { // Inserts an element into this set. void insert(const MemberTy &c) { - if (!isValidState()) + if (!isValidState()) { return; + } set.insert(c); checkAndInvalidate(); } @@ -496,15 +502,17 @@ struct PotentialValuesState : AbstractState { // Takes union with |rhs|. void unionWith(const PotentialValuesState &rhs) { // If this is a full set, do nothing. - if (!isValidState()) + if (!isValidState()) { return; + } // If rhs is full set, change L to a full set. if (!rhs.isValidState()) { indicatePessimisticFixpoint(); return; } - for (const MemberTy &c : rhs.set) + for (const MemberTy &c : rhs.set) { set.insert(c); + } undefIsContained |= rhs.isUndefContained(); checkAndInvalidate(); } @@ -518,8 +526,9 @@ struct PotentialValuesState : AbstractState { // Takes intersection with |rhs|. void intersectWith(const PotentialValuesState &rhs) { // If rhs is a full set, do nothing. - if (!rhs.isValidState()) + if (!rhs.isValidState()) { return; + } // If this is a full set, change this to rhs. if (!isValidState()) { *this = rhs; @@ -527,8 +536,9 @@ struct PotentialValuesState : AbstractState { } SetTy intersectSet; for (const MemberTy &c : set) { - if (rhs.set.count(c)) + if (rhs.set.count(c)) { intersectSet.insert(c); + } } set = intersectSet; undefIsContained &= rhs.isUndefContained(); diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp index bcbd72cb1918..eb7621c10bf1 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp @@ -33,8 +33,9 @@ static std::optional mapSuccessorOperand(BranchOpInterface branchOp, // I don't know if there's a better way to do this - the interface doesn't // help. auto operandRange = branchOp.getSuccessorOperands(successorIdx); - if (operandRange.empty()) + if (operandRange.empty()) { return std::nullopt; + } unsigned beginIdx = operandRange.getForwardedOperands().getBeginOperandIndex(); if (operandIdx >= beginIdx && operandIdx < beginIdx + operandRange.size()) { @@ -187,8 +188,9 @@ void Explorer::initializeInverseCallGraph() { const Explorer::GlobalInfo * Explorer::getGlobalInfo(IREE::Util::GlobalOpInterface globalOp) { auto it = globalInfos.find(globalOp); - if (it == globalInfos.end()) + if (it == globalInfos.end()) { return nullptr; + } return it->second.get(); } @@ -198,11 +200,13 @@ const Explorer::GlobalInfo *Explorer::queryGlobalInfoFrom(StringRef globalName, auto &symbolTable = symbolTables.getSymbolTable(symbolTableOp); auto op = symbolTable.lookupNearestSymbolFrom( from, StringAttr::get(from->getContext(), globalName)); - if (!op) + if (!op) { return nullptr; + } auto it = globalInfos.find(op); - if (it == globalInfos.end()) + if (it == globalInfos.end()) { return nullptr; + } return it->second.get(); } @@ -259,8 +263,9 @@ void Explorer::forEachFunctionLikeOp( } bool Explorer::mayValuesAlias(Value a, Value b) { - if (a == b) + if (a == b) { return true; + } bool mayAlias = false; auto traversalResult = walkTransitiveUses(a, [&](OpOperand &value) { mayAlias = value.get() == b; @@ -287,8 +292,9 @@ TraversalResult Explorer::walk(OperationWalkFn fn) { LLVM_DEBUG(llvm::dbgs() << "? entering scc slice with " << scc.size() << " callables\n"); for (auto *node : scc) { - if (node->isExternal()) + if (node->isExternal()) { continue; + } // Ensure we want to step into this region. // Note that SCC returns every function like in the whole program, @@ -296,8 +302,9 @@ TraversalResult Explorer::walk(OperationWalkFn fn) { auto &callableRegion = *node->getCallableRegion(); auto *callableOp = callableRegion.getParentOp(); auto action = getTraversalAction(callableOp); - if (action == TraversalAction::IGNORE) + if (action == TraversalAction::IGNORE) { continue; + } bool validInPlace = true; for (auto *parentOp = callableOp->getParentOp(); parentOp != rootOp; parentOp = parentOp->getParentOp()) { @@ -315,10 +322,12 @@ TraversalResult Explorer::walk(OperationWalkFn fn) { LLVM_DEBUG(llvm::dbgs() << " + entering callable region @" << getRegionName(callableRegion) << "\n"); auto emitResult = recursiveWalk(callableOp, fn); - if (emitResult.wasInterrupted()) + if (emitResult.wasInterrupted()) { break; - if (emitResult.wasSkipped()) + } + if (emitResult.wasSkipped()) { continue; + } } } @@ -338,10 +347,12 @@ WalkResult Explorer::recursiveWalk(Operation *parentOp, LLVM_DEBUG(llvm::dbgs() << " == emitting op " << getOpName(parentOp) << "\n"); auto emitResult = fn(parentOp); - if (emitResult.wasInterrupted()) + if (emitResult.wasInterrupted()) { return WalkResult::interrupt(); - if (emitResult.wasSkipped()) + } + if (emitResult.wasSkipped()) { return WalkResult::advance(); + } if (parentOp->getNumRegions() == 0 || parentAction != TraversalAction::RECURSE) { @@ -355,8 +366,9 @@ WalkResult Explorer::recursiveWalk(Operation *parentOp, for (auto &block : region.getBlocks()) { for (auto &op : block) { auto opResult = recursiveWalk(&op, fn); - if (opResult.wasInterrupted()) + if (opResult.wasInterrupted()) { return WalkResult::interrupt(); + } } } } @@ -374,8 +386,9 @@ TraversalResult Explorer::walkAllValues(ValueWalkFn fn, LLVM_DEBUG(llvm::dbgs() << "? entering scc slice with " << scc.size() << " callables\n"); for (auto *node : scc) { - if (node->isExternal()) + if (node->isExternal()) { continue; + } // Ensure we want to step into this region. // Note that SCC returns every function like in the whole program, @@ -383,8 +396,9 @@ TraversalResult Explorer::walkAllValues(ValueWalkFn fn, auto &callableRegion = *node->getCallableRegion(); auto *callableOp = callableRegion.getParentOp(); auto action = getTraversalAction(callableOp); - if (action == TraversalAction::IGNORE) + if (action == TraversalAction::IGNORE) { continue; + } bool validInPlace = true; for (auto *parentOp = callableOp->getParentOp(); parentOp != rootOp; parentOp = parentOp->getParentOp()) { @@ -403,10 +417,12 @@ TraversalResult Explorer::walkAllValues(ValueWalkFn fn, << getRegionName(callableRegion) << "\n"); auto emitResult = recursiveWalkValues(callableOp, visitedValues, fn, typeID); - if (emitResult.wasInterrupted()) + if (emitResult.wasInterrupted()) { break; - if (emitResult.wasSkipped()) + } + if (emitResult.wasSkipped()) { continue; + } } } @@ -442,16 +458,18 @@ WalkResult Explorer::recursiveWalkValues(Operation *parentOp, LLVM_DEBUG(llvm::dbgs() << " + processing op results " << getOpName(parentOp) << "\n"); for (auto result : parentOp->getResults()) { - if (typeID.has_value() && result.getType().getTypeID() != *typeID) + if (typeID.has_value() && result.getType().getTypeID() != *typeID) { continue; + } if (visitedValues.insert(result).second) { LLVM_DEBUG({ llvm::dbgs() << " == emitting value "; result.printAsOperand(llvm::dbgs(), asmState); llvm::dbgs() << "\n"; }); - if (fn(result).wasInterrupted()) + if (fn(result).wasInterrupted()) { return WalkResult::interrupt(); + } } } } @@ -473,23 +491,26 @@ WalkResult Explorer::recursiveWalkValues(Operation *parentOp, llvm::dbgs() << " arguments\n"; }); for (auto arg : block.getArguments()) { - if (typeID.has_value() && arg.getType().getTypeID() != *typeID) + if (typeID.has_value() && arg.getType().getTypeID() != *typeID) { continue; + } if (visitedValues.insert(arg).second) { LLVM_DEBUG({ llvm::dbgs() << " == emitting block arg "; arg.printAsOperand(llvm::dbgs(), asmState); llvm::dbgs() << "\n"; }); - if (fn(arg).wasInterrupted()) + if (fn(arg).wasInterrupted()) { return WalkResult::interrupt(); + } } } } for (auto &op : block) { auto opResult = recursiveWalkValues(&op, visitedValues, fn, typeID); - if (opResult.wasInterrupted()) + if (opResult.wasInterrupted()) { return WalkResult::interrupt(); + } } } } @@ -502,8 +523,9 @@ Explorer::walkIncomingCalls(CallableOpInterface callableOp, auto it = callGraphInv.find(callableOp.getCallableRegion()); if (it != callGraphInv.end()) { for (auto &callOp : it->second) { - if (fn(callOp).wasInterrupted()) + if (fn(callOp).wasInterrupted()) { break; + } } } bool isPublic = false; @@ -560,8 +582,9 @@ TraversalResult Explorer::walkReturnOps(Operation *parentOp, return WalkResult::advance(); }; for (auto ®ion : regionOp->getRegions()) { - if (enumerateTerminatorOps(region).wasInterrupted()) + if (enumerateTerminatorOps(region).wasInterrupted()) { break; + } } } else if (auto parentFuncOp = dyn_cast(parentOp)) { @@ -582,8 +605,9 @@ TraversalResult Explorer::walkReturnOps(Operation *parentOp, terminatorOp->print(llvm::dbgs(), asmState); llvm::dbgs() << "\n"; }); - if (fn(terminatorOp).wasInterrupted()) + if (fn(terminatorOp).wasInterrupted()) { break; + } } } } @@ -711,8 +735,9 @@ TraversalResult Explorer::walkOutgoingBranchOperandArguments( ++successorIdx) { auto successorOperandIdx = mapSuccessorOperand(branchOp, successorIdx, operandIdx); - if (!successorOperandIdx.has_value()) + if (!successorOperandIdx.has_value()) { continue; + } auto *targetBlock = branchOp->getSuccessor(successorIdx); auto blockArg = targetBlock->getArgument(*successorOperandIdx); if (fn(targetBlock, blockArg).wasInterrupted()) { @@ -833,8 +858,9 @@ TraversalResult Explorer::walkDefiningOps(Value value, ResultWalkFn fn, << loadOp.getGlobalName() << ":\n"); for (auto *user : globalInfo->uses) { auto storeOp = dyn_cast(user); - if (!storeOp) + if (!storeOp) { continue; + } LLVM_DEBUG({ llvm::dbgs() << " + queuing stored value from "; storeOp.print(llvm::dbgs(), asmState); @@ -886,8 +912,9 @@ TraversalResult Explorer::walkDefiningOps(Value value, ResultWalkFn fn, do { // Pop the next work item; avoiding processing values more than once. auto work = worklist.pop_back_val(); - if (!processedValues.insert(work.getAsOpaquePointer()).second) + if (!processedValues.insert(work.getAsOpaquePointer()).second) { continue; + } LLVM_DEBUG({ llvm::dbgs() << " ? working on "; @@ -1115,8 +1142,9 @@ TraversalResult Explorer::walkTransitiveUses(Value value, UseWalkFn fn, << storeOp.getGlobalName() << ":\n"); for (auto *user : globalInfo->uses) { auto loadOp = dyn_cast(user); - if (!loadOp) + if (!loadOp) { continue; + } LLVM_DEBUG({ llvm::dbgs() << " + queuing loaded value from "; loadOp.print(llvm::dbgs(), asmState); @@ -1143,8 +1171,9 @@ TraversalResult Explorer::walkTransitiveUses(Value value, UseWalkFn fn, // times!). for (auto &use : work.getUses()) { auto *ownerOp = use.getOwner(); - if (!processedValues.insert(&use).second) + if (!processedValues.insert(&use).second) { continue; + } auto action = getTraversalAction(ownerOp); if (action == TraversalAction::IGNORE) { diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.cpp index f87864ebb0ad..2b17acd1bcb7 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.cpp @@ -37,8 +37,9 @@ LogicalResult IntegerDivisibilityAnalysis::visitOperation( }); auto joinCallback = [&](Value v, const IntegerDivisibility &newDiv) { auto result = dyn_cast(v); - if (!result) + if (!result) { return; + } assert(llvm::is_contained(op->getResults(), result)); LLVM_DEBUG(dbgs() << "Inferred divisibility " << newDiv << "\n"); diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/FuncToUtil/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Util/Conversion/FuncToUtil/Patterns.cpp index a3e19b82aa48..e66e96252bb1 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Conversion/FuncToUtil/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/FuncToUtil/Patterns.cpp @@ -98,8 +98,9 @@ struct FuncFuncOpPattern : public OpConversionPattern { for (auto retainAttrName : retainedAttributes) { StringRef attrName(retainAttrName); Attribute attr = srcOp->getAttr(attrName); - if (attr) + if (attr) { newFuncOp->setAttr(attrName, attr); + } } // Copy all arg/result attrs. We could filter these. diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/FuncToUtil/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Conversion/FuncToUtil/test/BUILD.bazel index 5db13f411ee6..3e92f3d8edc3 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Conversion/FuncToUtil/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/FuncToUtil/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "func_ops.mlir", ], diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/BUILD.bazel index 794f7ca76376..67706eb1ea6a 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "memref_ops.mlir", ], diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Conversion/test/BUILD.bazel index 54a003dd2c21..a4b61738a425 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Conversion/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "compiler_hints.mlir", "structural_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/IR/BUILD.bazel index 2ca83b21fdd2..0a8bf199db8c 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/IR/BUILD.bazel @@ -18,6 +18,7 @@ exports_files(["UtilBase.td"]) iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "UtilAttrs.td", "UtilBase.td", diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp index d58a5c1cb9aa..de119b161b51 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp @@ -155,8 +155,9 @@ static SmallVector findDuplicateRegionResults(Region ®ion) { auto uniformDupeIndexMap = llvm::to_vector(llvm::seq(0u, resultCount)); // old -> new for (unsigned idx = 0; idx < resultCount; ++idx) { - if (deadResultsMap.test(idx)) + if (deadResultsMap.test(idx)) { continue; + } // Each bit represents a result that duplicates the result at idx. // We walk all the sites and AND their masks together to get the safe // set of duplicate results. @@ -257,15 +258,17 @@ static void inlineClosureOperands(const ClosureOptimizationOptions &options, for (auto opArg : llvm::enumerate(closureOp.getClosureOperands())) { auto outerValue = opArg.value(); auto *sourceOp = outerValue.getDefiningOp(); - if (!sourceOp) + if (!sourceOp) { continue; // can't clone block arguments into closures + } // We cannot just simply inline and replace all users if this is an // argument that can be written; for example, the region might perform // work after loading a initial constant from the argument and then // write back. - if (!closureOp.getOperandAccess(opArg.index()).isReadOnly()) + if (!closureOp.getOperandAccess(opArg.index()).isReadOnly()) { continue; + } if (closureOp.canClosureContainOp(sourceOp) && shouldInlineIntoClosure(options, outerValue)) { diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp index 12b2127db67a..097fc4ad3395 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp @@ -258,15 +258,17 @@ class PackedWriter { : logicalBitWidth(logicalBitWidth), endian(endian), os(os) {} void write(const uint64_t value) { - if (bitOffset + logicalBitWidth > physicalBitWidth) + if (bitOffset + logicalBitWidth > physicalBitWidth) { flush(); + } physicalBuffer |= value << bitOffset; bitOffset += logicalBitWidth; } void flush() { - if (bitOffset == 0) + if (bitOffset == 0) { return; + } physicalType physicalValue = llvm::support::endian::byte_swap(physicalBuffer, endian); os.write((const char *)&physicalValue, sizeof(physicalValue)); @@ -533,8 +535,9 @@ LogicalResult BytePatternAttr::serializeToStream(Location loc, //===----------------------------------------------------------------------===// Attribute ByteRangeAttr::parse(AsmParser &p, Type type) { - if (failed(p.parseLess())) + if (failed(p.parseLess())) { return {}; + } // TODO(benvanik): support the range syntax; the dialect asm parser fights // with it though by checking for proper []/() nesting. @@ -573,8 +576,9 @@ Attribute ByteRangeAttr::parse(AsmParser &p, Type type) { return {}; } - if (failed(p.parseGreater())) + if (failed(p.parseGreater())) { return {}; + } start = startInclusive ? start : start + 1; end = endInclusive ? end : end - 1; @@ -912,8 +916,9 @@ void HoistableAttrInterface::gatherHoistableAttrs(Operation *fromOp, } } } - if (auto *parentOp = fromOp->getParentOp()) + if (auto *parentOp = fromOp->getParentOp()) { gatherHoistableAttrs(parentOp, dialectAttrs); + } } // static @@ -923,8 +928,9 @@ void HoistableAttrInterface::gatherHoistableAttrs(Operation *fromOp, // precedence over any from ancestors. We also want to preserve any // non-hoistable attrs when we reassign the dialect attrs. NamedAttrList dialectAttrs; - for (auto attr : toOp->getDialectAttrs()) + for (auto attr : toOp->getDialectAttrs()) { dialectAttrs.push_back(attr); + } // Gather attributes from the op and its parents, only adding ones not already // set on the op. diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilDialect.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilDialect.cpp index 86ed3a8281a5..d116e5905e12 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilDialect.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilDialect.cpp @@ -54,8 +54,9 @@ struct UtilInlinerInterface : public DialectInlinerInterface { if (auto inliningPolicy = callable->getAttrOfType( "inlining_policy")) { - if (!inliningPolicy.isLegalToInline(call, callable)) + if (!inliningPolicy.isLegalToInline(call, callable)) { return false; + } } // Check any extended inlining policies that may come from dialect @@ -64,8 +65,9 @@ struct UtilInlinerInterface : public DialectInlinerInterface { if (auto inliningPolicy = dyn_cast( attr.getValue())) { - if (!inliningPolicy.isLegalToInline(call, callable)) + if (!inliningPolicy.isLegalToInline(call, callable)) { return false; + } } } @@ -86,8 +88,9 @@ struct UtilInlinerInterface : public DialectInlinerInterface { } void handleTerminator(Operation *op, Block *newDest) const final { - if (!op->hasTrait()) + if (!op->hasTrait()) { return; + } OpBuilder builder(op); if (auto returnOp = dyn_cast(op)) { @@ -159,8 +162,9 @@ struct FoldDimOp : public OpRewritePattern { } auto shapeAwareOp = dyn_cast_if_present(source.getDefiningOp()); - if (!shapeAwareOp) + if (!shapeAwareOp) { return failure(); + } // We only support static dimension indices today (as in general we only // support ranked shapes). If we find dynamic indices sneaking in we will diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp index 59f09570eb48..44cd74d55d91 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp @@ -81,14 +81,16 @@ static LogicalResult canonicalizeAssumeIntOp(AssumeIntOp op, needsRewrite = true; } } - if (!needsRewrite) + if (!needsRewrite) { return failure(); + } // Need to rewrite the assumption. auto normalizeAssumptions = [](Attribute row, bool &madeChange) { auto rowArray = cast(row); - if (rowArray.size() <= 1) + if (rowArray.size() <= 1) { return rowArray; + } bool allSame = true; for (unsigned i = 1; i < rowArray.size(); ++i) { @@ -98,8 +100,9 @@ static LogicalResult canonicalizeAssumeIntOp(AssumeIntOp op, } } - if (!allSame) + if (!allSame) { return rowArray; + } // All entries are the same: compress down to a single column. madeChange = true; @@ -350,8 +353,9 @@ struct FoldCastIntoNullOp : public OpRewritePattern { PatternRewriter &rewriter) const override { auto nullOp = dyn_cast_if_present(castOp.getOperand().getDefiningOp()); - if (!nullOp) + if (!nullOp) { return failure(); + } rewriter.replaceOpWithNewOp(castOp, castOp.getResult().getType()); return success(); } @@ -425,8 +429,9 @@ static OpFoldResult foldRangeOp(Type type, ValueRange operands, int64_t value = initialValue; for (auto operand : attrOperands) { auto intValue = dyn_cast_if_present(operand); - if (!intValue) + if (!intValue) { return {}; + } value = expr(value, intValue.getValue().getSExtValue()); } return IntegerAttr::get(type, value); @@ -566,8 +571,9 @@ struct FoldConstantRanges : public OpRewritePattern { lengths.push_back(length); } } - if (offsets.size() == op.getOffsets().size()) + if (offsets.size() == op.getOffsets().size()) { return failure(); + } // Preserve dynamic ranges. Value min; @@ -627,8 +633,9 @@ struct ExpandSimpleRangeExtentsOp : public OpRewritePattern { op.getLengths().back(), one, rewriter); maxValue = arith::MaxUIOp::create(rewriter, loc, endLhs, endRhs); } - if (!minValue || !maxValue) + if (!minValue || !maxValue) { return failure(); + } rewriter.replaceOp(op, {minValue, maxValue}); return success(); } @@ -645,8 +652,9 @@ struct DeduplicateRangeExtentsOp : public OpRewritePattern { for (auto range : llvm::zip_equal(op.getOffsets(), op.getLengths())) { ranges.insert(range); } - if (ranges.size() == op.getOffsets().size()) + if (ranges.size() == op.getOffsets().size()) { return failure(); + } // Recreate with the deduplicated ranges. SmallVector offsets; @@ -702,8 +710,9 @@ static bool isAlignedTo(Value value, Value alignment) { // If the value is produced by an align op we can check that. if (auto sourceAlignOp = value.getDefiningOp()) { // Check for same exact alignment - even if dynamic. - if (sourceAlignOp.getAlignment() == alignment) + if (sourceAlignOp.getAlignment() == alignment) { return true; + } // If the alignments are constant we can compare them inline. APInt sourceAlignment; @@ -762,8 +771,9 @@ static bool isAlignedTo(Value value, Value alignment) { OpFoldResult AlignOp::fold(FoldAdaptor operands) { // If aligning an already-aligned value then fold if this is provably a // no-op. We can check this for equality even with dynamic alignments. - if (isAlignedTo(getValue(), getAlignment())) + if (isAlignedTo(getValue(), getAlignment())) { return getValue(); + } // If values are static we can perform the alignment here. APInt staticValue; @@ -992,8 +1002,9 @@ struct DropEmptyInitializerOp : public OpRewritePattern { LogicalResult matchAndRewrite(InitializerOp op, PatternRewriter &rewriter) const override { - if (op.getBody().getBlocks().size() != 1) + if (op.getBody().getBlocks().size() != 1) { return failure(); + } auto &block = op.getBody().front(); // Empty block or block with only a ReturnLike terminator. if (block.empty() || (block.getOperations().size() == 1 && @@ -1128,8 +1139,9 @@ struct FoldBufferSubspanOps : public OpRewritePattern { LogicalResult matchAndRewrite(BufferSubspanOp op, PatternRewriter &rewriter) const override { auto parentOp = BufferSubspanOp::findSubspanOp(op.getSource()); - if (!parentOp) + if (!parentOp) { return failure(); + } auto fusedLoc = rewriter.getFusedLoc({parentOp.getLoc(), op.getLoc()}); auto newOffset = rewriter.createOrFold( fusedLoc, parentOp.getSourceOffset(), op.getSourceOffset()); @@ -1159,8 +1171,9 @@ struct FoldBufferSubspanOpsIntoConsumers for (auto &use : llvm::make_early_inc_range(op.getResult().getUses())) { auto subrangeOp = dyn_cast(use.getOwner()); - if (!subrangeOp) + if (!subrangeOp) { continue; + } didUpdateAny = true; rewriter.setInsertionPoint(subrangeOp); auto oldRange = subrangeOp.getSubrangeOperand(use.getOperandNumber()); @@ -1193,14 +1206,16 @@ struct SinkSubspanAcrossSelectOps using Base::Base; LogicalResult matchAndRewrite(mlir::arith::SelectOp op, PatternRewriter &rewriter) const override { - if (!isa(op.getType())) + if (!isa(op.getType())) { return failure(); + } auto trueSubspan = dyn_cast_if_present( op.getTrueValue().getDefiningOp()); auto falseSubspan = dyn_cast_if_present( op.getFalseValue().getDefiningOp()); - if (!trueSubspan || !falseSubspan) + if (!trueSubspan || !falseSubspan) { return failure(); + } if (trueSubspan.getSource() != falseSubspan.getSource() || trueSubspan.getResultSize() != falseSubspan.getResultSize()) { return failure(); @@ -1275,8 +1290,9 @@ struct SelectBufferSizeOp : public OpRewritePattern { LogicalResult matchAndRewrite(BufferSizeOp op, PatternRewriter &rewriter) const override { auto selectOp = op.getOperand().getDefiningOp(); - if (!selectOp) + if (!selectOp) { return failure(); + } auto trueSize = rewriter.createOrFold( op.getLoc(), selectOp.getTrueValue()); auto falseSize = rewriter.createOrFold( @@ -1313,8 +1329,9 @@ struct FoldSubspansIntoStorageOp : public OpRewritePattern { LogicalResult matchAndRewrite(BufferStorageOp op, PatternRewriter &rewriter) const override { auto subspanOp = BufferSubspanOp::findSubspanOp(op.getOperand()); - if (!subspanOp) + if (!subspanOp) { return failure(); + } auto fusedLoc = rewriter.getFusedLoc({subspanOp.getLoc(), op.getLoc()}); rewriter.setInsertionPointAfter(op); auto newOffset = rewriter.createOrFold( diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp index 698b983d160e..1baaaf55b6c1 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp @@ -66,8 +66,9 @@ Value buildIfElseTree( ArrayAttr deduplicateArrayElements(ArrayAttr arrayAttr) { SetVector attrsSet(arrayAttr.begin(), arrayAttr.end()); - if (attrsSet.size() == arrayAttr.size()) + if (attrsSet.size() == arrayAttr.size()) { return arrayAttr; + } return ArrayAttr::get(arrayAttr.getContext(), attrsSet.takeVector()); } @@ -202,8 +203,9 @@ void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, needsSpace = true; // subsequent attr value needs a space separator } if (attr) { - if (needsSpace) + if (needsSpace) { p << ' '; + } p << "= "; p.printAttribute(attr); } @@ -249,12 +251,14 @@ void printSymbolAlias(OpAsmPrinter &p, Operation *op, StringAttr sym_name, ParseResult parseTypeAlias(OpAsmParser &parser, TypeAttr &encodingTypeAttr, Type &storageType) { Type encodingType; - if (failed(parser.parseType(encodingType))) + if (failed(parser.parseType(encodingType))) { return failure(); + } storageType = encodingType; if (succeeded(parser.parseOptionalKeyword("as"))) { - if (failed(parser.parseType(storageType))) + if (failed(parser.parseType(storageType))) { return failure(); + } } encodingTypeAttr = TypeAttr::get(encodingType); return success(); @@ -356,18 +360,22 @@ void printSizeAwareType(OpAsmPrinter &p, Operation *op, Type type, Value size) { ParseResult parseOperandTypeList(OpAsmParser &parser, SmallVectorImpl &operandTypes) { - if (failed(parser.parseLParen())) + if (failed(parser.parseLParen())) { return failure(); - if (succeeded(parser.parseOptionalRParen())) + } + if (succeeded(parser.parseOptionalRParen())) { return success(); // empty + } do { Type type; - if (failed(parser.parseType(type))) + if (failed(parser.parseType(type))) { return failure(); + } operandTypes.push_back(type); } while (succeeded(parser.parseOptionalComma())); - if (failed(parser.parseRParen())) + if (failed(parser.parseRParen())) { return failure(); + } return success(); } @@ -403,8 +411,9 @@ parseTiedResultList(OpAsmParser &parser, } if (succeeded(parser.parseOptionalKeyword("as"))) { // Type _may_ differ from the operand. - if (failed(parser.parseType(type))) + if (failed(parser.parseType(type))) { return failure(); + } } else { // Use the operands type. type = operandTypes[tiedOperandIndex]; @@ -443,8 +452,9 @@ void printTiedResultList(OpAsmPrinter &p, Operation *op, ValueRange operands, if (printType) { p.printType(resultType); } - if (i < resultTypes.size() - 1) + if (i < resultTypes.size() - 1) { p << ", "; + } } } @@ -476,8 +486,9 @@ parseTiedFunctionResultListImpl(OpAsmParser &parser, } if (succeeded(parser.parseOptionalKeyword("as"))) { // Type _may_ differ from the operand. - if (failed(parser.parseType(type))) + if (failed(parser.parseType(type))) { return failure(); + } } else { // Use the operands type. type = arguments[tiedOperandIndex].type; @@ -566,11 +577,13 @@ void printTiedFunctionResultList(OpAsmPrinter &p, Operation *op, ValueRange operands, TypeRange operandTypes, TypeRange resultTypes, ArrayAttr tiedOperands) { - if (resultTypes.size() != 1) + if (resultTypes.size() != 1) { p << "("; + } printTiedResultList(p, op, operands, operandTypes, resultTypes, tiedOperands); - if (resultTypes.size() != 1) + if (resultTypes.size() != 1) { p << ")"; + } } //===----------------------------------------------------------------------===// @@ -583,8 +596,9 @@ parseShapedTypeList(OpAsmParser &parser, SmallVectorImpl &types, SmallVectorImpl &dims) { do { Type type; - if (failed(parser.parseType(type))) + if (failed(parser.parseType(type))) { return failure(); + } if (auto shapedType = dyn_cast(type)) { if (!shapedType.hasStaticShape()) { SmallVector dynamicDims; @@ -639,8 +653,9 @@ ParseResult parseShapedTypeList(OpAsmParser &parser, SmallVectorImpl &types0, SmallVectorImpl &types1, SmallVectorImpl &dims) { - if (failed(parseShapedTypeList(parser, types0, dims))) + if (failed(parseShapedTypeList(parser, types0, dims))) { return failure(); + } types1 = types0; return success(); } @@ -672,11 +687,13 @@ ParseResult parseShapedTiedResult( int64_t tiedOperandIndex = IREE::Util::TiedOpInterface::kUntiedIndex; if (res.has_value() && succeeded(res.value())) { tiedOperandIndex = 0; - if (failed(parser.parseKeyword("as"))) + if (failed(parser.parseKeyword("as"))) { return failure(); + } } - if (failed(parser.parseType(resultType))) + if (failed(parser.parseType(resultType))) { return failure(); + } if (auto shapedType = dyn_cast(resultType)) { if (!shapedType.hasStaticShape()) { SmallVector dynamicDims; @@ -766,8 +783,9 @@ ParseResult parseShapedResultList( } if (succeeded(parser.parseOptionalKeyword("as"))) { // Type _may_ differ from the operand. - if (failed(parser.parseType(type))) + if (failed(parser.parseType(type))) { return failure(); + } } else { // Use the operands type. type = operandTypes[tiedOperandIndex]; @@ -848,8 +866,9 @@ void printShapedResultList(OpAsmPrinter &p, Operation *op, ValueRange operands, p << "}"; resultDims = resultDims.drop_front(1); } - if (i < resultTypes.size() - 1) + if (i < resultTypes.size() - 1) { p << ", "; + } } } @@ -865,16 +884,18 @@ ParseResult parseShapedFunctionType( SmallVectorImpl &resultTypes, SmallVectorImpl &resultDims, ArrayAttr &tiedOperands) { - if (failed(parser.parseLParen())) + if (failed(parser.parseLParen())) { return failure(); + } if (failed(parser.parseOptionalRParen())) { if (failed(parseShapedTypeList(parser, operandTypes, operandDims)) || failed(parser.parseRParen())) { return failure(); } } - if (failed(parser.parseArrow())) + if (failed(parser.parseArrow())) { return failure(); + } if (succeeded(parser.parseOptionalLParen())) { if (succeeded(parser.parseOptionalRParen())) { // Empty list/no results `()`. @@ -905,12 +926,14 @@ void printShapedFunctionType(OpAsmPrinter &p, Operation *op, p << "("; printShapedTypeList(p, op, operandTypes, operandDims); p << ") -> "; - if (resultTypes.size() != 1) + if (resultTypes.size() != 1) { p << "("; + } printShapedResultList(p, op, operands, operandTypes, operandDims, resultTypes, resultDims, tiedOperands); - if (resultTypes.size() != 1) + if (resultTypes.size() != 1) { p << ")"; + } } //===----------------------------------------------------------------------===// @@ -961,8 +984,9 @@ static ParseResult parseShapedFunctionResultList( } if (succeeded(parser.parseOptionalKeyword("as"))) { // Type _may_ differ from the operand. - if (failed(parser.parseType(type))) + if (failed(parser.parseType(type))) { return failure(); + } } else { // Use the operands type. type = argTypes[tiedOperandIndex]; @@ -1016,8 +1040,9 @@ static void printShapedFunctionResultList(OpAsmPrinter &p, Operation *op, p.printOptionalAttrDict(attrs.getValue()); } } - if (i < resultTypes.size() - 1) + if (i < resultTypes.size() - 1) { p << ", "; + } } } @@ -1029,8 +1054,9 @@ ParseResult parseShapedFunctionSignature(OpAsmParser &parser, SmallVector args; SmallVector argTypes; SmallVector resultTypes; - if (failed(parser.parseLParen())) + if (failed(parser.parseLParen())) { return failure(); + } if (failed(parser.parseOptionalRParen())) { if (failed(parseShapedFunctionArgumentList(parser, args, argTypes, argAttrs)) || @@ -1074,8 +1100,9 @@ void printShapedFunctionSignature(OpAsmPrinter &p, Operation *op, if (argAttrs) { auto attrs = dyn_cast_if_present(argAttrs.getValue()[argIndex]); - if (attrs && !attrs.empty()) + if (attrs && !attrs.empty()) { p.printOptionalAttrDict(attrs.getValue()); + } } ++argIndex; }); @@ -1087,12 +1114,14 @@ void printShapedFunctionSignature(OpAsmPrinter &p, Operation *op, resultAttrs && !resultAttrs.empty() && llvm::any_of(resultAttrs.getAsValueRange(), [](auto attr) { return !attr.empty(); }); - if (resultTypes.size() != 1 || anyResultAttrs) + if (resultTypes.size() != 1 || anyResultAttrs) { p << "("; + } printShapedFunctionResultList(p, op, functionType.getInputs(), resultTypes, resultAttrs, tiedOperands); - if (resultTypes.size() != 1 || anyResultAttrs) + if (resultTypes.size() != 1 || anyResultAttrs) { p << ")"; + } } } @@ -1121,24 +1150,27 @@ void AlignOp::inferResultRanges(ArrayRef argRanges, auto align = [&](APInt value, bool &invalid) -> APInt { APInt aligned = (value + alignmentM1) & alignmentM1Inv; // Detect overflow, which commonly happens at max range. - if (aligned.ult(value)) + if (aligned.ult(value)) { invalid = true; + } return aligned; }; bool invalid = false; auto alignedUmin = align(umin, invalid); auto alignedUmax = align(umax, invalid); - if (!invalid) + if (!invalid) { setResultRange(getResult(), ConstantIntRanges::fromUnsigned(alignedUmin, alignedUmax)); + } } } void AlignOp::inferResultDivisibility(ArrayRef argDivs, SetIntDivisibilityFn setResultDivs) { auto alignmentDiv = argDivs[1]; - if (alignmentDiv.isUninitialized()) + if (alignmentDiv.isUninitialized()) { return; + } setResultDivs(getResult(), alignmentDiv.getValue()); } @@ -1186,8 +1218,9 @@ AssumeIntOp::getUnionedUnsignedRange(unsigned operandIndex) { static bool isConstantZero(IntAssumptionAttr assumption) { std::optional umin = assumption.getUmin(); std::optional umax = assumption.getUmax(); - if (!umin || !umax) + if (!umin || !umax) { return false; + } return *umin == 0 && *umax == 0; } @@ -1199,14 +1232,16 @@ AssumeIntOp::getUnionedUnsignedDivisor(unsigned operandIndex) { auto divisor = assumption.getUdiv(); if (!divisor) { // Constant zero is divisible by anything - if (isConstantZero(assumption)) + if (isConstantZero(assumption)) { continue; + } return std::nullopt; } - if (divisorUnion) + if (divisorUnion) { divisorUnion = std::gcd(*divisor, *divisorUnion); - else + } else { divisorUnion = *divisor; + } } return divisorUnion; } @@ -1216,19 +1251,22 @@ void AssumeIntOp::inferResultRanges(ArrayRef argRanges, for (auto [index, result] : llvm::enumerate(getResults())) { Type type = result.getType(); unsigned bitWidth; - if (isa(type)) + if (isa(type)) { bitWidth = 64; - else if (auto intType = dyn_cast(type)) + } else if (auto intType = dyn_cast(type)) { bitWidth = intType.getWidth(); - else + } else { continue; + } auto [umin, umax] = getUnionedUnsignedRange(index); auto uminAp = APInt::getMinValue(bitWidth); auto umaxAp = APInt::getMaxValue(bitWidth); - if (umin) + if (umin) { uminAp = APInt(bitWidth, *umin); - if (umax) + } + if (umax) { umaxAp = APInt(bitWidth, *umax); + } setResultRange(result, ConstantIntRanges::fromUnsigned(uminAp, umaxAp)); } @@ -1238,8 +1276,9 @@ void AssumeIntOp::inferResultDivisibility(ArrayRef argDivs, SetIntDivisibilityFn setResultDivs) { for (auto [index, result] : llvm::enumerate(getResults())) { Type type = result.getType(); - if (!isa(type) && !isa(type)) + if (!isa(type) && !isa(type)) { continue; + } auto udiv = getUnionedUnsignedDivisor(index); if (udiv) { setResultDivs(result, @@ -1261,8 +1300,9 @@ void AssumeIntOp::build(OpBuilder &builder, OperationState &state, ArrayRef operands, ArrayRef assumptions) { state.addOperands(operands); - for (auto operand : operands) + for (auto operand : operands) { state.addTypes({operand.getType()}); + } state.addAttribute("assumptions", ArrayAttr::get(builder.getContext(), ArrayRef(assumptions.begin(), @@ -1282,12 +1322,14 @@ LogicalResult AssumeIntOp::verify() { llvm::enumerate(allOperandAssumptions)) { auto operandAssumptions = cast(operandAssumptionsAttr); // We always allow a single row to broadcast to any requested size. - if (operandAssumptions.size() == 1) + if (operandAssumptions.size() == 1) { continue; - if (rank && *rank != operandAssumptions.size()) + } + if (rank && *rank != operandAssumptions.size()) { return emitOpError() << "expected operand #" << index << " to have " << *rank << " assumptions but it has " << operandAssumptions.size(); + } rank = operandAssumptions.size(); } @@ -1304,29 +1346,35 @@ ParseResult AssumeIntOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand &parsedOperand = parsedOperands.back(); SmallVector operandAssumptions; - if (parser.parseOperand(parsedOperand)) + if (parser.parseOperand(parsedOperand)) { return failure(); + } // Parse as a single assumption or a list. if (failed(parser.parseOptionalLSquare())) { // Single assumption. IntAssumptionAttr singleAssumption; - if (parser.parseCustomAttributeWithFallback(singleAssumption)) + if (parser.parseCustomAttributeWithFallback(singleAssumption)) { return failure(); + } operandAssumptions.push_back(singleAssumption); } else { // Multiple assumptions. if (failed(parser.parseOptionalRSquare())) { if (parser.parseCommaSeparatedList([&]() { IntAssumptionAttr singleAssumption; - if (parser.parseCustomAttributeWithFallback(singleAssumption)) + if (parser.parseCustomAttributeWithFallback( + singleAssumption)) { return failure(); + } operandAssumptions.push_back(singleAssumption); return success(); - })) + })) { return failure(); - if (parser.parseRSquare()) + } + if (parser.parseRSquare()) { return failure(); + } } } @@ -1335,22 +1383,26 @@ ParseResult AssumeIntOp::parse(OpAsmParser &parser, OperationState &result) { parser.getBuilder().getArrayAttr(operandAssumptions)); return success(); - })) + })) { return failure(); + } // Parse `:` type. - if (parser.parseColon() || parser.parseTypeList(parsedOperandTypes)) + if (parser.parseColon() || parser.parseTypeList(parsedOperandTypes)) { return failure(); + } result.addTypes(parsedOperandTypes); if (parser.resolveOperands(parsedOperands, parsedOperandTypes, - parser.getNameLoc(), result.operands)) + parser.getNameLoc(), result.operands)) { return failure(); + } result.attributes.append( "assumptions", parser.getBuilder().getArrayAttr(allOperandAssumptions)); - if (parser.parseOptionalAttrDict(result.attributes)) + if (parser.parseOptionalAttrDict(result.attributes)) { return failure(); + } return success(); } @@ -1425,15 +1477,17 @@ ParseResult UnfoldableConstantOp::parse(OpAsmParser &parser, OperationState &state) { Attribute valueAttr; if (parser.parseOptionalAttrDict(state.attributes) || - parser.parseAttribute(valueAttr, "value", state.attributes)) + parser.parseAttribute(valueAttr, "value", state.attributes)) { return failure(); + } // If the attribute is a symbol reference, then we expect a trailing type. Type type; - if (!isa(valueAttr)) + if (!isa(valueAttr)) { type = cast(valueAttr).getType(); - else if (parser.parseColonType(type)) + } else if (parser.parseColonType(type)) { return failure(); + } // Add the attribute type to the list. return parser.addTypeToList(type, state.types); @@ -1444,13 +1498,15 @@ void UnfoldableConstantOp::print(OpAsmPrinter &p) { p << " "; p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"}); - if (op->getAttrs().size() > 1) + if (op->getAttrs().size() > 1) { p << ' '; + } p << getValue(); // If the value is a symbol reference, print a trailing type. - if (isa(getValue())) + if (isa(getValue())) { p << " : " << getType(); + } } //===----------------------------------------------------------------------===// @@ -1458,8 +1514,9 @@ void UnfoldableConstantOp::print(OpAsmPrinter &p) { //===----------------------------------------------------------------------===// bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { - if (inputs.size() != 1 || outputs.size() != 1) + if (inputs.size() != 1 || outputs.size() != 1) { return false; + } Type a = inputs.front(), b = outputs.front(); if (a == b) { // Both types are the same. @@ -1509,8 +1566,9 @@ SmallVector CastOp::getTiedResultOperandIndices() { std::optional> NumericOptionalNarrowOp::getIntegerRange() { - if (!getMinValue() || !getMaxValue()) + if (!getMinValue() || !getMaxValue()) { return {}; + } bool signExtend = isSigned(); // Note: Cannot sign extend 0 bit values. int64_t minValue = signExtend && getMinValue()->getBitWidth() > 0 @@ -1621,22 +1679,26 @@ parseFunctionArgumentList(OpAsmParser &parser, auto argPresent = parser.parseOptionalArgument( argument, /*allowType=*/true, /*allowAttrs=*/true); if (argPresent.has_value()) { - if (failed(argPresent.value())) + if (failed(argPresent.value())) { return failure(); // Present but malformed. - if (!arguments.empty() && arguments.back().ssaName.name.empty()) + } + if (!arguments.empty() && arguments.back().ssaName.name.empty()) { return parser.emitError(argument.ssaName.location, "expected type instead of SSA identifier"); + } } else { argument.ssaName.location = parser.getCurrentLocation(); - if (!arguments.empty() && !arguments.back().ssaName.name.empty()) + if (!arguments.empty() && !arguments.back().ssaName.name.empty()) { return parser.emitError(argument.ssaName.location, "expected SSA identifier"); + } NamedAttrList attrs; if (parser.parseType(argument.type) || parser.parseOptionalAttrDict(attrs) || - parser.parseOptionalLocationSpecifier(argument.sourceLoc)) + parser.parseOptionalLocationSpecifier(argument.sourceLoc)) { return failure(); + } argument.attrs = attrs.getDictionary(parser.getContext()); } arguments.push_back(argument); @@ -1648,52 +1710,61 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); StringAttr symVisibilityAttr; - if (failed(parseSymbolVisibility(parser, symVisibilityAttr))) + if (failed(parseSymbolVisibility(parser, symVisibilityAttr))) { return failure(); - if (symVisibilityAttr) + } + if (symVisibilityAttr) { result.addAttribute(SymbolTable::getVisibilityAttrName(), symVisibilityAttr); + } StringAttr nameAttr; if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), - result.attributes)) + result.attributes)) { return failure(); + } SmallVector arguments; - if (parseFunctionArgumentList(parser, arguments)) + if (parseFunctionArgumentList(parser, arguments)) { return failure(); + } SmallVector resultTypes; SmallVector resultAttrs; ArrayAttr tiedOperands; if (succeeded(parser.parseOptionalArrow())) { if (failed(parseTiedFunctionResultList(parser, arguments, resultTypes, - resultAttrs, tiedOperands))) + resultAttrs, tiedOperands))) { return failure(); + } } - if (tiedOperands) + if (tiedOperands) { result.addAttribute("tied_operands", tiedOperands); + } SmallVector argumentTypes; - for (auto argument : arguments) + for (auto argument : arguments) { argumentTypes.push_back(argument.type); + } result.addAttribute("function_type", TypeAttr::get(builder.getFunctionType( argumentTypes, resultTypes))); NamedAttrList parsedAttributes; SMLoc attributeDictLocation = parser.getCurrentLocation(); - if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes)) + if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes)) { return failure(); + } for (StringRef disallowed : { SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(), StringRef("function_type"), }) { - if (parsedAttributes.get(disallowed)) + if (parsedAttributes.get(disallowed)) { return parser.emitError(attributeDictLocation, "'") << disallowed << "' is an inferred attribute and should not be specified in the " "explicit attribute dictionary"; + } } result.attributes.append(parsedAttributes); @@ -1707,10 +1778,12 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { auto parseResult = parser.parseOptionalRegion(*body, arguments, /*enableNameShadowing=*/false); if (parseResult.has_value()) { - if (failed(*parseResult)) + if (failed(*parseResult)) { return failure(); - if (body->empty()) + } + if (body->empty()) { return parser.emitError(loc, "expected non-empty function body"); + } } return success(); } @@ -1745,8 +1818,9 @@ bool IREE::Util::FuncOp::canDiscardOnUseEmpty() { bool IREE::Util::FuncOp::hasAnyTiedOperands() { auto tiedOperandsAttr = getTiedOperandsAttr(); - if (!tiedOperandsAttr) + if (!tiedOperandsAttr) { return false; + } return llvm::any_of( tiedOperandsAttr.getAsRange(), [](IntegerAttr attr) { return attr.getInt() != IREE::Util::TiedOpInterface::kUntiedIndex; @@ -1773,8 +1847,9 @@ void IREE::Util::FuncOp::expandSignature( expandArgument(oldIndex, argType, newArgumentTypes); size_t expandedCount = newArgumentTypes.size() - newIndex; for (size_t i = 0; i < adjustedTiedOperands.size(); ++i) { - if (adjustedTiedOperands[i] == oldIndex) + if (adjustedTiedOperands[i] == oldIndex) { adjustedTiedOperands[i] = newIndex; + } } newArgumentAttrs.push_back(oldArgumentAttrs[oldIndex]); newArgumentAttrs.append(expandedCount - 1, @@ -1819,8 +1894,9 @@ FunctionType CallOp::getCalleeType() { static bool areTiedOperandsEqual(ArrayAttr a, ArrayAttr b) { auto hasAnyTied = [](ArrayAttr tiedOperandsAttr) { - if (!tiedOperandsAttr) + if (!tiedOperandsAttr) { return false; + } return llvm::any_of( tiedOperandsAttr.getAsRange(), [](IntegerAttr attr) { return attr.getInt() != IREE::Util::TiedOpInterface::kUntiedIndex; @@ -1828,10 +1904,12 @@ static bool areTiedOperandsEqual(ArrayAttr a, ArrayAttr b) { }; bool hasAnyTiedA = hasAnyTied(a); bool hasAnyTiedB = hasAnyTied(b); - if (hasAnyTiedA != hasAnyTiedB) + if (hasAnyTiedA != hasAnyTiedB) { return false; - if (!a || !b) + } + if (!a || !b) { return true; + } return a == b; } @@ -1877,8 +1955,9 @@ IREE::Util::CallOp IREE::Util::CallOp::cloneAndExpand( size_t newIndex = newOperands.size(); expandOperand(oldIndex, operand, newOperands); for (size_t i = 0; i < adjustedTiedOperands.size(); ++i) { - if (adjustedTiedOperands[i] == oldIndex) + if (adjustedTiedOperands[i] == oldIndex) { adjustedTiedOperands[i] = newIndex; + } } } @@ -2026,8 +2105,9 @@ void GlobalLoadOp::getEffects( SmallVectorImpl &effects) { // HACK: mlir doesn't have symbol side effects so we have to mark as a global // read if not immutable and not in an initializer. - if (!isGlobalImmutable()) + if (!isGlobalImmutable()) { effects.emplace_back(MemoryEffects::Read::get()); + } } LogicalResult @@ -2221,6 +2301,58 @@ void BufferConstantOp::getAsmResultNames( setNameFn(getResult(), getName().value_or("buffer_cst")); } +void BufferConstantOp::build(OpBuilder &builder, OperationState &state, + Attribute value) { + state.addTypes({builder.getType()}); + state.addAttribute("value", value); +} + +void BufferConstantOp::build(OpBuilder &builder, OperationState &state, + StringRef value) { + state.addTypes({builder.getType()}); + state.addAttribute("value", builder.getStringAttr(value)); +} + +void BufferConstantOp::build(OpBuilder &builder, OperationState &state, + ArrayRef value) { + state.addTypes({builder.getType()}); + state.addAttribute("value", + DenseIntElementsAttr::get( + VectorType::get(static_cast(value.size()), + builder.getI8Type()), + value)); +} + +// static +Value BufferConstantOp::createOrNull(OpBuilder &builder, Location loc, + Attribute value) { + if (!value) { + auto bufferType = builder.getType(); + return IREE::Util::NullOp::create(builder, loc, bufferType).getResult(); + } + return IREE::Util::BufferConstantOp::create(builder, loc, value); +} + +// static +Value BufferConstantOp::createOrNull(OpBuilder &builder, Location loc, + StringRef value) { + if (value.empty()) { + auto bufferType = builder.getType(); + return IREE::Util::NullOp::create(builder, loc, bufferType).getResult(); + } + return IREE::Util::BufferConstantOp::create(builder, loc, value); +} + +// static +Value BufferConstantOp::createOrNull(OpBuilder &builder, Location loc, + ArrayRef value) { + if (value.empty()) { + auto bufferType = builder.getType(); + return IREE::Util::NullOp::create(builder, loc, bufferType).getResult(); + } + return IREE::Util::BufferConstantOp::create(builder, loc, value); +} + LogicalResult BufferConstantOp::verify() { if (!isa(getValue())) { return emitOpError("unsupported non-serializable constant attribute type"); diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td index 208d2e958cc0..abb4c0667ecb 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td @@ -1418,6 +1418,20 @@ def Util_BufferConstantOp : Util_PureOp<"buffer.constant", [ ($name^)? attr-dict `:` type($result) `=` $value }]; + let builders = [ + OpBuilder<(ins "Attribute":$value)>, + OpBuilder<(ins "StringRef":$value)>, + OpBuilder<(ins "ArrayRef":$value)>, + ]; + + let extraClassDeclaration = [{ + // Returns a new buffer op with the given contents unless they are + // nullptr/empty in which case returns util.null. + static Value createOrNull(OpBuilder &builder, Location loc, Attribute value); + static Value createOrNull(OpBuilder &builder, Location loc, StringRef value); + static Value createOrNull(OpBuilder &builder, Location loc, ArrayRef value); + }]; + let hasVerifier = 1; } diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp index 1ea0d7f75794..c96075d5c2ba 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp @@ -168,10 +168,12 @@ bool isValueUsableForOp(Value value, Block *block, } if (definingBlock == block) { // Defined in the same block; ensure block order. - if (isa(value)) + if (isa(value)) { return true; - if (insertionPoint == block->end()) + } + if (insertionPoint == block->end()) { return true; + } if (value.getDefiningOp()->isBeforeInBlock(&*insertionPoint)) { return true; } @@ -255,12 +257,14 @@ Operation *materializeConstant(OpBuilder &builder, Location loc, bool isPublicOrExternal(CallableOpInterface callableOp) { if (auto symbolOp = dyn_cast(callableOp.getOperation())) { - if (symbolOp.isPublic()) + if (symbolOp.isPublic()) { return true; + } } auto *region = callableOp.getCallableRegion(); - if (!region || region->empty()) + if (!region || region->empty()) { return true; + } return false; } @@ -396,22 +400,27 @@ std::optional detail::getTiedResultOperandIndex(Operation *op, unsigned resultIndex) { auto storageAttr = op->getAttrOfType( IREE::Util::TiedOpInterface::getStorageAttrName()); - if (!storageAttr) + if (!storageAttr) { return std::nullopt; + } auto valueAttrs = storageAttr.getValue(); - if (valueAttrs.empty()) + if (valueAttrs.empty()) { return std::nullopt; + } if (auto tiedOp = dyn_cast(op)) { auto indexAndLength = tiedOp.getTiedResultsIndexAndLength(); - if (resultIndex < indexAndLength.first) + if (resultIndex < indexAndLength.first) { return std::nullopt; + } resultIndex -= indexAndLength.first; - if (resultIndex >= indexAndLength.second) + if (resultIndex >= indexAndLength.second) { return std::nullopt; + } } int64_t value = cast(valueAttrs[resultIndex]).getInt(); - if (value == IREE::Util::TiedOpInterface::kUntiedIndex) + if (value == IREE::Util::TiedOpInterface::kUntiedIndex) { return std::nullopt; + } if (auto tiedOp = dyn_cast(op)) { unsigned tiedOperandsOffset = tiedOp.getTiedOperandsIndexAndLength().first; return tiedOperandsOffset + static_cast(value); @@ -436,8 +445,9 @@ void detail::setTiedResultOperandIndex(Operation *op, unsigned resultIndex, // returned by `getTiedOperandsIndexAndLength`. unsigned tiedOperandsOffset = tiedOp.getTiedOperandsIndexAndLength().first; for (auto &index : indices) { - if (index != TiedOpInterface::kUntiedIndex) + if (index != TiedOpInterface::kUntiedIndex) { index -= tiedOperandsOffset; + } } } @@ -451,11 +461,13 @@ SmallVector detail::getTiedResultOperandIndices(Operation *op) { SmallVector indices; auto storageAttr = op->getAttrOfType( IREE::Util::TiedOpInterface::getStorageAttrName()); - if (!storageAttr) + if (!storageAttr) { return indices; + } auto valueAttrs = storageAttr.getValue(); - if (valueAttrs.empty()) + if (valueAttrs.empty()) { return indices; + } auto tiedOp = cast(op); auto resultRange = tiedOp.getTiedResultsIndexAndLength(); unsigned tiedOperandsOffset = tiedOp.getTiedOperandsIndexAndLength().first; @@ -475,8 +487,9 @@ Value TiedOpInterface::findTiedBaseValue(Value derivedValue) { while (auto definingOp = dyn_cast_if_present( baseValue.getDefiningOp())) { auto tiedValue = definingOp.getTiedResultOperand(baseValue); - if (!tiedValue) + if (!tiedValue) { break; + } baseValue = tiedValue; } return baseValue; @@ -503,8 +516,9 @@ bool detail::isOperandTied(Operation *op, unsigned operandIndex) { SmallVector detail::getOperandTiedResults(Operation *op, unsigned operandIndex) { auto tiedOp = dyn_cast(op); - if (!tiedOp) + if (!tiedOp) { return {}; + } auto resultRange = tiedOp.getTiedResultsIndexAndLength(); SmallVector results; auto tiedIndices = tiedOp.getTiedResultOperandIndices(); @@ -518,8 +532,9 @@ SmallVector detail::getOperandTiedResults(Operation *op, LogicalResult detail::verifyTiedOp(IREE::Util::TiedOpInterface tiedOp) { auto tiedOperandIndices = tiedOp.getTiedResultOperandIndices(); - if (tiedOperandIndices.empty()) + if (tiedOperandIndices.empty()) { return success(); + } auto resultRange = tiedOp.getTiedResultsIndexAndLength(); if (tiedOperandIndices.size() != resultRange.second) { return tiedOp.emitError("op results/tied operand indices mismatch"); @@ -566,8 +581,9 @@ void excludeTiedOperandAndResultIndices( // Count up the number of removed operands prior to this one. unsigned offset = 0; for (unsigned i = 0; i < tiedOperandIndex; ++i) { - if (i < excludedOperands.size() && excludedOperands[i]) + if (i < excludedOperands.size() && excludedOperands[i]) { ++offset; + } } tiedOperandIndex -= offset; @@ -591,16 +607,18 @@ Value SizeAwareTypeInterface::findSizeValue(Value resourceValue, Block *block, while (!worklist.empty()) { auto value = worklist.pop_back_val(); auto *definingOp = value.getDefiningOp(); - if (!definingOp) + if (!definingOp) { continue; + } if (auto sizeAwareOp = dyn_cast(definingOp)) { return sizeAwareOp.getResultSizeFromValue(value); } if (auto tiedOp = dyn_cast(definingOp)) { auto tiedOperand = tiedOp.getTiedResultOperand(value); - if (tiedOperand) + if (tiedOperand) { worklist.push_back(tiedOperand); + } } } @@ -663,8 +681,9 @@ std::optional findDynamicDims(Value workValue) { // {|block|, |insertionPoint|} implicitly. while (workValue) { auto workOp = workValue.getDefiningOp(); - if (!workOp) + if (!workOp) { break; + } if (auto shapeAwareOp = dyn_cast(workOp)) { return shapeAwareOp.getResultDynamicDimsFromValue(workValue); @@ -708,8 +727,9 @@ std::optional findDynamicDims(Value shapedValue, Block *block, // Look up the use-def chain: always safe, as any value we reach dominates // {|block|, |insertionPoint|} implicitly. auto upwardRange = findDynamicDims(shapedValue); - if (upwardRange.has_value()) + if (upwardRange.has_value()) { return upwardRange.value(); + } // Look down the use-def chain: not safe at some point because we'll move past // where {|block|, |insertionPoint|} is dominated. This is often fine for a @@ -747,8 +767,9 @@ ValueRange findDynamicDimsInList(unsigned idx, ValueRange values, } else if (isa(value.getType())) { dynamicDimCount = 1; } - if (!dynamicDimCount) + if (!dynamicDimCount) { return ValueRange{}; + } // Find where the dynamic dims start in the flattened list. unsigned offset = 0; diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.h b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.h index 6b1307a55f72..ecc630c043db 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.h +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.h @@ -233,10 +233,12 @@ class IntegerDivisibility { static IntegerDivisibility join(const IntegerDivisibility &lhs, const IntegerDivisibility &rhs) { - if (lhs.isUninitialized()) + if (lhs.isUninitialized()) { return rhs; - if (rhs.isUninitialized()) + } + if (rhs.isUninitialized()) { return lhs; + } return IntegerDivisibility(lhs.getValue().getUnion(rhs.getValue())); } diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.td index 8a158100c80e..add5dbbef6df 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.td +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.td @@ -58,6 +58,9 @@ def Util_ListType : Util_TypeDef<"List"> { ); let builders = [ + TypeBuilder<(ins), [{ + return $_get($_ctxt, IREE::Util::VariantType::get($_ctxt)); + }]>, TypeBuilderWithInferredContext<(ins "Type":$element_type ), [{ diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD.bazel index a1c60408471f..abb8771cf69c 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "alignment_folding.mlir", "alignment_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Util/TransformOps/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/TransformOps/BUILD.bazel index 5abb54d70073..143b5be5ee62 100644 --- a/compiler/src/iree/compiler/Dialect/Util/TransformOps/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/TransformOps/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "UtilTransformOps.td", ], diff --git a/compiler/src/iree/compiler/Dialect/Util/TransformOps/UtilTransformOps.cpp b/compiler/src/iree/compiler/Dialect/Util/TransformOps/UtilTransformOps.cpp index 7cbae6aea599..c10a59b3252c 100644 --- a/compiler/src/iree/compiler/Dialect/Util/TransformOps/UtilTransformOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/TransformOps/UtilTransformOps.cpp @@ -75,11 +75,13 @@ IREE::Util::transform_dialect::CreateSerializedModuleOp::apply( DiagnosedSilenceableFailure result = state.applyTransform(cast(transform)); // TODO: Support better error propagation. - if (result.isSilenceableFailure()) + if (result.isSilenceableFailure()) { return DiagnosedSilenceableFailure::definiteFailure(); + } // Pass through the error message from definite failures. - if (result.isDefiniteFailure()) + if (result.isDefiniteFailure()) { return result; + } } // Serialize the module as bytecode to a string. @@ -280,13 +282,15 @@ DiagnosedSilenceableFailure IREE::Util::transform_dialect::CastAndCallOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SmallVector inputs; - if (getInputs()) + if (getInputs()) { llvm::append_range(inputs, state.getPayloadValues(getInputs())); + } SetVector outputs; if (getOutputs()) { - for (auto output : state.getPayloadValues(getOutputs())) + for (auto output : state.getPayloadValues(getOutputs())) { outputs.insert(output); + } // Verify that the set of output values to be replaced is unique. if (outputs.size() != @@ -386,10 +390,11 @@ DiagnosedSilenceableFailure IREE::Util::transform_dialect::CastAndCallOp::apply( } } - if (insertAfter) + if (insertAfter) { rewriter.setInsertionPointAfter(insertionPoint); - else + } else { rewriter.setInsertionPoint(insertionPoint); + } for (auto [input, type] : llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) { @@ -504,12 +509,15 @@ LogicalResult IREE::Util::transform_dialect::CastAndCallOp::verify() { void IREE::Util::transform_dialect::CastAndCallOp::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getInsertionPointMutable(), effects); - if (getInputs()) + if (getInputs()) { transform::onlyReadsHandle(getInputsMutable(), effects); - if (getOutputs()) + } + if (getOutputs()) { transform::onlyReadsHandle(getOutputsMutable(), effects); - if (getFunction()) + } + if (getFunction()) { transform::onlyReadsHandle(getFunctionMutable(), effects); + } transform::producesHandle(getOperation()->getOpResults(), effects); transform::modifiesPayload(effects); } diff --git a/compiler/src/iree/compiler/Dialect/Util/TransformOps/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/TransformOps/test/BUILD.bazel index 76afbf00d106..138cbadc3b25 100644 --- a/compiler/src/iree/compiler/Dialect/Util/TransformOps/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/TransformOps/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "create_serialized_module.mlir", "symbol_transforms.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel index f8bedfcc836c..997dfe0f5742 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel @@ -43,6 +43,7 @@ iree_compiler_cc_library( "StripDebugOps.cpp", "TestConversion.cpp", "TestFloatRangeAnalysis.cpp", + "TestIntegerDivisibilityAnalysis.cpp", "VerifyInitializationOrder.cpp", "VerifyStructuredControlFlow.cpp", ], diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt index f73fa2a9bca9..e64d47b6ab6d 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt @@ -41,6 +41,7 @@ iree_cc_library( "StripDebugOps.cpp" "TestConversion.cpp" "TestFloatRangeAnalysis.cpp" + "TestIntegerDivisibilityAnalysis.cpp" "VerifyInitializationOrder.cpp" "VerifyStructuredControlFlow.cpp" DEPS diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp index 343e2b29a755..dfa3fdeaade2 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp @@ -29,8 +29,13 @@ struct DropCompilerHintsPass op.replaceAllUsesWith(op.getOperands()); op.erase(); } else if (auto op = dyn_cast(genericOp)) { - if (keepAssumeInt) + // TODO(benvanik): #19348 was a terrible approach and this needs to be + // undone. If LLVMGPU wants to keep the hints it should have its own + // codegen op that carries the information. DropCompilerHints is meant + // to drop all compiler hints. + if (keepAssumeInt) { return; + } op.replaceAllUsesWith(op.getOperands()); op.erase(); } diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/FixedPointIterator.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/FixedPointIterator.cpp index 9690b148edd7..641c280fd74a 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/FixedPointIterator.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/FixedPointIterator.cpp @@ -61,10 +61,12 @@ FixedPointIteratorPass::FixedPointIteratorPass(OpPassManager pipeline) LogicalResult FixedPointIteratorPass::initializeOptions( StringRef options, function_ref errorHandler) { - if (failed(Pass::initializeOptions(options, errorHandler))) + if (failed(Pass::initializeOptions(options, errorHandler))) { return failure(); - if (pipeline) + } + if (pipeline) { return success(); + } // Pipelines are expected to be of the form `()`. // TODO: This was lifted from the Inliner pass. We should provide a parse @@ -73,12 +75,14 @@ LogicalResult FixedPointIteratorPass::initializeOptions( // See: https://github.com/llvm/llvm-project/issues/52813 StringRef pipelineSr = pipelineStr; size_t pipelineStart = pipelineSr.find_first_of('('); - if (pipelineStart == StringRef::npos || !pipelineSr.consume_back(")")) + if (pipelineStart == StringRef::npos || !pipelineSr.consume_back(")")) { return failure(); + } StringRef opName = pipelineSr.take_front(pipelineStart); OpPassManager pm(opName); - if (failed(parsePassPipeline(pipelineSr.drop_front(1 + pipelineStart), pm))) + if (failed(parsePassPipeline(pipelineSr.drop_front(1 + pipelineStart), pm))) { return failure(); + } pipeline = std::move(pm); return success(); } diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/FuseGlobals.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/FuseGlobals.cpp index 38c4a801035f..6721c172b6fc 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/FuseGlobals.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/FuseGlobals.cpp @@ -80,8 +80,9 @@ class FuseGlobalsPass : public impl::FuseGlobalsPassBase { llvm::dbgs() << ":\n"; }); auto *region = callableOp.getCallableRegion(); - if (!region) + if (!region) { continue; + } for (auto &block : *region) { DenseMap> valueStores; @@ -93,8 +94,9 @@ class FuseGlobalsPass : public impl::FuseGlobalsPassBase { storeOp.print(llvm::dbgs(), *asmState); llvm::dbgs() << "; candidate=" << global.isCandidate() << "\n"; }); - if (!global.isCandidate()) + if (!global.isCandidate()) { continue; + } valueStores[storeOp.getStoredGlobalValue()].push_back(storeOp); } for (auto valueStore : valueStores) { diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp index 98fd7b413c14..007552e38259 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp @@ -49,8 +49,9 @@ static std::string getHoistedName(Type type) { type.print(os); } str = sanitizeSymbolName(str); - if (str.substr(str.size() - 1) == "_") + if (str.substr(str.size() - 1) == "_") { str = str.substr(0, str.size() - 1); // strip trailing _ + } return str; } @@ -97,24 +98,36 @@ class HoistIntoGlobalsPass file.close(); } - // Maps original values to newly materialized values. - HoistedValueMap hoistedMap; - // Walk all operations in the program and hoist any escapes from // const-expr values into globals. Note that we must walk the const-exprs // in topological order so that corresponding initializers will be created // in order without depending on globals that have not been initialized // yet. + OpBuilder builder(&getContext()); for (auto funcOp : getOperation().getOps()) { // Ignore initializers. - if (isa(funcOp.getOperation())) + if (isa(funcOp.getOperation())) { continue; + } + + // Maps original values to newly materialized globals (per-function). + HoistedValueMap hoistedMap; + + // Operation order for deterministic sorting (per-function). + llvm::DenseMap opOrder; + unsigned orderIdx = 0; + auto walkRes = funcOp.walk([&](Operation *iterOp) { // We only want to look at const-expr ops (non roots) since they may // have interesting escapes. Early exit here for efficiency. auto *iterInfo = constExprs.lookup(iterOp); - if (!iterInfo) + if (!iterInfo) { return WalkResult::advance(); + } + + // Record operation order for deterministic sorting. Since we walk in + // PreOrder, producers are visited before their users. + opOrder[iterOp] = orderIdx++; for (Value constExprResult : iterOp->getResults()) { auto *resultInfo = constExprs.lookup(constExprResult); assert(resultInfo && "must have const-expr info"); @@ -123,43 +136,51 @@ class HoistIntoGlobalsPass continue; } if (failed(hoistConstExpr(constExprResult, hoistedMap, moduleSymbols, - constExprs))) { + constExprs, opOrder))) { return WalkResult::interrupt(); } } return WalkResult::advance(); }); - if (walkRes.wasInterrupted()) + if (walkRes.wasInterrupted()) { return signalPassFailure(); - } - - // Apply any remaining RAUW cleanups. We have to do these at the cleanup - // phase since modifying the source program can invalidate the analysis. - // Up to this point, we have only been cloning. - OpBuilder builder(&getContext()); - for (auto [originalValue, globalOp] : hoistedMap) { - builder.setInsertionPointAfterValue(originalValue); - auto loadOp = globalOp.createLoadOp(globalOp->getLoc(), builder); - if (!originalValue.getDefiningOp() - ->getParentOfType()) { - loadOp.setGlobalImmutable(true); - } - Value loadedValue = loadOp.getLoadedGlobalValue(); - // Call user hook to cast back to the original type. - if (auto hoistableType = dyn_cast( - originalValue.getType())) { - loadedValue = hoistableType.decodeStorageType( - builder, loadedValue.getLoc(), originalValue.getType(), - loadedValue); } - if (loadedValue.getType() != originalValue.getType()) { - getOperation().emitError() - << "Unresolved conflict between casted global of type " - << loadedValue.getType() << " and original type " - << originalValue.getType(); - return signalPassFailure(); + + // Apply RAUW cleanups for this function. We do this after cloning to + // avoid invalidating the analysis during the walk. + // Sort the hoisted values by program order for deterministic output. + using HoistedValue = std::pair; + auto sortedHoisted = llvm::to_vector_of(hoistedMap); + llvm::sort(sortedHoisted, + [&opOrder](const HoistedValue &lhs, const HoistedValue &rhs) { + return opOrder[lhs.first.getDefiningOp()] < + opOrder[rhs.first.getDefiningOp()]; + }); + + for (auto [originalValue, globalOp] : sortedHoisted) { + builder.setInsertionPointAfterValue(originalValue); + auto loadOp = globalOp.createLoadOp(globalOp->getLoc(), builder); + if (!originalValue.getDefiningOp() + ->getParentOfType()) { + loadOp.setGlobalImmutable(true); + } + Value loadedValue = loadOp.getLoadedGlobalValue(); + // Call user hook to cast back to the original type. + if (auto hoistableType = dyn_cast( + originalValue.getType())) { + loadedValue = hoistableType.decodeStorageType( + builder, loadedValue.getLoc(), originalValue.getType(), + loadedValue); + } + if (loadedValue.getType() != originalValue.getType()) { + getOperation().emitError() + << "Unresolved conflict between casted global of type " + << loadedValue.getType() << " and original type " + << originalValue.getType(); + return signalPassFailure(); + } + originalValue.replaceAllUsesWith(loadedValue); } - originalValue.replaceAllUsesWith(loadedValue); } cleanupDeadOps(constExprs); } @@ -167,17 +188,21 @@ class HoistIntoGlobalsPass Operation *getTopLevelOp(Operation *childOp) { auto *moduleBlock = getOperation().getBody(); auto *op = childOp; - while (op->getBlock() != moduleBlock) + while (op->getBlock() != moduleBlock) { op = op->getParentOp(); + } return op; } - LogicalResult hoistConstExpr(Value originalValue, HoistedValueMap &hoistedMap, - SymbolTable &moduleSymbols, - const ConstExprAnalysis &constExprs) { + LogicalResult + hoistConstExpr(Value originalValue, HoistedValueMap &hoistedMap, + SymbolTable &moduleSymbols, + const ConstExprAnalysis &constExprs, + const llvm::DenseMap &opOrder) { IREE::Util::GlobalOp existingGlobal = hoistedMap.lookup(originalValue); - if (existingGlobal) + if (existingGlobal) { return success(); + } // Gather any dialect attributes we may need to preserve. auto *topLevelOp = getTopLevelOp(originalValue.getDefiningOp()); @@ -196,7 +221,7 @@ class HoistIntoGlobalsPass if (failed(cloneConstExprInto(initializerOp.getLoc(), moduleBuilder, initializerBuilder, originalValue, dialectAttrs, hoistedMap, moduleSymbols, - constExprs))) { + constExprs, opOrder))) { return failure(); } @@ -212,9 +237,11 @@ class HoistIntoGlobalsPass cloneProducerTreeInto(OpBuilder &initializerBuilder, const ConstExprAnalysis::ConstValueInfo *producerInfo, HoistedValueMap &hoistedMap, IRMapping &cloneMapping, - const ConstExprAnalysis &constExprs) { - if (cloneMapping.contains(producerInfo->constValue)) + const ConstExprAnalysis &constExprs, + const llvm::DenseMap &opOrder) { + if (cloneMapping.contains(producerInfo->constValue)) { return; + } // We either have a global associated already or we need to traverse // down and materialize producers. @@ -236,10 +263,20 @@ class HoistIntoGlobalsPass return; } - // Materialize all producers recursively. - for (auto *producerInfo : producerInfo->producers) { - cloneProducerTreeInto(initializerBuilder, producerInfo, hoistedMap, - cloneMapping, constExprs); + // Materialize all producers recursively. Sort producers by their program + // order for deterministic output. + auto sortedProducers = + llvm::to_vector_of( + producerInfo->producers); + llvm::sort(sortedProducers, + [&opOrder](ConstExprAnalysis::ConstValueInfo *lhs, + ConstExprAnalysis::ConstValueInfo *rhs) { + return opOrder.lookup(lhs->constValue.getDefiningOp()) < + opOrder.lookup(rhs->constValue.getDefiningOp()); + }); + for (ConstExprAnalysis::ConstValueInfo *prodInfo : sortedProducers) { + cloneProducerTreeInto(initializerBuilder, prodInfo, hoistedMap, + cloneMapping, constExprs, opOrder); } // And clone the requested op. @@ -257,13 +294,13 @@ class HoistIntoGlobalsPass // Clones the const expr tree rooted at `constExprValue` into the given // initializer, noting any new hoisted value mappings that result. At // a minimum, a mapping will be created for the requested value. - LogicalResult cloneConstExprInto(Location loc, OpBuilder &moduleBuilder, - OpBuilder &initializerBuilder, - Value constExprValue, - NamedAttrList dialectAttrs, - HoistedValueMap &hoistedMap, - SymbolTable &moduleSymbols, - const ConstExprAnalysis &constExprs) { + LogicalResult + cloneConstExprInto(Location loc, OpBuilder &moduleBuilder, + OpBuilder &initializerBuilder, Value constExprValue, + NamedAttrList dialectAttrs, HoistedValueMap &hoistedMap, + SymbolTable &moduleSymbols, + const ConstExprAnalysis &constExprs, + const llvm::DenseMap &opOrder) { // Do a depth first traversal of the producers, emitting them in a valid // def-use order. Operation *rootOp = constExprValue.getDefiningOp(); @@ -274,7 +311,7 @@ class HoistIntoGlobalsPass // Clone the whole tree as needed. IRMapping cloneMapping; cloneProducerTreeInto(initializerBuilder, rootInfo, hoistedMap, - cloneMapping, constExprs); + cloneMapping, constExprs, opOrder); // And for each result, create a global and store into it. for (Value origResult : rootOp->getResults()) { @@ -331,8 +368,9 @@ class HoistIntoGlobalsPass // longer be valid after this point. for (auto funcOp : getOperation().getOps()) { // Ignore initializers. - if (isa(funcOp.getOperation())) + if (isa(funcOp.getOperation())) { continue; + } funcOp.walk( [&](Operation *iterOp) { if (allOps.contains(iterOp) && iterOp->use_empty()) { diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/IPO.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/IPO.cpp index 553742032256..691882708f9d 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/IPO.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/IPO.cpp @@ -191,8 +191,9 @@ static FuncAnalysis analyzeFuncOp(IREE::Util::FuncOp funcOp, // Walk callee arguments. for (auto [i, value] : llvm::enumerate(funcOp.getArguments())) { - if (value.use_empty()) + if (value.use_empty()) { analysis.calleeUsedArgs.reset(i); + } } // Walk all return sites in the function. @@ -327,8 +328,9 @@ static FuncAnalysis analyzeFuncOp(IREE::Util::FuncOp funcOp, // Note that we need to track unused results as an AND such that all callers // need to not use them. We'll flip the bits below so that `used = true`. for (auto [i, value] : llvm::enumerate(callOp.getResults())) { - if (!value.use_empty()) + if (!value.use_empty()) { callerUnusedResults.reset(i); + } } } if (!analysis.callOps.empty()) { @@ -376,8 +378,9 @@ static FuncAnalysis analyzeFuncOp(IREE::Util::FuncOp funcOp, // we know all callers will stop passing them. for (unsigned i = 0; i < resultCount; ++i) { int argIndex = analysis.passthroughResultArgs[i]; - if (argIndex == kUnassigned) + if (argIndex == kUnassigned) { continue; + } auto arg = funcOp.getArgument(argIndex); bool onlyReturnUsers = true; for (auto user : arg.getUsers()) { @@ -518,14 +521,16 @@ static bool applyFuncChanges(FuncAnalysis &analysis, } // Early out if no changes. - if (deadArgs.none() && deadResults.none()) + if (deadArgs.none() && deadResults.none()) { return false; + } // Erase dead results from all return sites. funcOp.walk([&](IREE::Util::ReturnOp returnOp) { for (int i = deadResults.size() - 1; i >= 0; --i) { - if (deadResults.test(i)) + if (deadResults.test(i)) { returnOp.getOperandsMutable().erase(i); + } } }); @@ -612,8 +617,9 @@ static bool applyCallChanges(FuncAnalysis &analysis, } // Early out if no changes. - if (deadOperands.none() && deadResults.none()) + if (deadOperands.none() && deadResults.none()) { return false; + } // Fully replace call op because we may have changed result count. // TODO(benvanik): update tied operands, arg_attrs, and res_attrs. diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/ImportResources.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/ImportResources.cpp index 966ef7f02843..0d8bc5882804 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/ImportResources.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/ImportResources.cpp @@ -102,8 +102,9 @@ class ImportResourcesPass } } } - if (updated) + if (updated) { op->setAttrs(attrs); + } }); LLVM_DEBUG(llvm::dbgs() << "DONE CONVERTING RESOURCES\n"); } diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp index 647183f3f85b..49107588a034 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp @@ -99,8 +99,9 @@ struct ConvertOpToUnsigned : public OpRewritePattern { LogicalResult matchAndRewrite(Signed op, PatternRewriter &rewriter) const override { - if (failed(staticallyLegalToConvertToUnsignedOp(solver, op))) + if (failed(staticallyLegalToConvertToUnsignedOp(solver, op))) { return failure(); + } rewriter.replaceOpWithNewOp(op, op->getResultTypes(), op->getOperands(), op->getAttrs()); return success(); @@ -135,15 +136,18 @@ struct ConvertUnsignedI64IndexCastProducerToIndex PatternRewriter &rewriter) const override { Type inType = origIndexOp.getIn().getType(); Type outType = origIndexOp.getOut().getType(); - if (!inType.isSignlessInteger(64) || !isa(outType)) + if (!inType.isSignlessInteger(64) || !isa(outType)) { return failure(); + } Operation *producer = origIndexOp.getIn().getDefiningOp(); - if (!producer) + if (!producer) { return failure(); + } auto producerResult = producer->getResult(0); - if (!producerResult.hasOneUse()) + if (!producerResult.hasOneUse()) { return failure(); + } auto pred = [&](Value v) -> bool { auto *result = solver.lookupState(v); @@ -163,17 +167,20 @@ struct ConvertUnsignedI64IndexCastProducerToIndex if (!isa_and_present(producer)) + arith::RemUIOp, arith::SubIOp>(producer)) { return failure(); - if (!isOpStaticallyLegal(producer)) + } + if (!isOpStaticallyLegal(producer)) { return failure(); + } // Make modifications. rewriter.modifyOpInPlace(producer, [&]() { rewriter.setInsertionPoint(producer); for (auto &operand : producer->getOpOperands()) { - if (operand.get().getType() != inType) + if (operand.get().getType() != inType) { continue; + } Value newOperand = arith::IndexCastUIOp::create( rewriter, producer->getLoc(), outType, operand.get()); operand.set(newOperand); @@ -204,20 +211,24 @@ struct RemoveIndexCastForAssumeOfI32 PatternRewriter &rewriter) const override { llvm::SmallBitVector needNarrowing(op.getNumOperands(), false); for (auto [idx, arg] : llvm::enumerate(op.getOperands())) { - if (!arg.getType().isIndex()) + if (!arg.getType().isIndex()) { continue; + } auto castOp = arg.getDefiningOp(); - if (!castOp) + if (!castOp) { continue; + } Value castIn = castOp.getIn(); Type intType = castIn.getType(); - if (intType.getIntOrFloatBitWidth() > 32) + if (intType.getIntOrFloatBitWidth() > 32) { continue; + } needNarrowing[idx] = true; } - if (needNarrowing.none()) + if (needNarrowing.none()) { return failure(); + } SmallVector newArgs; newArgs.reserve(op.getNumOperands()); @@ -267,22 +278,27 @@ struct NarrowSCFForIvToI32 : public OpRewritePattern { Location loc = forOp.getLoc(); Value iv = forOp.getInductionVar(); Type srcType = iv.getType(); - if (!srcType.isIndex() && !srcType.isInteger(64)) + if (!srcType.isIndex() && !srcType.isInteger(64)) { return rewriter.notifyMatchFailure(forOp, "IV isn't an index or i64"); - if (!staticallyLegalToConvertToUnsigned(solver, iv)) + } + if (!staticallyLegalToConvertToUnsigned(solver, iv)) { return rewriter.notifyMatchFailure(forOp, "IV isn't non-negative"); - if (!staticallyLegalToConvertToUnsigned(solver, forOp.getStep())) + } + if (!staticallyLegalToConvertToUnsigned(solver, forOp.getStep())) { return rewriter.notifyMatchFailure(forOp, "Step isn't non-negative"); + } auto *ivState = solver.lookupState(iv); - if (ivState->getValue().getValue().smax().getActiveBits() > 31) + if (ivState->getValue().getValue().smax().getActiveBits() > 31) { return rewriter.notifyMatchFailure(forOp, "IV won't fit in signed int32"); + } Type i32 = rewriter.getI32Type(); auto doCastDown = [&](Value v) -> Value { - if (srcType.isIndex()) + if (srcType.isIndex()) { return arith::IndexCastUIOp::create(rewriter, loc, i32, v); - else + } else { return arith::TruncIOp::create(rewriter, loc, i32, v); + } }; Value newLb = doCastDown(forOp.getLowerBound()); Value newUb = doCastDown(forOp.getUpperBound()); @@ -322,9 +338,10 @@ static LogicalResult getDivisibility(DataFlowSolver &solver, Operation *op, Value value, PatternRewriter &rewriter, ConstantIntDivisibility &out) { auto *div = solver.lookupState(value); - if (!div || div->getValue().isUninitialized()) + if (!div || div->getValue().isUninitialized()) { return rewriter.notifyMatchFailure(op, "divisibility could not be determined"); + } out = div->getValue().getValue(); LLVM_DEBUG(dbgs() << " * Resolved divisibility: " << out << "\n"); @@ -338,17 +355,20 @@ struct RemUIDivisibilityByConstant : public OpRewritePattern { LogicalResult matchAndRewrite(arith::RemUIOp op, PatternRewriter &rewriter) const override { APInt rhsConstant; - if (!matchPattern(op.getRhs(), m_ConstantInt(&rhsConstant))) + if (!matchPattern(op.getRhs(), m_ConstantInt(&rhsConstant))) { return rewriter.notifyMatchFailure(op, "rhs is not constant"); + } ConstantIntDivisibility lhsDiv; - if (failed(getDivisibility(solver, op, op.getLhs(), rewriter, lhsDiv))) + if (failed(getDivisibility(solver, op, op.getLhs(), rewriter, lhsDiv))) { return failure(); + } uint64_t rhsValue = rhsConstant.getZExtValue(); if (rhsValue > 0 && lhsDiv.udiv() > 0) { - if (lhsDiv.udiv() % rhsValue != 0) + if (lhsDiv.udiv() % rhsValue != 0) { return rewriter.notifyMatchFailure(op, "rhs does not divide lhs"); + } rewriter.replaceOpWithNewOp( op, rewriter.getZeroAttr(op.getResult().getType())); @@ -397,10 +417,12 @@ struct ElideTruncOfIndexCast : public OpRewritePattern { LogicalResult matchAndRewrite(arith::TruncIOp truncOp, PatternRewriter &rewriter) const override { Operation *producer = truncOp.getOperand().getDefiningOp(); - if (!producer) + if (!producer) { return failure(); - if (!isa(producer)) + } + if (!isa(producer)) { return failure(); + } rewriter.replaceOpWithNewOp( truncOp, truncOp.getResult().getType(), producer->getOperand(0)); return success(); @@ -418,8 +440,9 @@ class DataFlowListener : public RewriterBase::Listener { protected: void notifyOperationErased(Operation *op) override { s.eraseState(s.getProgramPointAfter(op)); - for (Value res : op->getResults()) + for (Value res : op->getResults()) { s.eraseState(res); + } } void notifyOperationModified(Operation *op) override { @@ -463,8 +486,9 @@ class OptimizeIntArithmeticPass // Populate canonicalization patterns. auto arithDialect = ctx->getOrLoadDialect(); for (const RegisteredOperationName &name : ctx->getRegisteredOperations()) { - if (&name.getDialect() == arithDialect) + if (&name.getDialect() == arithDialect) { name.getCanonicalizationPatterns(patterns, ctx); + } } // General optimization patterns. @@ -513,8 +537,9 @@ class OptimizeIntArithmeticPass return signalPassFailure(); } - if (!changed) + if (!changed) { break; + } } } }; diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h index 5445337ddf1a..188c5457b7d0 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h @@ -73,6 +73,7 @@ createHoistIntoGlobalsPass(const ExprHoistingOptions &options); #define GEN_PASS_DECL_STRIPDEBUGOPSPASS #define GEN_PASS_DECL_TESTCONVERSIONPASS #define GEN_PASS_DECL_TESTFLOATRANGEANALYSISPASS +#define GEN_PASS_DECL_TESTINTEGERDIVISIBILITYANALYSISPASS #define GEN_PASS_DECL_VERIFYINITIALIZATIONORDERPASS #define GEN_PASS_DECL_VERIFYSTRUCTUREDCONTROLFLOWPASS #include "iree/compiler/Dialect/Util/Transforms/Passes.h.inc" // IWYU pragma: keep diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td index 7093bed69d38..b3f46f78add6 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td @@ -346,4 +346,14 @@ def TestFloatRangeAnalysisPass : Pass<"iree-util-test-float-range-analysis", ""> }]; } +def TestIntegerDivisibilityAnalysisPass : + Pass<"iree-util-test-integer-divisibility-analysis", ""> { + let summary = "Tests integer divisibility analysis."; + let description = [{ + Tests integer divisibility analysis by evaluating any + 'iree_unregistered.test_int_divisibility' op and setting the results on an + attribute. + }]; +} + #endif // IREE_DIALECT_UTIL_PASSES diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/Patterns.cpp index 6c1d0305106c..01a9d460dcb7 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Patterns.cpp @@ -118,11 +118,13 @@ struct FoldBlockArgumentsPattern using OpInterfaceRewritePattern::OpInterfaceRewritePattern; LogicalResult matchAndRewrite(CallableOpInterface op, PatternRewriter &rewriter) const override { - if (!op.getCallableRegion()) + if (!op.getCallableRegion()) { return failure(); + } auto ®ion = *op.getCallableRegion(); - if (region.empty() || region.hasOneBlock()) + if (region.empty() || region.hasOneBlock()) { return failure(); + } // Analyze all branches in the op to compute the information we'll need to // analyze across branch sources. @@ -171,11 +173,13 @@ struct FoldBlockArgumentsPattern for (auto &block : llvm::make_range(++region.getBlocks().begin(), region.getBlocks().end())) { unsigned numArgs = block.getNumArguments(); - if (numArgs == 0) + if (numArgs == 0) { continue; + } auto blockSources = llvm::ArrayRef(blockSourceMap[&block]); - if (blockSources.size() == 0) + if (blockSources.size() == 0) { continue; + } // Which args we'll end up erasing. // We need to do the actual removal after we've done the remapping below @@ -263,11 +267,13 @@ struct ElideBranchOperandsPattern using OpInterfaceRewritePattern::OpInterfaceRewritePattern; LogicalResult matchAndRewrite(CallableOpInterface op, PatternRewriter &rewriter) const override { - if (!op.getCallableRegion()) + if (!op.getCallableRegion()) { return failure(); + } auto ®ion = *op.getCallableRegion(); - if (region.empty()) + if (region.empty()) { return failure(); + } DominanceInfo dominance(op); // Analyze all branches to build a map of blocks to their sources. @@ -298,11 +304,13 @@ struct ElideBranchOperandsPattern for (auto &block : llvm::make_range(++region.getBlocks().begin(), region.getBlocks().end())) { unsigned numArgs = block.getNumArguments(); - if (numArgs == 0) + if (numArgs == 0) { continue; + } auto blockSources = llvm::ArrayRef(blockSourceMap[&block]); - if (blockSources.size() == 0) + if (blockSources.size() == 0) { continue; + } // Which args we'll end up erasing. // We need to do the actual removal after we've done the remapping below @@ -342,8 +350,9 @@ struct ElideBranchOperandsPattern uniformValue = nullptr; break; } - if (!uniformValue) + if (!uniformValue) { continue; + } // See if the uniform value dominates this block; if so we can use it. if (!uniformValue.getDefiningOp() || @@ -354,8 +363,9 @@ struct ElideBranchOperandsPattern elidedArgs.set(argIndex); } } - if (elidedArgs.none()) + if (elidedArgs.none()) { continue; + } // Erase all the block arguments we remapped. for (auto &blockSource : blockSources) { @@ -407,8 +417,9 @@ struct IndexSwitchToIfPattern : public OpRewritePattern { using Base::Base; LogicalResult matchAndRewrite(scf::IndexSwitchOp switchOp, PatternRewriter &rewriter) const override { - if (switchOp.getNumCases() != 1) + if (switchOp.getNumCases() != 1) { return failure(); + } Value caseValue = arith::ConstantIndexOp::create( rewriter, switchOp.getLoc(), switchOp.getCases().front()); Value isCaseValue = rewriter.createOrFold( @@ -472,16 +483,19 @@ struct MergeIndexSwitchPattern : public OpRewritePattern { // Inspect the previous op to see if it's also a switch. auto prevOp = dyn_cast_if_present(nextOp->getPrevNode()); - if (!prevOp) + if (!prevOp) { return failure(); + } // Require that the cases line up exactly. There's probably some merging // we could do in other cases but it'd be best to leave other patterns to // hoist/CSE cases/etc instead. - if (prevOp.getNumCases() != nextOp.getNumCases()) + if (prevOp.getNumCases() != nextOp.getNumCases()) { return rewriter.notifyMatchFailure(nextOp, "number of cases differ"); - if (!llvm::equal(prevOp.getCases(), nextOp.getCases())) + } + if (!llvm::equal(prevOp.getCases(), nextOp.getCases())) { return rewriter.notifyMatchFailure(nextOp, "case values differ"); + } // Create a new switch to replace nextOp that contains the same cases but // combined results from both ops. @@ -518,8 +532,9 @@ struct MergeIndexSwitchPattern : public OpRewritePattern { // values for the particular case. auto yieldA = *regionA.getOps().begin(); for (auto &op : regionA.getOps()) { - if (op.hasTrait()) + if (op.hasTrait()) { continue; + } // Clone each op and map its original value to the new local value. targetBuilder.clone(op, localMapping); } @@ -534,8 +549,9 @@ struct MergeIndexSwitchPattern : public OpRewritePattern { // Clone regionB into target. auto yieldB = *regionB.getOps().begin(); for (auto &op : regionB.getOps()) { - if (op.hasTrait()) + if (op.hasTrait()) { continue; + } // Clone each op and map its original value to the new local value. targetBuilder.clone(op, localMapping); } diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubranges.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubranges.cpp index 15735e9a0243..c6fffa1f66f1 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubranges.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubranges.cpp @@ -65,8 +65,9 @@ static ExpandedGlobalMap expandResourceGlobals(Operation *rootOp, // Gather all of the resource globals in the root. for (auto ®ion : rootOp->getRegions()) { for (auto globalOp : region.getOps()) { - if (!isResourceType(globalOp.getType())) + if (!isResourceType(globalOp.getType())) { continue; + } expandedGlobals[globalOp.getName()].resourceOp = globalOp; } } @@ -127,8 +128,9 @@ static void expandType(Type type, SmallVectorImpl &newTypes) { // Expands resources in the given |types| list to (resource, size, offset, len). // This could be changed to some iterator magic to avoid the alloc. static SmallVector expandTypes(TypeRange types) { - if (types.empty()) + if (types.empty()) { return {}; + } SmallVector newTypes; newTypes.reserve(types.size() * 2); for (auto type : types) { @@ -221,14 +223,16 @@ static void expandSubranges(Operation *op, SymbolTable &symbolTable, static void expandRegion(Region ®ion, bool canModifyEntryBlock, SymbolTable &symbolTable, ExpandedGlobalMap &globalMap, IndexSet &indexSet, SubrangeMap subrangeMap) { - if (region.empty()) + if (region.empty()) { return; + } // Update all block arguments. auto indexType = IndexType::get(region.getContext()); for (auto &block : region.getBlocks()) { - if (!llvm::any_of(block.getArgumentTypes(), isResourceType)) + if (llvm::none_of(block.getArgumentTypes(), isResourceType)) { continue; + } // Entry blocks that we can't modify are fully handled by // MutableRegionBranchOpInterface (via wrapExpandedBlockArgFn callback). @@ -245,8 +249,9 @@ static void expandRegion(Region ®ion, bool canModifyEntryBlock, // Insert new arguments for each resource argument. for (int i = block.getNumArguments() - 1; i >= 0; --i) { auto arg = block.getArgument(i); - if (!isResourceType(arg.getType())) + if (!isResourceType(arg.getType())) { continue; + } Subrange subrange; subrange.resource = arg; subrange.resourceSize = @@ -306,10 +311,12 @@ static void updateSubrangeOp(IREE::Util::SubrangeOpInterface op, // Ignore ops that are already in the map (we likely inserted them ourselves // earlier). auto resultResource = op.getSubrangeResult(); - if (!resultResource) + if (!resultResource) { return; - if (subrangeMap.count(resultResource)) + } + if (subrangeMap.count(resultResource)) { return; + } // Get the subrange of the source resource which we should have by way of the // other insertions (func/block args, etc). @@ -317,8 +324,9 @@ static void updateSubrangeOp(IREE::Util::SubrangeOpInterface op, builder.setInsertionPointAfter(op); auto sourceSubrange = consumeSubrange(op.getLoc(), op.getSubrangeResource(), subrangeMap, indexSet, builder); - if (op.getSubrangeResource() == sourceSubrange.resource) + if (op.getSubrangeResource() == sourceSubrange.resource) { return; + } // Update the subrange in the map by adding the source offset and the local // offset from the op. Future ops that consume subranges will reference back @@ -347,8 +355,9 @@ static void updateSubrangeOp(IREE::Util::SubrangeOpInterface op, static void expandGlobalLoadOp(IREE::Util::GlobalLoadOpInterface op, ExpandedGlobalMap &globalMap, IndexSet &indexSet, SubrangeMap &subrangeMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; + } OpBuilder builder(op); builder.setInsertionPointAfter(op); auto &expandedGlobal = globalMap[op.getGlobalName()]; @@ -386,8 +395,9 @@ static void expandGlobalLoadOp(IREE::Util::GlobalLoadOpInterface op, static void expandGlobalStoreOp(IREE::Util::GlobalStoreOpInterface op, ExpandedGlobalMap &globalMap, IndexSet &indexSet, SubrangeMap &subrangeMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; + } OpBuilder builder(op); builder.setInsertionPointAfter(op); auto subrange = consumeSubrange(op.getLoc(), op.getStoredGlobalValue(), @@ -460,13 +470,15 @@ static void expandFuncOp(IREE::Util::FuncOp op, SymbolTable &symbolTable, // %2 = stream.resource.subview %r[%ro] : {%rsz} -> {%rl} static void expandCallOp(IREE::Util::CallOp op, SymbolTable &symbolTable, IndexSet &indexSet, SubrangeMap &subrangeMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; + } // Ignore calls to public/external functions. auto calleeOp = symbolTable.lookup(op.getCallee()); - if (IREE::Util::isPublicOrExternal(calleeOp)) + if (IREE::Util::isPublicOrExternal(calleeOp)) { return; + } // Build the new call op with expanded operands and results. OpBuilder builder(op); @@ -518,10 +530,13 @@ static void expandCallOp(IREE::Util::CallOp op, SymbolTable &symbolTable, // util.return %0, %sz, %o, %l static void expandReturnOp(IREE::Util::ReturnOp op, IndexSet &indexSet, SubrangeMap &subrangeMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; - if (IREE::Util::isPublicOrExternal(op->getParentOfType())) + } + if (IREE::Util::isPublicOrExternal( + op->getParentOfType())) { return; + } OpBuilder builder(op); auto operands = expandOperands(op.getLoc(), op.getOperands(), subrangeMap, indexSet, builder); @@ -551,8 +566,9 @@ static void expandBranchOp(mlir::cf::BranchOp op, IndexSet &indexSet, static void expandCondBranchOp(mlir::cf::CondBranchOp op, IndexSet &indexSet, SubrangeMap &subrangeMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; + } OpBuilder builder(op); mlir::cf::CondBranchOp::create( builder, op.getLoc(), op.getCondition(), op.getTrueDest(), @@ -568,8 +584,9 @@ static ValueRange asValueRange(ArrayRef values) { return values; } static void expandSwitchOp(mlir::cf::SwitchOp op, IndexSet &indexSet, SubrangeMap &subrangeMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; + } OpBuilder builder(op); auto caseOperands = llvm::to_vector( llvm::map_range(op.getCaseOperands(), [&](ValueRange operands) { @@ -742,8 +759,9 @@ class PropagateSubrangesPass // NOTE: the callable may be empty (like when an extern) - we still want // to process it but don't need an IndexSet. auto *region = callableOp.getCallableRegion(); - if (!region || region->empty()) + if (!region || region->empty()) { continue; + } IndexSet indexSet(callableOp.getLoc(), OpBuilder::atBlockBegin(®ion->front())); SubrangeMap subrangeMap; diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp index 62f84fdff442..cddc2e7ee12b 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp @@ -50,8 +50,9 @@ static void hoistImmutableLoads(Region ®ion, auto ops = llvm::to_vector<8>(block.getOps()); for (auto &op : ops) { - if (!immutableGlobals.contains(op.getGlobalName())) + if (!immutableGlobals.contains(op.getGlobalName())) { continue; + } auto globalRef = cast(op.getGlobalAttr()); auto it = loadOps.find(globalRef); if (it == loadOps.end()) { @@ -89,8 +90,9 @@ static bool doesOpBlockMotion(Operation *op) { static SetVector getOpsThatBlockMotion(Block &block) { SetVector ops; for (auto &op : block.getOperations()) { - if (doesOpBlockMotion(&op)) + if (doesOpBlockMotion(&op)) { ops.insert(&op); + } } return ops; } @@ -100,12 +102,14 @@ static void moveOpUpInBlock(Block &block, Operation *op, // Find the earliest node that does not block op motion then move before it. mlir::Operation *earliestValidNode = op; while (earliestValidNode->getPrevNode()) { - if (opsThatBlockMotion.contains(earliestValidNode->getPrevNode())) + if (opsThatBlockMotion.contains(earliestValidNode->getPrevNode())) { break; + } earliestValidNode = earliestValidNode->getPrevNode(); } - if (earliestValidNode != op) + if (earliestValidNode != op) { op->moveBefore(earliestValidNode); + } } static void @@ -114,12 +118,14 @@ moveOpDownInBlock(Block &block, Operation *op, // Find the latest node that does not block op motion then move after it. mlir::Operation *latestValidNode = op; while (latestValidNode->getNextNode()) { - if (opsThatBlockMotion.contains(latestValidNode->getNextNode())) + if (opsThatBlockMotion.contains(latestValidNode->getNextNode())) { break; + } latestValidNode = latestValidNode->getNextNode(); } - if (latestValidNode != op) + if (latestValidNode != op) { op->moveAfter(latestValidNode); + } } // Optimizes the load/store ops for each given bucket. @@ -176,8 +182,9 @@ optimizeBuckets(Block &block, didRemoveAny = true; } } - if (ops.empty()) + if (ops.empty()) { continue; + } if (auto loadOp = dyn_cast(ops.front())) { diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/TestIntegerDivisibilityAnalysis.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/TestIntegerDivisibilityAnalysis.cpp new file mode 100644 index 000000000000..21954d4ec0dc --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/TestIntegerDivisibilityAnalysis.cpp @@ -0,0 +1,68 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h" +#include "iree/compiler/Dialect/Util/Transforms/Passes.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlowFramework.h" + +namespace mlir::iree_compiler::IREE::Util { + +#define GEN_PASS_DEF_TESTINTEGERDIVISIBILITYANALYSISPASS +#include "iree/compiler/Dialect/Util/Transforms/Passes.h.inc" + +namespace { + +class TestIntegerDivisibilityAnalysisPass + : public impl::TestIntegerDivisibilityAnalysisPassBase< + TestIntegerDivisibilityAnalysisPass> { +public: + void runOnOperation() override { + Operation *rootOp = getOperation(); + MLIRContext *context = &getContext(); + + // The pass is rooted on `iree_unregistered.test_int_divisibility` ops, + // which are expected to have a single operand for which to annotate + // divisibility information. + SmallVector> queryOps; + rootOp->walk([&](Operation *op) { + if (op->getName().getStringRef() == + "iree_unregistered.test_int_divisibility" && + op->getNumOperands() == 1) { + queryOps.emplace_back(op, op->getOperand(0)); + } + }); + + DataFlowSolver solver; + // DeadCodeAnalysis is the base analysis that allows the solver to traverse + // control flow. We include it to make the divisibility analysis more + // powerful. + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(rootOp))) { + return signalPassFailure(); + } + + for (auto &[op, value] : queryOps) { + auto *lattice = solver.lookupState(value); + if (!lattice || lattice->getValue().isUninitialized()) { + op->setAttr("divisibility", StringAttr::get(context, "uninitialized")); + continue; + } + + // Format for the divisibility information is "udiv = X, sdiv = Y". + const auto &div = lattice->getValue().getValue(); + std::string result; + llvm::raw_string_ostream os(result); + os << "udiv = " << div.udiv() << ", sdiv = " << div.sdiv(); + op->setAttr("divisibility", StringAttr::get(context, os.str())); + } + } +}; + +} // namespace + +} // namespace mlir::iree_compiler::IREE::Util diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel index d3fe86862e8b..7c52ce11c6b3 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "annotate_op_ordinals.mlir", "attribute_call_graph.mlir", @@ -41,6 +42,7 @@ iree_lit_test_suite( "strip_debug_ops.mlir", "test_float_range_analysis.mlir", "test_float_range_analysis_linalg.mlir", + "test_integer_divisibility_analysis.mlir", "verify_initialization_order.mlir", "verify_structured_control_flow.mlir", ], diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt index abca38549966..658f9a9582f3 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt @@ -39,6 +39,7 @@ iree_lit_test_suite( "strip_debug_ops.mlir" "test_float_range_analysis.mlir" "test_float_range_analysis_linalg.mlir" + "test_integer_divisibility_analysis.mlir" "verify_initialization_order.mlir" "verify_structured_control_flow.mlir" TOOLS diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/test_integer_divisibility_analysis.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/test_integer_divisibility_analysis.mlir new file mode 100644 index 000000000000..998b6f9a5592 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/test_integer_divisibility_analysis.mlir @@ -0,0 +1,188 @@ +// RUN: iree-opt --split-input-file --iree-util-test-integer-divisibility-analysis --allow-unregistered-dialect %s | FileCheck %s + +// CHECK-LABEL: @affine_apply_mul_divisibility +util.func @affine_apply_mul_divisibility(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.apply affine_map<(d0) -> (d0 * 4)>(%0) + // CHECK: divisibility = "udiv = 32, sdiv = 32" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_mul_negative +util.func @affine_apply_mul_negative(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.apply affine_map<(d0) -> (d0 * -4)>(%0) + // CHECK: divisibility = "udiv = 32, sdiv = 32" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_add_gcd +util.func @affine_apply_add_gcd(%arg0 : index, %arg1 : index) -> index { + %0:2 = util.assume.int %arg0, + %arg1 : index, index + %1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%0#0, %0#1) + // CHECK: divisibility = "udiv = 8, sdiv = 8" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_floordiv_exact +util.func @affine_apply_floordiv_exact(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.apply affine_map<(d0) -> (d0 floordiv 4)>(%0) + // CHECK: divisibility = "udiv = 16, sdiv = 16" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_ceildiv_exact +util.func @affine_apply_ceildiv_exact(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.apply affine_map<(d0) -> (d0 ceildiv 4)>(%0) + // CHECK: divisibility = "udiv = 16, sdiv = 16" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_floordiv_non_exact +util.func @affine_apply_floordiv_non_exact(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.apply affine_map<(d0) -> (d0 floordiv 3)>(%0) + // CHECK: divisibility = "udiv = 1, sdiv = 1" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_mod +util.func @affine_apply_mod(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.apply affine_map<(d0) -> (d0 mod 16)>(%0) + // CHECK: divisibility = "udiv = 1, sdiv = 1" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_composition +util.func @affine_apply_composition(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.apply affine_map<(d0) -> (d0 * 4 + 16)>(%0) + // CHECK: divisibility = "udiv = 16, sdiv = 16" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_with_symbol +util.func @affine_apply_with_symbol(%arg0 : index, %arg1 : index) -> index { + %0:2 = util.assume.int %arg0, + %arg1 : index, index + %1 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%0#0)[%0#1] + // CHECK: divisibility = "udiv = 16, sdiv = 16" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_min_uniform_divisibility +util.func @affine_min_uniform_divisibility(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.min affine_map<(d0) -> (d0, d0 + 64)>(%0) + // CHECK: divisibility = "udiv = 16, sdiv = 16" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_min_different_divisibilities +util.func @affine_min_different_divisibilities(%arg0 : index, %arg1 : index) -> index { + %0:2 = util.assume.int %arg0, + %arg1 : index, index + %1 = affine.min affine_map<(d0, d1) -> (d0, d1)>(%0#0, %0#1) + // CHECK: divisibility = "udiv = 8, sdiv = 8" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_max_uniform_divisibility +util.func @affine_max_uniform_divisibility(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.max affine_map<(d0) -> (d0, d0 - 64)>(%0) + // CHECK: divisibility = "udiv = 32, sdiv = 32" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_max_different_divisibilities +util.func @affine_max_different_divisibilities(%arg0 : index, %arg1 : index, %arg2 : index) -> index { + %0:3 = util.assume.int %arg0, + %arg1, + %arg2 : index, index, index + %3 = affine.max affine_map<(d0, d1, d2) -> (d0, d1, d2)>(%0#0, %0#1, %0#2) + // CHECK: divisibility = "udiv = 6, sdiv = 6" + %4 = "iree_unregistered.test_int_divisibility"(%3) : (index) -> index + util.return %4 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_constant +util.func @affine_apply_constant() -> index { + %0 = affine.apply affine_map<() -> (64)>() + // CHECK: divisibility = "udiv = 64, sdiv = 64" + %1 = "iree_unregistered.test_int_divisibility"(%0) : (index) -> index + util.return %1 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_chained_operations +util.func @affine_apply_chained_operations(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.apply affine_map<(d0) -> (d0 * 8)>(%0) + %2 = affine.apply affine_map<(d0) -> (d0 + 16)>(%1) + // CHECK: divisibility = "udiv = 16, sdiv = 16" + %3 = "iree_unregistered.test_int_divisibility"(%2) : (index) -> index + util.return %3 : index +} + +// ----- + +// CHECK-LABEL: @complex_chained_affine_ops +util.func @complex_chained_affine_ops(%arg0 : index, %arg1 : index, %arg2 : index) -> index { + %0:3 = util.assume.int %arg0, + %arg1, + %arg2 : index, index, index + %1 = affine.apply affine_map<(d0, d1) -> (d0 + 2 * d1)>(%0#0, %0#1) + // CHECK: divisibility = "udiv = 14, sdiv = 14" + %div_1 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + %2 = affine.max affine_map<(d0, d1) -> (d0 floordiv 6, d1 * 3)>(%0#0, %0#2) + // CHECK: divisibility = "udiv = 5, sdiv = 5" + %div_2 = "iree_unregistered.test_int_divisibility"(%2) : (index) -> index + %3 = affine.min affine_map<(d0)[s0] -> (2 * (s0 * d0 - 14) ceildiv 7, d0 floordiv 3 * 2)>(%2)[%1] + // CHECK: divisibility = "udiv = 2, sdiv = 2" + %div_3 = "iree_unregistered.test_int_divisibility"(%3) : (index) -> index + util.return %div_3 : index +} diff --git a/compiler/src/iree/compiler/Dialect/VM/Analysis/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Analysis/BUILD.bazel index aab9e7200cb6..fbf14d3a3edb 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Analysis/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/Analysis/BUILD.bazel @@ -55,3 +55,18 @@ iree_compiler_cc_library( "@llvm-project//mlir:Support", ], ) + +iree_compiler_cc_library( + name = "OrdinalAnalysis", + srcs = [ + "OrdinalAnalysis.cpp", + ], + hdrs = [ + "OrdinalAnalysis.h", + ], + deps = [ + "//compiler/src/iree/compiler/Dialect/Util/IR", + "//compiler/src/iree/compiler/Dialect/VM/IR", + "@llvm-project//llvm:Support", + ], +) diff --git a/compiler/src/iree/compiler/Dialect/VM/Analysis/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/VM/Analysis/CMakeLists.txt index d0a857a822a1..bec12e5727a3 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Analysis/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/VM/Analysis/CMakeLists.txt @@ -52,4 +52,18 @@ iree_cc_library( PUBLIC ) +iree_cc_library( + NAME + OrdinalAnalysis + HDRS + "OrdinalAnalysis.h" + SRCS + "OrdinalAnalysis.cpp" + DEPS + LLVMSupport + iree::compiler::Dialect::Util::IR + iree::compiler::Dialect::VM::IR + PUBLIC +) + ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/iree/compiler/Dialect/VM/Analysis/LinearScan/LiveIntervals.cpp b/compiler/src/iree/compiler/Dialect/VM/Analysis/LinearScan/LiveIntervals.cpp index 7d0a940e5ed7..1241ee0e43a5 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Analysis/LinearScan/LiveIntervals.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Analysis/LinearScan/LiveIntervals.cpp @@ -36,8 +36,9 @@ LogicalResult LiveIntervals::annotateIR(IREE::VM::FuncOp funcOp) { // Annotate each block with its instruction range. for (auto *block : liveIntervals.getBlockOrder()) { - if (block->empty()) + if (block->empty()) { continue; + } uint32_t blockStart = liveIntervals.getInstructionIndex(&block->front()); uint32_t blockEnd = liveIntervals.getInstructionIndex(&block->back()); @@ -55,8 +56,9 @@ LogicalResult LiveIntervals::annotateIR(IREE::VM::FuncOp funcOp) { uint32_t opIndex = liveIntervals.getInstructionIndex(&op); op.setAttr("op_index", builder.getI32IntegerAttr(opIndex)); - if (op.getNumResults() == 0) + if (op.getNumResults() == 0) { continue; + } SmallVector intervalStrs; for (auto result : op.getResults()) { @@ -141,8 +143,9 @@ LogicalResult LiveIntervals::build(IREE::VM::FuncOp funcOp) { const LiveInterval *LiveIntervals::getInterval(Value value) const { auto it = valueToInterval_.find(value); - if (it == valueToInterval_.end()) + if (it == valueToInterval_.end()) { return nullptr; + } return &intervals_[it->second]; } @@ -168,8 +171,9 @@ void LiveIntervals::sortBlocksInDominanceOrder(IREE::VM::FuncOp funcOp) { } llvm::SmallSetVector markedBlocks; std::function visit = [&](Block *block) { - if (markedBlocks.count(block) > 0) + if (markedBlocks.count(block) > 0) { return; + } for (auto *childBlock : dominanceInfo.getNode(block)->children()) { visit(childBlock->getBlock()); } @@ -201,8 +205,9 @@ void LiveIntervals::buildIntervals(ValueLiveness &liveness) { for (auto *block : blockOrder_) { // Process block arguments. for (auto blockArg : block->getArguments()) { - if (valueToInterval_.count(blockArg)) + if (valueToInterval_.count(blockArg)) { continue; + } // Block arguments are "defined" at the start of the block. // We use the first op's index as the start. @@ -228,8 +233,9 @@ void LiveIntervals::buildIntervals(ValueLiveness &liveness) { uint32_t opIndex = opToIndex_[&op]; for (auto result : op.getResults()) { - if (valueToInterval_.count(result)) + if (valueToInterval_.count(result)) { continue; + } uint32_t start = opIndex; uint32_t end = findLastUse(result, liveness); diff --git a/compiler/src/iree/compiler/Dialect/VM/Analysis/OrdinalAnalysis.cpp b/compiler/src/iree/compiler/Dialect/VM/Analysis/OrdinalAnalysis.cpp new file mode 100644 index 000000000000..605eca8bfad1 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/VM/Analysis/OrdinalAnalysis.cpp @@ -0,0 +1,109 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/VM/Analysis/OrdinalAnalysis.h" + +#include "iree/compiler/Dialect/Util/IR/UtilTypes.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir::iree_compiler::IREE::VM { + +// Returns the size in bytes of the global when stored in memory. +// Valid only for globals using primitive storage. +static size_t getGlobalStorageSize(IREE::Util::GlobalOpInterface globalOp) { + auto storageType = globalOp.getGlobalType(); + assert(storageType.isIntOrFloat()); + assert(storageType.getIntOrFloatBitWidth() % 8 == 0); + return IREE::Util::getRoundedElementByteWidth(storageType); +} + +OrdinalAnalysis::OrdinalAnalysis(IREE::VM::ModuleOp moduleOp) { + // Assign ordinals based on IR order (which should be deterministic). + int nextFuncOrdinal = 0; + int nextImportOrdinal = 0; + int nextExportOrdinal = 0; + int nextGlobalRefOrdinal = 0; + int nextRodataOrdinal = 0; + + // Bucket the primitive global ops by byte size for alignment packing. + SmallVector, 8> primitiveGlobalOps( + sizeof(int64_t) + 1); + + for (auto &op : moduleOp.getBlock().getOperations()) { + if (auto funcOp = dyn_cast(op)) { + ordinals_[&op] = nextFuncOrdinal++; + } else if (isa(op)) { + ordinals_[&op] = nextExportOrdinal++; + } else if (isa(op)) { + ordinals_[&op] = nextImportOrdinal++; + } else if (isa(op)) { + ordinals_[&op] = nextRodataOrdinal++; + } else if (auto globalOp = dyn_cast(op)) { + if (isa(globalOp.getGlobalType())) { + ordinals_[&op] = nextGlobalRefOrdinal++; + } else { + // Bucket the primitive global ops by byte size for alignment packing. + size_t storageSize = getGlobalStorageSize(globalOp); + primitiveGlobalOps[storageSize].push_back(globalOp); + } + } + } + + // Assign byte offset values to primitive globals, ensuring that we meet + // natural alignment requirements on each size type. + int nextGlobalBytesOrdinal = 0; + int globalBytes = 0; + for (auto sizeGlobalOps : llvm::enumerate(primitiveGlobalOps)) { + size_t storageSize = sizeGlobalOps.index(); + if (sizeGlobalOps.value().empty()) { + continue; + } + nextGlobalBytesOrdinal = llvm::alignTo(nextGlobalBytesOrdinal, storageSize); + for (auto &globalOp : sizeGlobalOps.value()) { + ordinals_[globalOp] = nextGlobalBytesOrdinal; + nextGlobalBytesOrdinal += storageSize; + globalBytes = std::max(globalBytes, nextGlobalBytesOrdinal); + } + } + + // Record counts. + counts_.importFuncs = nextImportOrdinal; + counts_.exportFuncs = nextExportOrdinal; + counts_.internalFuncs = nextFuncOrdinal; + counts_.globalBytes = globalBytes; + counts_.globalRefs = nextGlobalRefOrdinal; + counts_.rodatas = nextRodataOrdinal; + counts_.rwdatas = 0; +} + +int64_t OrdinalAnalysis::getOrdinal(IREE::VM::FuncOp op) const { + return getOrdinal(op.getOperation()); +} + +int64_t OrdinalAnalysis::getOrdinal(IREE::VM::ExportOp op) const { + return getOrdinal(op.getOperation()); +} + +int64_t OrdinalAnalysis::getOrdinal(IREE::VM::ImportOp op) const { + return getOrdinal(op.getOperation()); +} + +int64_t OrdinalAnalysis::getOrdinal(IREE::VM::RodataOp op) const { + return getOrdinal(op.getOperation()); +} + +int64_t +OrdinalAnalysis::getGlobalOrdinal(IREE::Util::GlobalOpInterface op) const { + return getOrdinal(op.getOperation()); +} + +int64_t OrdinalAnalysis::getOrdinal(Operation *op) const { + auto it = ordinals_.find(op); + assert(it != ordinals_.end() && "ordinal not found for operation"); + return it->second; +} + +} // namespace mlir::iree_compiler::IREE::VM diff --git a/compiler/src/iree/compiler/Dialect/VM/Analysis/OrdinalAnalysis.h b/compiler/src/iree/compiler/Dialect/VM/Analysis/OrdinalAnalysis.h new file mode 100644 index 000000000000..67e2580d5e91 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/VM/Analysis/OrdinalAnalysis.h @@ -0,0 +1,78 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_VM_ANALYSIS_ORDINALANALYSIS_H_ +#define IREE_COMPILER_DIALECT_VM_ANALYSIS_ORDINALANALYSIS_H_ + +#include "iree/compiler/Dialect/VM/IR/VMOps.h" +#include "llvm/ADT/DenseMap.h" + +namespace mlir::iree_compiler::IREE::VM { + +// Computes ordinal assignments for module-level symbols. +// +// Each ordinal is unique per-category and ordinals are contiguous starting +// from zero. Categories include: +// - Internal functions (vm.func) +// - Import functions (vm.import) +// - Export functions (vm.export) +// - Rodata segments (vm.rodata) +// - Global refs (vm.global.ref) +// - Global bytes (byte offset for primitive globals) +// +// This analysis is computed on-demand when ordinals are needed for +// serialization, avoiding the need to store ordinals as attributes on ops. +class OrdinalAnalysis { +public: + // Summary counts of module-level symbols. + struct OrdinalCounts { + int32_t importFuncs = 0; + int32_t exportFuncs = 0; + int32_t internalFuncs = 0; + int32_t globalBytes = 0; + int32_t globalRefs = 0; + int32_t rodatas = 0; + int32_t rwdatas = 0; // Currently unused, reserved. + }; + + OrdinalAnalysis() = default; + explicit OrdinalAnalysis(IREE::VM::ModuleOp moduleOp); + + OrdinalAnalysis(OrdinalAnalysis &&) = default; + OrdinalAnalysis &operator=(OrdinalAnalysis &&) = default; + OrdinalAnalysis(const OrdinalAnalysis &) = delete; + OrdinalAnalysis &operator=(const OrdinalAnalysis &) = delete; + + // Returns the ordinal counts for the module. + const OrdinalCounts &getCounts() const { return counts_; } + + // Returns the ordinal for a vm.func op. + int64_t getOrdinal(IREE::VM::FuncOp op) const; + + // Returns the ordinal for a vm.export op. + int64_t getOrdinal(IREE::VM::ExportOp op) const; + + // Returns the ordinal for a vm.import op. + int64_t getOrdinal(IREE::VM::ImportOp op) const; + + // Returns the ordinal for a vm.rodata op. + int64_t getOrdinal(IREE::VM::RodataOp op) const; + + // Returns the byte offset ordinal for a primitive global. + // Returns -1 if the global is a ref type. + int64_t getGlobalOrdinal(IREE::Util::GlobalOpInterface op) const; + + // Generic ordinal lookup for any operation with an ordinal. + int64_t getOrdinal(Operation *op) const; + +private: + OrdinalCounts counts_; + llvm::DenseMap ordinals_; +}; + +} // namespace mlir::iree_compiler::IREE::VM + +#endif // IREE_COMPILER_DIALECT_VM_ANALYSIS_ORDINALANALYSIS_H_ diff --git a/compiler/src/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.cpp b/compiler/src/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.cpp index 49c8510917c8..717a2d648724 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.cpp @@ -62,8 +62,9 @@ LogicalResult RegisterAllocation::annotateIR(IREE::VM::FuncOp funcOp) { registerAllocation.remapSuccessorRegisters(&op, i); auto succOperands = branchOp.getSuccessorOperands(i).getForwardedOperands(); - if (succOperands.empty()) + if (succOperands.empty()) { continue; + } unsigned baseIdx = succOperands.getBeginOperandIndex(); // remapSuccessorRegisters only returns pairs where src != dst. // For display, we need ALL operands with correct MOVE bits. @@ -114,8 +115,9 @@ LogicalResult RegisterAllocation::annotateIR(IREE::VM::FuncOp funcOp) { op.setAttr("operand_registers", getStrArrayAttr(builder, operandRegStrs)); } - if (op.getNumResults() == 0) + if (op.getNumResults() == 0) { continue; + } SmallVector regStrs; regStrs.reserve(op.getNumResults()); for (auto result : op.getResults()) { @@ -167,8 +169,9 @@ sortBlocksInDominanceOrder(IREE::VM::FuncOp funcOp) { } llvm::SmallSetVector markedBlocks; std::function visit = [&](Block *block) { - if (markedBlocks.count(block) > 0) + if (markedBlocks.count(block) > 0) { return; + } for (auto *childBlock : dominanceInfo.getNode(block)->children()) { visit(childBlock->getBlock()); } @@ -369,15 +372,18 @@ LogicalResult RegisterAllocation::recalculate(IREE::VM::FuncOp funcOp) { llvm::DenseMap coalesceSource; auto recordCoalesceCandidate = [&](Value dest, Value src) { - if (dest.getType() != src.getType()) + if (dest.getType() != src.getType()) { return; + } auto srcInterval = liveIntervals.getInterval(src); auto destInterval = liveIntervals.getInterval(dest); - if (!srcInterval || !destInterval) + if (!srcInterval || !destInterval) { return; + } // Only coalesce if intervals meet exactly (hand-off). - if (srcInterval->end != destInterval->start) + if (srcInterval->end != destInterval->start) { return; + } coalesceSource[dest] = src; }; @@ -386,17 +392,20 @@ LogicalResult RegisterAllocation::recalculate(IREE::VM::FuncOp funcOp) { // Block arguments can coalesce with branch operands from predecessors. for (auto *pred : block->getPredecessors()) { auto branchOp = dyn_cast(pred->getTerminator()); - if (!branchOp) + if (!branchOp) { continue; + } for (unsigned succIdx = 0; succIdx < pred->getTerminator()->getNumSuccessors(); ++succIdx) { - if (pred->getTerminator()->getSuccessor(succIdx) != block) + if (pred->getTerminator()->getSuccessor(succIdx) != block) { continue; + } OperandRange operands = branchOp.getSuccessorOperands(succIdx).getForwardedOperands(); for (auto [idx, operand] : llvm::enumerate(operands)) { - if (idx >= block->getNumArguments()) + if (idx >= block->getNumArguments()) { break; + } recordCoalesceCandidate(block->getArgument(idx), operand); } } @@ -481,8 +490,9 @@ void RegisterAllocation::computeElidableDiscards(IREE::VM::FuncOp funcOp) { for (auto &block : funcOp.getBlocks()) { for (auto &op : block.getOperations()) { auto discardOp = dyn_cast(&op); - if (!discardOp) + if (!discardOp) { continue; + } SmallVector operandElidability; for (Value ref : discardOp.getRefs()) { @@ -510,8 +520,9 @@ void RegisterAllocation::computeElidableDiscards(IREE::VM::FuncOp funcOp) { break; } } - if (hasPrecedingMoveUse) + if (hasPrecedingMoveUse) { break; + } } operandElidability.push_back(hasPrecedingMoveUse); } @@ -633,18 +644,21 @@ struct FeedbackArcSet { SmallVector outEdges; outEdges.reserve(node->outdegree); for (auto &edge : edges) { - if (edge.sink == node) + if (edge.sink == node) { inEdges.push_back(edge); - if (edge.source == node) + } + if (edge.source == node) { outEdges.push_back(edge); + } } bool collectInEdges = node->indegree <= node->outdegree; bool collectOutEdges = !collectInEdges; SmallVector results; for (auto &edge : inEdges) { - if (edge.source == node) + if (edge.source == node) { continue; + } if (collectInEdges) { results.push_back({edge.source->id, edge.sink->id}); } @@ -654,8 +668,9 @@ struct FeedbackArcSet { assignBucket(edge.source); } for (auto &edge : outEdges) { - if (edge.sink == node) + if (edge.sink == node) { continue; + } if (collectOutEdges) { results.push_back({edge.source->id, edge.sink->id}); } @@ -681,11 +696,13 @@ struct FeedbackArcSet { ends.erase(ends.begin()); removeNode(node); } - if (remainingNodes.empty()) + if (remainingNodes.empty()) { break; + } for (ssize_t i = buckets.size() - 1; i >= 0; --i) { - if (buckets[i].empty()) + if (buckets[i].empty()) { continue; + } auto *bucket = buckets[i].front(); buckets[i].erase(buckets[i].begin()); auto feedbackEdges = removeNode(bucket); @@ -715,11 +732,13 @@ struct FeedbackArcSet { llvm::SmallSetVector unmarkedNodes = acyclicNodes; llvm::SmallSetVector markedNodes; std::function visit = [&](NodeID node) { - if (markedNodes.count(node) > 0) + if (markedNodes.count(node) > 0) { return; + } for (auto &edge : acyclicEdges) { - if (edge.first != node) + if (edge.first != node) { continue; + } visit(edge.second); } markedNodes.insert(node); @@ -729,8 +748,9 @@ struct FeedbackArcSet { } for (auto node : markedNodes.takeVector()) { for (auto &edge : acyclicEdges) { - if (edge.first != node) + if (edge.first != node) { continue; + } result.acyclicEdges.push_back({edge.first, edge.second}); } } diff --git a/compiler/src/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.h b/compiler/src/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.h index 1af00b0d8775..919a04100d2d 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.h +++ b/compiler/src/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.h @@ -227,8 +227,9 @@ class RegisterAllocation : public VMRegisterAllocation { // operands have already been released via MOVE on preceding operations. bool isDiscardElidable(Operation *op) const { auto it = discardOperandElidability_.find(op); - if (it == discardOperandElidability_.end()) + if (it == discardOperandElidability_.end()) { return false; + } return llvm::all_of(it->second, [](bool b) { return b; }); } @@ -238,10 +239,12 @@ class RegisterAllocation : public VMRegisterAllocation { bool isDiscardOperandElidable(Operation *op, unsigned operandIndex) const override { auto it = discardOperandElidability_.find(op); - if (it == discardOperandElidability_.end()) + if (it == discardOperandElidability_.end()) { return false; - if (operandIndex >= it->second.size()) + } + if (operandIndex >= it->second.size()) { return false; + } return it->second[operandIndex]; } diff --git a/compiler/src/iree/compiler/Dialect/VM/Analysis/ValueLiveness.cpp b/compiler/src/iree/compiler/Dialect/VM/Analysis/ValueLiveness.cpp index 5385554ca8bd..ff7aeab3e89d 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Analysis/ValueLiveness.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Analysis/ValueLiveness.cpp @@ -260,14 +260,17 @@ LogicalResult ValueLiveness::computeLiveIntervals(IREE::VM::FuncOp funcOp) { // Handle values entering the block and dying within. for (auto value : blockSets.liveIn) { - if (blockSets.liveOut.count(value)) + if (blockSets.liveOut.count(value)) { continue; + } Operation *lastUse = &block.front(); for (auto &use : value.getUses()) { - if (use.getOwner()->getBlock() != &block) + if (use.getOwner()->getBlock() != &block) { continue; - if (lastUse == use.getOwner()) + } + if (lastUse == use.getOwner()) { continue; + } if (lastUse->isBeforeInBlock(use.getOwner())) { lastUse = use.getOwner(); } @@ -277,14 +280,16 @@ LogicalResult ValueLiveness::computeLiveIntervals(IREE::VM::FuncOp funcOp) { // Handle values defined within the block and not escaping. for (auto value : blockSets.defined) { - if (blockSets.liveOut.count(value)) + if (blockSets.liveOut.count(value)) { continue; + } Operation *firstUse = value.getDefiningOp() ? value.getDefiningOp() : &block.front(); Operation *lastUse = firstUse; for (auto &use : value.getUses()) { - if (use.getOwner()->getBlock() != &block) + if (use.getOwner()->getBlock() != &block) { continue; + } if (lastUse->isBeforeInBlock(use.getOwner())) { lastUse = use.getOwner(); } @@ -386,8 +391,9 @@ bool ValueLiveness::isLastRealValueUse(Value value, Operation *useOp, break; } } - if (valueIsSuccessorOperand) + if (valueIsSuccessorOperand) { break; + } } } // Check if the value escapes to any successor blocks. diff --git a/compiler/src/iree/compiler/Dialect/VM/Analysis/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Analysis/test/BUILD.bazel index 90d56411b87c..5f09cbe79644 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Analysis/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/Analysis/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "live_intervals.mlir", "register_allocation.mlir", diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp index 678c8c41b2ce..027e6f8d222f 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp @@ -100,8 +100,9 @@ struct CmpI32OpConversion : public OpConversionPattern { LogicalResult matchAndRewrite(arith::CmpIOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!adaptor.getLhs().getType().isInteger(32)) + if (!adaptor.getLhs().getType().isInteger(32)) { return failure(); + } auto returnType = rewriter.getIntegerType(32); switch (srcOp.getPredicate()) { case arith::CmpIPredicate::eq: @@ -155,8 +156,9 @@ struct CmpI64OpConversion : public OpConversionPattern { LogicalResult matchAndRewrite(arith::CmpIOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!adaptor.getLhs().getType().isInteger(64)) + if (!adaptor.getLhs().getType().isInteger(64)) { return failure(); + } auto returnType = rewriter.getIntegerType(32); switch (srcOp.getPredicate()) { case arith::CmpIPredicate::eq: @@ -210,8 +212,9 @@ struct CmpF32OpConversion : public OpConversionPattern { LogicalResult matchAndRewrite(arith::CmpFOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!adaptor.getLhs().getType().isF32()) + if (!adaptor.getLhs().getType().isF32()) { return failure(); + } auto returnType = rewriter.getIntegerType(32); switch (srcOp.getPredicate()) { case arith::CmpFPredicate::AlwaysFalse: // 0 @@ -300,8 +303,9 @@ struct CmpF64OpConversion : public OpConversionPattern { LogicalResult matchAndRewrite(arith::CmpFOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!adaptor.getLhs().getType().isF64()) + if (!adaptor.getLhs().getType().isF64()) { return failure(); + } auto returnType = rewriter.getIntegerType(32); switch (srcOp.getPredicate()) { case arith::CmpFPredicate::AlwaysFalse: // 0 @@ -623,13 +627,15 @@ struct ExtendFOpConversion : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto srcType = dyn_cast_if_present(srcOp.getIn().getType()); auto resultType = dyn_cast_if_present(srcOp.getType()); - if (!srcType || !resultType) + if (!srcType || !resultType) { return failure(); + } auto dstType = getTypeConverter()->convertType(resultType); auto srcBits = srcType.getWidth(); auto resultBits = resultType.getWidth(); - if (srcBits != 32 || resultBits != 64) + if (srcBits != 32 || resultBits != 64) { return rewriter.notifyMatchFailure(srcOp, "unsupported extf conversion"); + } rewriter.replaceOpWithNewOp(srcOp, dstType, adaptor.getIn()); return success(); diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/BUILD.bazel index 9a0cbbace1fc..3076f512dc6f 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "arithmetic_ops.mlir", "assignment_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp index 20bfb1bd9df0..95c219e49430 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp @@ -63,8 +63,9 @@ LogicalResult ImportTable::build(Operation *rootOp, std::optional ImportTable::find(StringRef symbolName) { auto it = symbols.find(symbolName); - if (it == symbols.end()) + if (it == symbols.end()) { return std::nullopt; + } return it->second; } @@ -110,8 +111,9 @@ LogicalResult appendImportModule(StringRef importModuleSrc, Value castToImportType(Value value, Type targetType, OpBuilder &builder) { auto sourceType = value.getType(); - if (sourceType == targetType) + if (sourceType == targetType) { return value; + } bool sourceIsInteger = isa(sourceType); // Allow bitcast between same width float/int types. This is used for @@ -202,8 +204,9 @@ std::optional> rewriteAttrToOperands(Location loc, for (auto elementAttr : arrayAttr) { auto flattenedValues = rewriteAttrToOperands(loc, elementAttr, inputType, builder); - if (!flattenedValues) + if (!flattenedValues) { return std::nullopt; + } allValues.append(flattenedValues->begin(), flattenedValues->end()); } return allValues; @@ -226,8 +229,9 @@ std::optional> rewriteAttrToOperands(Location loc, int ordinal = 0; LogicalResult walkStatus = conversionInterface->walkAttributeStorage( attrValue, [&](Attribute elementAttr) { - if (anyFailed) + if (anyFailed) { return; + } auto elementType = tupleTypes[ordinal++]; auto flattenedValues = rewriteAttrToOperands(loc, elementAttr, elementType, builder); @@ -237,14 +241,16 @@ std::optional> rewriteAttrToOperands(Location loc, } allValues.append(flattenedValues->begin(), flattenedValues->end()); }); - if (failed(walkStatus)) + if (failed(walkStatus)) { return std::nullopt; + } } else { // Custom dialect type maps into zero or more input types (ala arrays). LogicalResult walkStatus = conversionInterface->walkAttributeStorage( attrValue, [&](Attribute elementAttr) { - if (anyFailed) + if (anyFailed) { return; + } auto flattenedValues = rewriteAttrToOperands(loc, elementAttr, inputType, builder); if (!flattenedValues) { @@ -253,11 +259,13 @@ std::optional> rewriteAttrToOperands(Location loc, } allValues.append(flattenedValues->begin(), flattenedValues->end()); }); - if (failed(walkStatus)) + if (failed(walkStatus)) { return std::nullopt; + } } - if (anyFailed) + if (anyFailed) { return std::nullopt; + } return allValues; } diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h index c5d557f5b42d..5093a686b7c1 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h @@ -102,8 +102,9 @@ rewriteToCall(T op, Adaptor adaptor, IREE::VM::ImportOp importOp, if (auto attrValue = op->getAttr(inputName)) { auto flattenedAttrs = detail::rewriteAttrToOperands( op.getLoc(), attrValue, inputType, builder); - if (!flattenedAttrs) + if (!flattenedAttrs) { return std::nullopt; + } state.addOperands(*flattenedAttrs); if (importOp.isFuncArgumentVariadic(input.index())) { segmentSizes.push_back(flattenedAttrs->size() / @@ -162,8 +163,9 @@ rewriteToCall(T op, Adaptor adaptor, IREE::VM::ImportOp importOp, for (auto [result, targetType] : llvm::zip_equal(callOp->getResults(), operation->getResultTypes())) { targetType = typeConverter.convertType(targetType); - if (!targetType) + if (!targetType) { return std::nullopt; + } results.push_back(castFromImportType(result, targetType, builder)); } return results; @@ -185,8 +187,9 @@ class VMImportOpConversion : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto results = rewriteToCall(op, adaptor, importOp, *this->getTypeConverter(), rewriter); - if (!results.has_value()) + if (!results.has_value()) { return failure(); + } rewriter.replaceOp(op, results.value()); return success(); } diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM/Patterns.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM/Patterns.cpp index 1a16eacffe91..e74176e95b84 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM/Patterns.cpp @@ -30,8 +30,9 @@ class UnaryArithmeticOpConversion : public OpConversionPattern { matchAndRewrite(SrcOpTy srcOp, typename SrcOpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { // TODO(benvanik): support vectors. - if (isa(srcOp.getResult().getType())) + if (isa(srcOp.getResult().getType())) { return failure(); + } switch (adaptor.getOperand().getType().getIntOrFloatBitWidth()) { case 32: @@ -57,8 +58,9 @@ class BinaryArithmeticOpConversion : public OpConversionPattern { matchAndRewrite(SrcOpTy srcOp, typename SrcOpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { // TODO(benvanik): support vectors. - if (isa(srcOp.getResult().getType())) + if (isa(srcOp.getResult().getType())) { return failure(); + } switch (adaptor.getLhs().getType().getIntOrFloatBitWidth()) { case 32: diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM/test/BUILD.bazel index 206cdadd7685..07c99862721b 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "arithmetic_ops.mlir", ], diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.cpp index efcea8fba614..e651bc9b995b 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.cpp @@ -115,16 +115,18 @@ struct FuncOpConversion : public OpConversionPattern { matchAndRewrite(func::FuncOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Handled by import-specific conversion. - if (srcOp.isExternal()) + if (srcOp.isExternal()) { return failure(); + } // Convert function signature. TypeConverter::SignatureConversion signatureConversion( srcOp.getNumArguments()); auto newFuncType = convertFuncSignature(srcOp, *getTypeConverter(), signatureConversion, rewriter); - if (failed(newFuncType)) + if (failed(newFuncType)) { return failure(); + } // Create new function with converted argument and result types. // Note that attributes are dropped. Consider preserving some if needed. @@ -189,8 +191,9 @@ struct ExternalFuncOpConversion : public OpConversionPattern { matchAndRewrite(func::FuncOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Handled by internal-specific conversion. - if (!srcOp.isExternal()) + if (!srcOp.isExternal()) { return failure(); + } // If the user declared an intended signature then we can use that instead // of running conversion ourselves. This can be used in cases where the @@ -210,8 +213,9 @@ struct ExternalFuncOpConversion : public OpConversionPattern { srcOp.getNumArguments()); auto convertedSignature = convertFuncSignature( srcOp, *getTypeConverter(), signatureConversion, rewriter); - if (failed(convertedSignature)) + if (failed(convertedSignature)) { return failure(); + } newSignature = *convertedSignature; } @@ -354,8 +358,9 @@ struct CallOpConversion : public OpConversionPattern { rewriter.setInsertionPointToStart(fallbackBlock); auto fallbackResults = convertCallOp(rootOp, loc, fallbackName, operands, resultTypes, importTable, rewriter); - if (failed(fallbackResults)) + if (failed(fallbackResults)) { return failure(); + } IREE::VM::BranchOp::create(rewriter, loc, exitBlock, *fallbackResults); return exitResults; diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/BUILD.bazel index bc6b7ac5b474..35075c7270dc 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "control_flow_ops.mlir", "func_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertBufferOps.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertBufferOps.cpp index f7f6ab226c34..4a6cd3f941ad 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertBufferOps.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertBufferOps.cpp @@ -20,15 +20,17 @@ namespace mlir::iree_compiler { namespace { static Value castToI64(Value value, OpBuilder &builder) { - if (value.getType().isInteger(64)) + if (value.getType().isInteger(64)) { return value; + } return builder.createOrFold( value.getLoc(), builder.getI64Type(), value); } static Value castToIndex(Value value, OpBuilder &builder) { - if (value.getType().isIndex()) + if (value.getType().isIndex()) { return value; + } return builder.createOrFold( value.getLoc(), builder.getIndexType(), value); } @@ -161,8 +163,9 @@ struct BufferCompareOpConversion static Value unscaleOffset(Location loc, Value offset, int64_t scale, OpBuilder &builder) { - if (scale == 1) + if (scale == 1) { return offset; + } return builder.createOrFold( loc, offset.getType(), offset, IREE::VM::ConstI64Op::create(builder, loc, scale)); diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertListOps.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertListOps.cpp index 1aefdfab5a80..1aec588a7b53 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertListOps.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertListOps.cpp @@ -20,15 +20,17 @@ namespace mlir::iree_compiler { namespace { static Value castToI32(Value value, OpBuilder &builder) { - if (value.getType().isInteger(32)) + if (value.getType().isInteger(32)) { return value; + } return builder.createOrFold( value.getLoc(), builder.getI32Type(), value); } static Value castToIndex(Value value, OpBuilder &builder) { - if (value.getType().isIndex()) + if (value.getType().isIndex()) { return value; + } return builder.createOrFold( value.getLoc(), builder.getIndexType(), value); } @@ -200,8 +202,9 @@ void populateUtilListToVMPatterns(MLIRContext *context, } else { elementType = typeConverter.convertType(type.getElementType()); } - if (!elementType) + if (!elementType) { return std::nullopt; + } return IREE::VM::RefType::get(IREE::VM::ListType::get(elementType)); }); diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertStructuralOps.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertStructuralOps.cpp index b9b2803fa9f7..65f723cf6513 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertStructuralOps.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertStructuralOps.cpp @@ -89,16 +89,18 @@ class FuncOpConversion : public OpConversionPattern { matchAndRewrite(IREE::Util::FuncOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Handled by import-specific conversion. - if (srcOp.isExternal()) + if (srcOp.isExternal()) { return failure(); + } // Convert function signature. TypeConverter::SignatureConversion signatureConversion( srcOp.getNumArguments()); auto newFuncType = convertFuncSignature(srcOp, *getTypeConverter(), signatureConversion, rewriter); - if (failed(newFuncType)) + if (failed(newFuncType)) { return failure(); + } // Create new function with converted argument and result types. // Note that attributes are dropped. Consider preserving some if needed. @@ -165,8 +167,9 @@ class ExternalFuncOpConversion matchAndRewrite(IREE::Util::FuncOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Handled by internal-specific conversion. - if (!srcOp.isExternal()) + if (!srcOp.isExternal()) { return failure(); + } // If the user declared an intended signature then we can use that instead // of running conversion ourselves. This can be used in cases where the @@ -186,8 +189,9 @@ class ExternalFuncOpConversion srcOp.getNumArguments()); auto convertedSignature = convertFuncSignature( srcOp, *getTypeConverter(), signatureConversion, rewriter); - if (failed(convertedSignature)) + if (failed(convertedSignature)) { return failure(); + } newSignature = *convertedSignature; } @@ -328,8 +332,9 @@ struct CallOpConversion : public OpConversionPattern { rewriter.setInsertionPointToStart(fallbackBlock); auto fallbackResults = convertCallOp(rootOp, loc, fallbackName, operands, resultTypes, rewriter); - if (failed(fallbackResults)) + if (failed(fallbackResults)) { return failure(); + } IREE::VM::BranchOp::create(rewriter, loc, exitBlock, *fallbackResults); return exitResults; diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.cpp index 27f9a80987fc..13b2ba2687e5 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.cpp @@ -103,6 +103,22 @@ struct CmpNEOpConversion : public OpConversionPattern { } }; +//===----------------------------------------------------------------------===// +// util.optimization_barrier +//===----------------------------------------------------------------------===// + +struct OptimizationBarrierOpConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(IREE::Util::OptimizationBarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, adaptor.getOperands()); + return success(); + } +}; + } // namespace void populateUtilToVMPatterns(MLIRContext *context, @@ -113,6 +129,7 @@ void populateUtilToVMPatterns(MLIRContext *context, patterns.insert(typeConverter, context); patterns.insert(typeConverter, context); patterns.insert(typeConverter, context); + patterns.insert(typeConverter, context); populateUtilAlignmentToVMPatterns(context, conversionTarget, typeConverter, patterns); diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/BUILD.bazel index 3a4dd9ee2458..1dc2826c8774 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "alignment_ops.mlir", "assignment_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp index 913b73e6edb5..e7e456b32313 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp @@ -131,8 +131,9 @@ LogicalResult convertFuncOp(IREE::VM::FuncOp funcOp, } if (failed( - funcOp.replaceAllSymbolUses(builder.getStringAttr(name), moduleOp))) + funcOp.replaceAllSymbolUses(builder.getStringAttr(name), moduleOp))) { return funcOp.emitError() << "unable to update symbol name in module"; + } return success(); } @@ -1186,8 +1187,9 @@ createModuleStructure(IREE::VM::ModuleOp moduleOp, rodataOp.getAlignment() ? static_cast(rodataOp.getAlignment().value()) : 0; - if (alignment == 0) + if (alignment == 0) { alignment = kDefaultRodataAlignment; + } std::string bufferName = moduleOp.getName().str() + "_" + rodataOp.getName().str(); @@ -1196,8 +1198,9 @@ createModuleStructure(IREE::VM::ModuleOp moduleOp, ") static const uint8_t " + bufferName + "[] = {"; size_t index = 0; for (char value : byteBuffer) { - if (index++ > 0) + if (index++ > 0) { stmt += ", "; + } stmt += std::to_string( static_cast(static_cast(value))); } @@ -2202,6 +2205,15 @@ class ImportOpConverter { return importOp.emitError() << "failed to create call"; } + // Release refs in argument buffer after call returns. Refs that were + // taken by the callee (via assign_ref+memset) will be null and release + // will be a no-op. + if (failed(releaseArgumentBuffer( + flattenInputTypes(importOp, segmentSizes, builder), call.value(), + builder, loc))) { + return importOp.emitError() << "failed to release argument buffer"; + } + if (failed(unpackResultBuffer(importOp.getResultTypes(), newFuncOp, call.value(), builder, loc))) { return importOp.emitError() << "failed to unpack result struct"; @@ -2451,10 +2463,14 @@ class ImportOpConverter { /*operand=*/uint8Ptr) .getResult(); + // Retain the ref into args_storage. The callee may take ownership via + // assign_ref+memset(0), so we must retain (not just assign/borrow). + // After the call returns, releaseArgumentBuffer will release any refs + // that weren't taken by the callee. emitc::CallOpaqueOp::create(builder, /*location=*/loc, /*type=*/TypeRange{}, - /*callee=*/"iree_vm_ref_assign", + /*callee=*/"iree_vm_ref_retain", /*operands=*/ArrayRef{arg, refPtr}); } else { auto argLValue = emitc_builders::asLValue(builder, loc, arg); @@ -2482,6 +2498,83 @@ class ImportOpConverter { return success(); } + // Releases refs in the argument buffer after an import call returns. + // This mirrors packArgumentBuffer but releases instead of retaining. + // Refs that were taken by the callee (via assign_ref+memset) will be null + // and release will be a no-op. + LogicalResult releaseArgumentBuffer(ArrayRef inputTypes, + TypedValue call, + OpBuilder &builder, Location loc) const { + // Find the last ref type index. We only need to iterate up to and including + // that index to release all refs. This avoids generating unused pointer + // arithmetic for trailing non-ref types. + std::optional lastRefIndex; + for (size_t i = 0; i < inputTypes.size(); i++) { + if (isa(inputTypes[i])) { + lastRefIndex = i; + } + } + if (!lastRefIndex) { + return success(); + } + + auto ctx = builder.getContext(); + + auto arguments = + emitc::MemberOp::create(builder, loc, + /*type=*/ + emitc::LValueType::get(emitc::OpaqueType::get( + ctx, "iree_byte_span_t")), + /*memberName=*/"arguments", + /*operand=*/call) + .getResult(); + + Type bytePtrType = + emitc::PointerType::get(builder.getIntegerType(8, false)); + auto uint8Ptr = emitc_builders::structMember(builder, loc, + /*type=*/bytePtrType, + /*memberName=*/"data", + /*operand=*/arguments); + + // Only iterate up to and including the last ref type. + for (size_t i = 0; i <= *lastRefIndex; i++) { + Type inputType = inputTypes[i]; + + // Get the value type and compute alignment (must match packArgumentBuffer + // exactly to ensure we're releasing the correct locations). + Type valueType = typeConverter.convertTypeAsNonPointer(inputType); + size_t alignment = getTypeAlignment(valueType); + if (alignment > 4) { + uint8Ptr = emitc_builders::alignPtr(builder, loc, uint8Ptr, alignment); + } + + // Release refs. If the callee took ownership and zeroed the ref, + // iree_vm_ref_release on a null ref is a no-op. + if (isa(inputType)) { + Type refPtrType = emitc::PointerType::get( + emitc::OpaqueType::get(ctx, "iree_vm_ref_t")); + Value refPtr = emitc::CastOp::create(builder, + /*location=*/loc, + /*type=*/refPtrType, + /*operand=*/uint8Ptr) + .getResult(); + emitc_builders::ireeVmRefRelease(builder, loc, refPtr); + } + + // Advance pointer to next element (only if not at the last ref). + if (i < *lastRefIndex) { + Value size = + emitc_builders::sizeOf(builder, loc, TypeAttr::get(valueType)); + uint8Ptr = + emitc::AddOp::create(builder, + /*location=*/loc, /*type=*/bytePtrType, + /*operands=*/ArrayRef{uint8Ptr, size}) + .getResult(); + } + } + return success(); + } + LogicalResult unpackResultBuffer(ArrayRef resultTypes, mlir::emitc::FuncOp &funcOp, TypedValue call, @@ -2695,11 +2788,13 @@ class CallOpConversion : public EmitCConversionPattern { IREE::VM::ImportOp importOp = lookupSymbolRef(op.getOperation(), "callee"); - if (!funcOp && !importOp) + if (!funcOp && !importOp) { return op.emitError() << "lookup of callee failed"; + } - if (funcOp && importOp) + if (funcOp && importOp) { return op.emitError() << "lookup of callee ambiguous"; + } const bool isImported = importOp != nullptr; @@ -2791,8 +2886,9 @@ class CallOpConversion : public EmitCConversionPattern { return failure(); } - if (!funcName.has_value()) + if (!funcName.has_value()) { return op->emitError() << "Couldn't build name to imported function"; + } auto callee = moduleOp.lookupSymbol(funcName.value()); if (callee == nullptr) { @@ -3733,8 +3829,9 @@ class BranchTableOpConversion { OpBuilder::InsertionGuard guard(rewriter); auto *nextBlock = rewriter.getInsertionBlock()->getNextNode(); - for (size_t i = 0; i < caseDestinations.size(); ++i) + for (size_t i = 0; i < caseDestinations.size(); ++i) { caseBlocks.push_back(rewriter.createBlock(nextBlock)); + } caseBlocks.push_back(rewriter.createBlock(nextBlock)); // default } IREE::VM::BranchOp::create(rewriter, op.getLoc(), caseBlocks.front()); @@ -5172,6 +5269,12 @@ class ConvertVMToEmitCPass void runOnOperation() override { IREE::VM::ModuleOp moduleOp = getOperation(); + // Erase vm.discard.refs ops before analysis. These are inserted by + // MaterializeRefDiscardsPass for the bytecode backend but are not used + // by EmitC. Erasing them here avoids inflating register pressure during + // the register allocation analysis. + moduleOp.walk([](IREE::VM::DiscardRefsOp op) { op.erase(); }); + ConversionTarget target(getContext()); EmitCTypeConverter typeConverter(moduleOp); diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.cpp index 4f36b525ab11..a4bc38fa0a36 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.cpp @@ -261,8 +261,9 @@ void structDefinition(OpBuilder builder, Location location, std::string decl = std::string("struct ") + structName.str() + " {"; for (auto &field : fields) { decl += field.type + " " + field.name; - if (field.isArray()) + if (field.isArray()) { decl += "[" + std::to_string(field.arraySize.value()) + "]"; + } decl += ";"; } decl += "};"; diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/BUILD.bazel index 2d9dc4c22200..1862bd578552 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/BUILD.bazel @@ -24,35 +24,36 @@ endif() iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ + "arithmetic_ops.mlir", "arithmetic_ops_f32.mlir", "arithmetic_ops_i64.mlir", - "arithmetic_ops.mlir", + "assignment_ops.mlir", "assignment_ops_f32.mlir", "assignment_ops_i64.mlir", - "assignment_ops.mlir", "buffer_ops.mlir", "buffer_ops_f32.mlir", "buffer_ops_f64.mlir", "buffer_ops_i64.mlir", + "comparison_ops.mlir", "comparison_ops_f32.mlir", "comparison_ops_i64.mlir", - "comparison_ops.mlir", + "const_ops.mlir", "const_ops_f32.mlir", "const_ops_i64.mlir", - "const_ops.mlir", "control_flow_ops.mlir", + "conversion_ops.mlir", "conversion_ops_f32.mlir", "conversion_ops_i64.mlir", - "conversion_ops.mlir", "func_op.mlir", + "global_ops.mlir", "global_ops_f32.mlir", "global_ops_i64.mlir", - "global_ops.mlir", - "list_ops_i64.mlir", "list_ops.mlir", - "shift_ops_i64.mlir", + "list_ops_i64.mlir", "shift_ops.mlir", + "shift_ops_i64.mlir", "type_conversion.mlir", ], include = ["*.mlir"], diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/control_flow_ops.mlir b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/control_flow_ops.mlir index 3929db967f10..7ee30a829acb 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/control_flow_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/control_flow_ops.mlir @@ -725,7 +725,8 @@ vm.module @my_module { // CHECK: %[[ARGSDATA:.+]] = load %[[ARGSDATA_LVAL]] : > // CHECK: call_opaque "iree_host_align" // CHECK: %[[ARG:.+]] = cast %{{.+}} : !emitc.ptr to !emitc.ptr> - // CHECK: call_opaque "iree_vm_ref_assign"(%arg2, %[[ARG]]) + // Retain the ref into args_storage (not just assign/borrow). + // CHECK: call_opaque "iree_vm_ref_retain"(%arg2, %[[ARG]]) // Create the call to the imported function. // CHECK: %[[MODULE_LVAL:.+]] = "emitc.member_of_ptr"(%[[FUNC_LVAL]]) <{member = "module"}> : (!emitc.lvalue>>) -> !emitc.lvalue>> @@ -735,6 +736,9 @@ vm.module @my_module { // CHECK-NEXT: %[[ARGSTRUCT_RVAL:.+]] = load %[[ARGSTRUCT]] : > // CHECK-NEXT: %{{.+}} = call_opaque "EMITC_CALL_INDIRECT"(%[[BEGIN_CALL]], %[[MODULE]], %arg0, %[[ARGSTRUCT_RVAL]]) + // Release refs in argument buffer after call returns. + // CHECK: call_opaque "iree_vm_ref_release" + // Unpack the function results (with pointer alignment). // CHECK: %[[RES_MEMBER:.+]] = "emitc.member"(%[[ARGSTRUCT]]) <{member = "results"}> : (!emitc.lvalue>) -> !emitc.lvalue> // CHECK: %[[RESPTR_MEMBER:.+]] = "emitc.member"(%[[RES_MEMBER]]) <{member = "data"}> : (!emitc.lvalue>) -> !emitc.lvalue> diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/IR/BUILD.bazel index d5dff039a403..803e15bf969e 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/IR/BUILD.bazel @@ -18,6 +18,7 @@ exports_files(["VMOps.td"]) iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "VMBase.td", "VMOpcodesCore.td", diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMDialect.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMDialect.cpp index dfb76203a164..378e77ce9c9a 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMDialect.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMDialect.cpp @@ -74,8 +74,9 @@ struct VMInlinerInterface : public DialectInlinerInterface { if (auto inliningPolicy = callable->getAttrOfType( "inlining_policy")) { - if (!inliningPolicy.isLegalToInline(call, callable)) + if (!inliningPolicy.isLegalToInline(call, callable)) { return false; + } } // Sure! return true; @@ -259,8 +260,9 @@ void VMDialect::printType(Type type, DialectAsmPrinter &os) const { Operation *VMDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { auto typedValue = dyn_cast(value); - if (!typedValue) + if (!typedValue) { return nullptr; + } if (ConstI32Op::isBuildableWith(typedValue, type)) { auto convertedValue = ConstI32Op::convertConstValue(typedValue); diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp index 381fc19e8bb4..93454ff077cd 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp @@ -44,8 +44,9 @@ Attribute oneOfType(Type type) { } else if (isa(type)) { auto vtType = cast(type); auto element = oneOfType(vtType.getElementType()); - if (!element) + if (!element) { return {}; + } return DenseElementsAttr::get(vtType, element); } return {}; @@ -64,8 +65,9 @@ struct DropEmptyInitializerOp : public OpRewritePattern { using Base::Base; LogicalResult matchAndRewrite(InitializerOp op, PatternRewriter &rewriter) const override { - if (op.getBody().getBlocks().size() != 1) + if (op.getBody().getBlocks().size() != 1) { return failure(); + } auto &block = op.getBody().front(); if (block.empty() || isa(block.front())) { rewriter.eraseOp(op); @@ -85,12 +87,14 @@ struct InlineConstGlobalInitializer : public OpRewritePattern { PatternRewriter &rewriter) const override { SmallVector deadOps; op.walk([&](Operation *op) { - if (!isGlobalStoreOp(op)) + if (!isGlobalStoreOp(op)) { return; + } auto value = op->getOperand(0); Attribute valueAttr; - if (!matchPattern(value, m_Constant(&valueAttr))) + if (!matchPattern(value, m_Constant(&valueAttr))) { return; + } auto globalRefAttr = op->getAttrOfType("global"); assert(globalRefAttr); auto globalOp = @@ -100,10 +104,12 @@ struct InlineConstGlobalInitializer : public OpRewritePattern { globalOp, [&]() { globalOp.setGlobalInitialValue(valueAttr); }); deadOps.push_back(op); }); - if (deadOps.empty()) + if (deadOps.empty()) { return failure(); - for (auto deadOp : deadOps) + } + for (auto deadOp : deadOps) { rewriter.eraseOp(deadOp); + } return success(); } @@ -135,14 +141,17 @@ struct DropDefaultConstGlobalOpInitializer : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const override { - if (!op.getInitialValue().has_value()) + if (!op.getInitialValue().has_value()) { return failure(); + } if (auto value = dyn_cast(op.getInitialValueAttr())) { - if (value.getValue() != 0) + if (value.getValue() != 0) { return failure(); + } } else if (auto value = dyn_cast(op.getInitialValueAttr())) { - if (value.getValue().isNonZero()) + if (value.getValue().isNonZero()) { return failure(); + } } auto visibility = op.getVisibility(); auto newOp = rewriter.replaceOpWithNewOp( @@ -488,8 +497,9 @@ static Attribute constFoldUnaryOp(Attribute rawOperand, dyn_cast_if_present(rawOperand)) { auto elementResult = constFoldUnaryOp( {operand.getSplatValue()}, calculate); - if (!elementResult) + if (!elementResult) { return {}; + } return DenseElementsAttr::get(operand.getType(), elementResult); } else if (auto operand = dyn_cast_if_present(rawOperand)) { return cast(operand).mapValues( @@ -511,8 +521,9 @@ constFoldFloatUnaryOp(Attribute rawOperand, dyn_cast_if_present(rawOperand)) { auto elementResult = constFoldFloatUnaryOp({operand.getSplatValue()}, calculate); - if (!elementResult) + if (!elementResult) { return {}; + } return DenseElementsAttr::get(operand.getType(), elementResult); } else if (auto operand = dyn_cast_if_present(rawOperand)) { return cast(operand).mapValues( @@ -535,33 +546,38 @@ static TypedAttr constFoldBinaryOp(Attribute rawLhs, Attribute rawRhs, const CalculationT &calculate) { if (auto lhs = dyn_cast_if_present(rawLhs)) { auto rhs = dyn_cast_if_present(rawRhs); - if (!rhs) + if (!rhs) { return {}; + } return AttrElementT::get(lhs.getType(), calculate(lhs.getValue(), rhs.getValue())); } else if (auto lhs = dyn_cast_if_present(rawLhs)) { // TODO(benvanik): handle splat/otherwise. auto rhs = dyn_cast_if_present(rawRhs); - if (!rhs || lhs.getType() != rhs.getType()) + if (!rhs || lhs.getType() != rhs.getType()) { return {}; + } auto elementResult = constFoldBinaryOp( lhs.getSplatValue(), rhs.getSplatValue(), calculate); - if (!elementResult) + if (!elementResult) { return {}; + } return DenseElementsAttr::get(lhs.getType(), elementResult); } else if (auto lhs = dyn_cast_if_present(rawLhs)) { auto rhs = dyn_cast_if_present(rawRhs); - if (!rhs || lhs.getType() != rhs.getType()) + if (!rhs || lhs.getType() != rhs.getType()) { return {}; + } auto lhsIt = lhs.getValues().begin(); auto rhsIt = rhs.getValues().begin(); SmallVector resultAttrs(lhs.getNumElements()); for (int64_t i = 0; i < lhs.getNumElements(); ++i) { resultAttrs[i] = constFoldBinaryOp(*lhsIt, *rhsIt, calculate); - if (!resultAttrs[i]) + if (!resultAttrs[i]) { return {}; + } ++lhsIt; ++rhsIt; } @@ -597,8 +613,9 @@ static Attribute constFoldTernaryOp(Attribute rawA, Attribute rawB, auto elementResult = constFoldTernaryOp( a.getSplatValue(), b.getSplatValue(), c.getSplatValue(), calculate); - if (!elementResult) + if (!elementResult) { return {}; + } return DenseElementsAttr::get(a.getType(), elementResult); } else if (auto a = dyn_cast_if_present(rawA)) { auto b = dyn_cast_if_present(rawB); @@ -613,8 +630,9 @@ static Attribute constFoldTernaryOp(Attribute rawA, Attribute rawB, for (int64_t i = 0; i < a.getNumElements(); ++i) { resultAttrs[i] = constFoldTernaryOp(*aIt, *bIt, *cIt, calculate); - if (!resultAttrs[i]) + if (!resultAttrs[i]) { return {}; + } ++aIt; ++bIt; ++cIt; @@ -669,14 +687,16 @@ static OpFoldResult foldAddOp(ADD op, Attribute lhs, Attribute rhs) { if (auto subOp = dyn_cast_if_present(op.getLhs().getDefiningOp())) { // t = vm.sub x, y // = vm.add t, z - if (subOp.getRhs() == op.getRhs()) // y == z: - return subOp.getLhs(); // (x - y) + y = x + if (subOp.getRhs() == op.getRhs()) { // y == z: + return subOp.getLhs(); // (x - y) + y = x + } } else if (auto subOp = dyn_cast_if_present(op.getRhs().getDefiningOp())) { // t = vm.sub x, y // = vm.add z, t - if (subOp.getRhs() == op.getLhs()) // y == z: - return subOp.getLhs(); // y + (x - y) = x + if (subOp.getRhs() == op.getLhs()) { // y == z: + return subOp.getLhs(); // y + (x - y) = x + } } return constFoldBinaryOp( lhs, rhs, @@ -716,10 +736,12 @@ static OpFoldResult foldSubOp(SUB op, Attribute lhs, Attribute rhs) { if (auto addOp = dyn_cast_if_present(op.getLhs().getDefiningOp())) { // t = vm.add x, y // = vm.sub t, z - if (addOp.getLhs() == op.getRhs()) // x == z: - return addOp.getRhs(); // (x + y) - x = y - if (addOp.getRhs() == op.getRhs()) // y == z: - return addOp.getLhs(); // (x + y) - y = x + if (addOp.getLhs() == op.getRhs()) { // x == z: + return addOp.getRhs(); // (x + y) - x = y + } + if (addOp.getRhs() == op.getRhs()) { // y == z: + return addOp.getLhs(); // (x + y) - y = x + } } return constFoldBinaryOp( lhs, rhs, @@ -764,8 +786,9 @@ struct FoldConstantMulOperand : public OpRewritePattern { LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const override { AttrElementT c1, c2; - if (!matchPattern(op.getRhs(), m_Constant(&c1))) + if (!matchPattern(op.getRhs(), m_Constant(&c1))) { return failure(); + } if (auto mulOp = dyn_cast_if_present(op.getLhs().getDefiningOp())) { if (matchPattern(mulOp.getRhs(), m_Constant(&c2))) { auto c = rewriter.createOrFold( @@ -980,8 +1003,9 @@ OpFoldResult AbsI64Op::fold(FoldAdaptor operands) { } OpFoldResult MinI32SOp::fold(FoldAdaptor operands) { - if (getLhs() == getRhs()) + if (getLhs() == getRhs()) { return getLhs(); + } return constFoldBinaryOp(operands.getLhs(), operands.getRhs(), [](const APInt &lhs, const APInt &rhs) { return llvm::APIntOps::smin(lhs, rhs); @@ -989,8 +1013,9 @@ OpFoldResult MinI32SOp::fold(FoldAdaptor operands) { } OpFoldResult MinI64SOp::fold(FoldAdaptor operands) { - if (getLhs() == getRhs()) + if (getLhs() == getRhs()) { return getLhs(); + } return constFoldBinaryOp(operands.getLhs(), operands.getRhs(), [](const APInt &lhs, const APInt &rhs) { return llvm::APIntOps::smin(lhs, rhs); @@ -998,8 +1023,9 @@ OpFoldResult MinI64SOp::fold(FoldAdaptor operands) { } OpFoldResult MinI32UOp::fold(FoldAdaptor operands) { - if (getLhs() == getRhs()) + if (getLhs() == getRhs()) { return getLhs(); + } return constFoldBinaryOp(operands.getLhs(), operands.getRhs(), [](const APInt &lhs, const APInt &rhs) { return llvm::APIntOps::umin(lhs, rhs); @@ -1007,8 +1033,9 @@ OpFoldResult MinI32UOp::fold(FoldAdaptor operands) { } OpFoldResult MinI64UOp::fold(FoldAdaptor operands) { - if (getLhs() == getRhs()) + if (getLhs() == getRhs()) { return getLhs(); + } return constFoldBinaryOp(operands.getLhs(), operands.getRhs(), [](const APInt &lhs, const APInt &rhs) { return llvm::APIntOps::umin(lhs, rhs); @@ -1016,8 +1043,9 @@ OpFoldResult MinI64UOp::fold(FoldAdaptor operands) { } OpFoldResult MaxI32SOp::fold(FoldAdaptor operands) { - if (getLhs() == getRhs()) + if (getLhs() == getRhs()) { return getLhs(); + } return constFoldBinaryOp(operands.getLhs(), operands.getRhs(), [](const APInt &lhs, const APInt &rhs) { return llvm::APIntOps::smax(lhs, rhs); @@ -1025,8 +1053,9 @@ OpFoldResult MaxI32SOp::fold(FoldAdaptor operands) { } OpFoldResult MaxI64SOp::fold(FoldAdaptor operands) { - if (getLhs() == getRhs()) + if (getLhs() == getRhs()) { return getLhs(); + } return constFoldBinaryOp(operands.getLhs(), operands.getRhs(), [](const APInt &lhs, const APInt &rhs) { return llvm::APIntOps::smax(lhs, rhs); @@ -1034,8 +1063,9 @@ OpFoldResult MaxI64SOp::fold(FoldAdaptor operands) { } OpFoldResult MaxI32UOp::fold(FoldAdaptor operands) { - if (getLhs() == getRhs()) + if (getLhs() == getRhs()) { return getLhs(); + } return constFoldBinaryOp(operands.getLhs(), operands.getRhs(), [](const APInt &lhs, const APInt &rhs) { return llvm::APIntOps::umax(lhs, rhs); @@ -1043,8 +1073,9 @@ OpFoldResult MaxI32UOp::fold(FoldAdaptor operands) { } OpFoldResult MaxI64UOp::fold(FoldAdaptor operands) { - if (getLhs() == getRhs()) + if (getLhs() == getRhs()) { return getLhs(); + } return constFoldBinaryOp(operands.getLhs(), operands.getRhs(), [](const APInt &lhs, const APInt &rhs) { return llvm::APIntOps::umax(lhs, rhs); @@ -1274,16 +1305,18 @@ OpFoldResult MinF64Op::fold(FoldAdaptor operands) { } OpFoldResult MaxF32Op::fold(FoldAdaptor operands) { - if (getLhs() == getRhs()) + if (getLhs() == getRhs()) { return getLhs(); + } return constFoldBinaryOp( operands.getLhs(), operands.getRhs(), [](const APFloat &a, const APFloat &b) { return llvm::maxnum(a, b); }); } OpFoldResult MaxF64Op::fold(FoldAdaptor operands) { - if (getLhs() == getRhs()) + if (getLhs() == getRhs()) { return getLhs(); + } return constFoldBinaryOp( operands.getLhs(), operands.getRhs(), [](const APFloat &a, const APFloat &b) { return llvm::maxnum(a, b); }); @@ -1810,8 +1843,9 @@ struct FoldCastRefIntoOpResult : public OpRewritePattern { PatternRewriter &rewriter) const override { auto zeroOp = dyn_cast_if_present( castOp.getOperand().getDefiningOp()); - if (!zeroOp) + if (!zeroOp) { return failure(); + } rewriter.replaceOpWithNewOp(castOp, castOp.getResult().getType()); return success(); @@ -1821,8 +1855,9 @@ struct FoldCastRefIntoOpResult : public OpRewritePattern { } // namespace OpFoldResult CastAnyRefOp::fold(FoldAdaptor operands) { - if (getOperand().getType() == getResult().getType()) + if (getOperand().getType() == getResult().getType()) { return getOperand(); + } if (auto castOp = dyn_cast_if_present(getOperand().getDefiningOp())) { if (castOp.getOperand().getType() == getResult().getType()) { @@ -1838,8 +1873,9 @@ void CastAnyRefOp::getCanonicalizationPatterns(RewritePatternSet &results, } OpFoldResult CastRefAnyOp::fold(FoldAdaptor operands) { - if (getOperand().getType() == getResult().getType()) + if (getOperand().getType() == getResult().getType()) { return getOperand(); + } if (auto castOp = dyn_cast_if_present(getOperand().getDefiningOp())) { if (castOp.getOperand().getType() == getResult().getType()) { @@ -1894,8 +1930,9 @@ static Attribute constFoldBinaryCmpOp(Attribute rawLhs, Attribute rawRhs, const CalculationT &calculate) { if (auto lhs = dyn_cast_if_present(rawLhs)) { auto rhs = dyn_cast_if_present(rawRhs); - if (!rhs) + if (!rhs) { return {}; + } auto boolType = IntegerType::get(lhs.getContext(), 32); return AttrElementT::get(boolType, calculate(lhs.getValue(), rhs.getValue())); @@ -2321,35 +2358,40 @@ static TypedAttr constFoldBinaryCmpFOp(Attribute rawLhs, Attribute rawRhs, const CalculationT &calculate) { if (auto lhs = dyn_cast_if_present(rawLhs)) { auto rhs = dyn_cast_if_present(rawRhs); - if (!rhs) + if (!rhs) { return {}; + } return IntegerAttr::get(IntegerType::get(lhs.getContext(), 32), calculate(lhs.getValue(), rhs.getValue())); } else if (auto lhs = dyn_cast_if_present(rawLhs)) { // TODO(benvanik): handle splat/otherwise. auto rhs = dyn_cast_if_present(rawRhs); - if (!rhs || lhs.getType() != rhs.getType()) + if (!rhs || lhs.getType() != rhs.getType()) { return {}; + } auto elementResult = constFoldBinaryCmpFOp( lhs.getSplatValue(), rhs.getSplatValue(), calculate); - if (!elementResult) + if (!elementResult) { return {}; + } auto resultType = lhs.getType().clone({}, IntegerType::get(lhs.getContext(), 32)); return DenseElementsAttr::get(resultType, elementResult); } else if (auto lhs = dyn_cast_if_present(rawLhs)) { auto rhs = dyn_cast_if_present(rawRhs); - if (!rhs || lhs.getType() != rhs.getType()) + if (!rhs || lhs.getType() != rhs.getType()) { return {}; + } auto lhsIt = lhs.getValues().begin(); auto rhsIt = rhs.getValues().begin(); SmallVector resultAttrs(lhs.getNumElements()); for (int64_t i = 0; i < lhs.getNumElements(); ++i) { resultAttrs[i] = constFoldBinaryCmpFOp(*lhsIt, *rhsIt, calculate); - if (!resultAttrs[i]) + if (!resultAttrs[i]) { return {}; + } ++lhsIt; ++rhsIt; } @@ -2979,22 +3021,27 @@ static LogicalResult collapseBranch(Block *&successor, return failure(); } // Check that the successor only contains a unconditional branch. - if (std::next(successor->begin()) != successor->end()) + if (std::next(successor->begin()) != successor->end()) { return failure(); + } // Check that the terminator is an unconditional branch. BranchOp successorBranch = dyn_cast(successor->getTerminator()); - if (!successorBranch) + if (!successorBranch) { return failure(); + } // Check that the arguments are only used within the terminator. for (BlockArgument arg : successor->getArguments()) { - for (Operation *user : arg.getUsers()) - if (user != successorBranch) + for (Operation *user : arg.getUsers()) { + if (user != successorBranch) { return failure(); + } + } } // Don't try to collapse branches to infinite loops. Block *successorDest = successorBranch.getDest(); - if (successorDest == successor) + if (successorDest == successor) { return failure(); + } // Update the operands to the successor. If the branch parent has no // arguments, we can use the branch operands directly. @@ -3008,10 +3055,11 @@ static LogicalResult collapseBranch(Block *&successor, // Otherwise, we need to remap any argument operands. for (Value operand : operands) { BlockArgument argOperand = dyn_cast(operand); - if (argOperand && argOperand.getOwner() == successor) + if (argOperand && argOperand.getOwner() == successor) { argStorage.push_back(successorOperands[argOperand.getArgNumber()]); - else + } else { argStorage.push_back(operand); + } } successor = successorDest; successorOperands = argStorage; @@ -3312,8 +3360,9 @@ struct RequiredImportResolver : public OpRewritePattern { PatternRewriter &rewriter) const override { auto importOp = SymbolTable::lookupNearestSymbolFrom( op, op.getImportAttr()); - if (!importOp || importOp.getIsOptional()) + if (!importOp || importOp.getIsOptional()) { return failure(); + } rewriter.replaceOpWithNewOp(op, 1); return success(); } diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp index 9e466acb63f0..199d9bc99786 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp @@ -65,19 +65,23 @@ void setResultIntegerName(OpAsmSetValueNameFn &setNameFn, Value result, // (type, type, ...) ParseResult parseResultTypeList(OpAsmParser &parser, ArrayAttr &resultTypes) { - if (failed(parser.parseLParen())) + if (failed(parser.parseLParen())) { return failure(); + } SmallVector typeAttrs; - if (succeeded(parser.parseOptionalRParen())) + if (succeeded(parser.parseOptionalRParen())) { goto done; // empty list + } do { Type type; - if (failed(parser.parseType(type))) + if (failed(parser.parseType(type))) { return failure(); + } typeAttrs.push_back(TypeAttr::get(type)); } while (succeeded(parser.parseOptionalComma())); - if (failed(parser.parseRParen())) + if (failed(parser.parseRParen())) { return failure(); + } done: resultTypes = parser.getBuilder().getArrayAttr(typeAttrs); return success(); @@ -172,9 +176,10 @@ Block *FuncOp::addEntryBlock() { LogicalResult FuncOp::verifyType() { auto type = getFunctionTypeAttr().getValue(); - if (!isa(type)) + if (!isa(type)) { return emitOpError("requires '" + getFunctionTypeAttrName().getValue() + "' attribute of function type"); + } return success(); } @@ -404,9 +409,10 @@ void ImportOp::build(OpBuilder &builder, OperationState &result, StringRef name, LogicalResult ImportOp::verifyType() { auto type = getFunctionTypeAttr().getValue(); - if (!isa(type)) + if (!isa(type)) { return emitOpError("requires '" + getFunctionTypeAttrName().getValue() + "' attribute of function type"); + } return success(); } @@ -609,8 +615,9 @@ static bool isConstFloatBuildableWith(TypedAttr value, Type type) { } else if (auto elementsAttr = dyn_cast(value)) { elementType = elementsAttr.getShapedType().getElementType(); } - if (!elementType) + if (!elementType) { return false; + } return elementType.getIntOrFloatBitWidth() == SZ; } @@ -920,8 +927,9 @@ static std::string makeSafeIdentifier(StringRef unsafeIdentifier) { llvm::raw_string_ostream os(result); bool lastUnderscore = true; for (char c : unsafeIdentifier) { - if (!llvm::isPrint(c)) + if (!llvm::isPrint(c)) { continue; + } if (llvm::isAlnum(c)) { os << llvm::toLower(c); lastUnderscore = false; @@ -1410,8 +1418,9 @@ void CallVariadicOp::print(OpAsmPrinter &p) { } p << tupleOperands; p << ')'; - if (i < segmentSize - 1) + if (i < segmentSize - 1) { p << ", "; + } } } else { SmallVector segmentOperands; @@ -1562,32 +1571,39 @@ static ParseResult parseBranchTableCases( SmallVectorImpl> &caseOperands, SmallVectorImpl> &caseOperandTypes) { if (parser.parseKeyword("default") || parser.parseColon() || - parser.parseSuccessor(defaultDestination)) + parser.parseSuccessor(defaultDestination)) { return failure(); + } if (succeeded(parser.parseOptionalLParen())) { if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None, /*allowResultNumber=*/false) || - parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen()) + parser.parseColonTypeList(defaultOperandTypes) || + parser.parseRParen()) { return failure(); + } } while (succeeded(parser.parseOptionalComma())) { int64_t index = 0; - if (failed(parser.parseInteger(index))) + if (failed(parser.parseInteger(index))) { return failure(); - if (index != caseDestinations.size()) + } + if (index != caseDestinations.size()) { return failure(); + } Block *destination; SmallVector operands; SmallVector operandTypes; if (failed(parser.parseColon()) || - failed(parser.parseSuccessor(destination))) + failed(parser.parseSuccessor(destination))) { return failure(); + } if (succeeded(parser.parseOptionalLParen())) { if (failed(parser.parseOperandList(operands, OpAsmParser::Delimiter::None, /*allowResultNumber=*/false)) || failed(parser.parseColonTypeList(operandTypes)) || - failed(parser.parseRParen())) + failed(parser.parseRParen())) { return failure(); + } } caseDestinations.push_back(destination); caseOperands.emplace_back(operands); @@ -1628,8 +1644,9 @@ Block *BranchTableOp::getSuccessorForOperands(ArrayRef operands) { SuccessorRange caseDestinations = getCaseDestinations(); if (auto valueAttr = dyn_cast_if_present(operands.front())) { int64_t value = valueAttr.getValue().getSExtValue(); - if (value < 0 || value >= caseDestinations.size()) + if (value < 0 || value >= caseDestinations.size()) { return getDefaultDestination(); + } return caseDestinations[value]; } return nullptr; @@ -1742,6 +1759,18 @@ SuccessorOperands CondBreakOp::getSuccessorOperands(unsigned index) { return SuccessorOperands(getDestOperandsMutable()); } +//===----------------------------------------------------------------------===// +// vm.optimization_barrier +//===----------------------------------------------------------------------===// + +void OptimizationBarrierOp::build(OpBuilder &builder, OperationState &state, + ValueRange operands, + ArrayRef attributes) { + state.addOperands(operands); + state.addTypes(llvm::to_vector(operands.getTypes())); + state.addAttributes(attributes); +} + } // namespace mlir::iree_compiler::IREE::VM //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td index bf36457bbdd6..2e6a99e5df73 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td @@ -5344,4 +5344,47 @@ def VM_CondBreakOp : VM_Op<"cond_break", [ } // OpGroupDebuggingOps +//===----------------------------------------------------------------------===// +// Compiler hints +//===----------------------------------------------------------------------===// + +def OpGroupCompilerHintOps : OpDocGroup { + let summary = "Compiler hint ops"; + let description = ""; +} + +let opDocGroup = OpGroupCompilerHintOps in { + +def VM_OptimizationBarrierOp : VM_Op<"optimization_barrier", [ + VM_PseudoOp, + AllTypesMatch<["operands", "results"]>, + ]> { + let summary = [{Prevents compiler optimizations across a value.}]; + let description = [{ + Wraps any operands in an unoptimizable identity to prevent its results from + being folded. It will be dropped during the final step in compilation and + has no effect at runtime. + }]; + + let arguments = (ins + Variadic:$operands + ); + let results = (outs + Variadic:$results + ); + + let assemblyFormat = [{ + attr-dict ($operands^ `:` type($operands))? + }]; + + let builders = [ + OpBuilder<(ins + "ValueRange":$operands, + CArg<"ArrayRef", "{}">:$attributes + )>, + ]; +} + +} // OpGroupCompilerHintOps + #endif // IREE_DIALECT_VM_OPS diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMTypes.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMTypes.cpp index ef0f25011e52..02d2b7df847f 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMTypes.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMTypes.cpp @@ -152,8 +152,9 @@ Attribute VMDialect::parseAttribute(DialectAsmParser &parser, Type type) const { Attribute genAttr; OptionalParseResult parseResult = generatedAttributeParser(parser, &mnemonic, type, genAttr); - if (parseResult.has_value()) + if (parseResult.has_value()) { return genAttr; + } parser.emitError(parser.getNameLoc()) << "unknown HAL attribute: " << mnemonic; return {}; diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/IR/test/BUILD.bazel index a886e9ec5890..0f638408a177 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "arithmetic_folding.mlir", "arithmetic_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/ArchiveWriter.cpp b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/ArchiveWriter.cpp index e8f0bb528770..8784f5296d17 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/ArchiveWriter.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/ArchiveWriter.cpp @@ -670,8 +670,9 @@ LogicalResult ZIPArchiveWriter::flush(FlatbufferBuilder &fbb) { return success(); }, os); - if (!zipFile.has_value()) + if (!zipFile.has_value()) { return failure(); + } fileRefs.push_back(*zipFile); } diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BUILD.bazel index 4bf8864c5e00..30df7a9c5403 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BUILD.bazel @@ -29,19 +29,17 @@ iree_compiler_cc_library( ], deps = [ "//compiler/src/iree/compiler/Dialect/Util/IR", - "//compiler/src/iree/compiler/Dialect/Util/Transforms", "//compiler/src/iree/compiler/Dialect/VM/Analysis", + "//compiler/src/iree/compiler/Dialect/VM/Analysis:OrdinalAnalysis", "//compiler/src/iree/compiler/Dialect/VM/Analysis:ValueLiveness", "//compiler/src/iree/compiler/Dialect/VM/Conversion", "//compiler/src/iree/compiler/Dialect/VM/IR", - "//compiler/src/iree/compiler/Dialect/VM/Transforms", "//compiler/src/iree/compiler/Dialect/VM/Utils:CallingConvention", "//compiler/src/iree/compiler/Dialect/VM/Utils:TypeTable", "//compiler/src/iree/compiler/Utils", "//runtime/src/iree/schemas:bytecode_module_def_c_fbs", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp index ad2958f544f4..882c5fb377b9 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp @@ -24,8 +24,10 @@ namespace { class V0BytecodeEncoder : public BytecodeEncoder { public: V0BytecodeEncoder(llvm::DenseMap *typeTable, - RegisterAllocation *registerAllocation) - : typeTable_(typeTable), registerAllocation_(registerAllocation) {} + RegisterAllocation *registerAllocation, + const OrdinalAnalysis *ordinalAnalysis) + : typeTable_(typeTable), registerAllocation_(registerAllocation), + ordinalAnalysis_(ordinalAnalysis) {} ~V0BytecodeEncoder() = default; LogicalResult beginBlock(Block *block) override { @@ -59,11 +61,7 @@ class V0BytecodeEncoder : public BytecodeEncoder { if (!symbolOp) { return currentOp_->emitOpError() << "target symbol not found: " << name; } - auto ordinalAttr = symbolOp->getAttrOfType("ordinal"); - if (!ordinalAttr) { - return symbolOp->emitOpError() << "missing ordinal"; - } - int32_t ordinal = ordinalAttr.getInt(); + int32_t ordinal = ordinalAnalysis_->getOrdinal(symbolOp); if (isa(symbolOp)) { // Imported functions have their MSB set. ordinal |= 0x80000000u; @@ -224,12 +222,14 @@ class V0BytecodeEncoder : public BytecodeEncoder { LogicalResult encodeBranchTable(SuccessorRange caseSuccessors, OperandRangeRange caseOperands, int baseSuccessorIndex) override { - if (failed(writeUint16(caseSuccessors.size()))) + if (failed(writeUint16(caseSuccessors.size()))) { return failure(); + } for (auto [successor, operands] : llvm::zip_equal(caseSuccessors, caseOperands)) { - if (failed(encodeBranch(successor, operands, ++baseSuccessorIndex))) + if (failed(encodeBranch(successor, operands, ++baseSuccessorIndex))) { return failure(); + } } return success(); } @@ -323,11 +323,13 @@ class V0BytecodeEncoder : public BytecodeEncoder { LogicalResult ensureAlignment(size_t alignment) { size_t paddedSize = (bytecode_.size() + (alignment - 1)) & ~(alignment - 1); size_t padding = paddedSize - bytecode_.size(); - if (padding == 0) + if (padding == 0) { return success(); + } static const uint8_t kZeros[32] = {0}; - if (padding > sizeof(kZeros)) + if (padding > sizeof(kZeros)) { return failure(); + } return writeBytes(kZeros, padding); } @@ -387,6 +389,7 @@ class V0BytecodeEncoder : public BytecodeEncoder { llvm::DenseMap *typeTable_; RegisterAllocation *registerAllocation_; + const OrdinalAnalysis *ordinalAnalysis_; Operation *currentOp_ = nullptr; @@ -400,7 +403,8 @@ class V0BytecodeEncoder : public BytecodeEncoder { // static std::optional BytecodeEncoder::encodeFunction( IREE::VM::FuncOp funcOp, llvm::DenseMap &typeTable, - SymbolTable &symbolTable, DebugDatabaseBuilder &debugDatabase) { + SymbolTable &symbolTable, const OrdinalAnalysis &ordinalAnalysis, + DebugDatabaseBuilder &debugDatabase) { EncodedBytecodeFunction result; // Perform register allocation first so that we can quickly lookup values as @@ -414,7 +418,7 @@ std::optional BytecodeEncoder::encodeFunction( FunctionSourceMap sourceMap; sourceMap.localName = funcOp.getName().str(); - V0BytecodeEncoder encoder(&typeTable, ®isterAllocation); + V0BytecodeEncoder encoder(&typeTable, ®isterAllocation, &ordinalAnalysis); for (auto &block : funcOp.getBlocks()) { size_t blockStart = encoder.getOffset(); diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h index 79ba38f22fa9..e29f1fbf36c5 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h @@ -7,6 +7,7 @@ #ifndef IREE_COMPILER_DIALECT_VM_TARGET_BYTECODE_BYTECODEENCODER_H_ #define IREE_COMPILER_DIALECT_VM_TARGET_BYTECODE_BYTECODEENCODER_H_ +#include "iree/compiler/Dialect/VM/Analysis/OrdinalAnalysis.h" #include "iree/compiler/Dialect/VM/IR/VMFuncEncoder.h" #include "iree/compiler/Dialect/VM/IR/VMOps.h" #include "iree/compiler/Dialect/VM/Target/Bytecode/DebugDatabaseBuilder.h" @@ -43,7 +44,9 @@ class BytecodeEncoder : public VMFuncEncoder { // Returns None on failure. static std::optional encodeFunction(IREE::VM::FuncOp funcOp, llvm::DenseMap &typeTable, - SymbolTable &symbolTable, DebugDatabaseBuilder &debugDatabase); + SymbolTable &symbolTable, + const OrdinalAnalysis &ordinalAnalysis, + DebugDatabaseBuilder &debugDatabase); BytecodeEncoder() = default; ~BytecodeEncoder() = default; diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp index e1bbef46d064..0fa7bbfe6eea 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp @@ -11,14 +11,13 @@ #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" -#include "iree/compiler/Dialect/Util/Transforms/Passes.h" +#include "iree/compiler/Dialect/VM/Analysis/OrdinalAnalysis.h" #include "iree/compiler/Dialect/VM/Analysis/RegisterAllocation.h" #include "iree/compiler/Dialect/VM/Analysis/ValueLiveness.h" #include "iree/compiler/Dialect/VM/IR/VMDialect.h" #include "iree/compiler/Dialect/VM/IR/VMOps.h" #include "iree/compiler/Dialect/VM/Target/Bytecode/ArchiveWriter.h" #include "iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h" -#include "iree/compiler/Dialect/VM/Transforms/Passes.h" #include "iree/compiler/Dialect/VM/Utils/CallingConvention.h" #include "iree/compiler/Dialect/VM/Utils/TypeTable.h" #include "iree/compiler/Utils/FlatbufferUtils.h" @@ -33,12 +32,9 @@ #include "mlir/IR/Diagnostics.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Visitors.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" #include "mlir/Tools/mlir-translate/Translation.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/LocationSnapshot.h" -#include "mlir/Transforms/Passes.h" IREE_DEFINE_COMPILER_OPTION_FLAGS( mlir::iree_compiler::IREE::VM::BytecodeTargetOptions); @@ -125,15 +121,16 @@ serializeEmbeddedData(Location loc, Attribute valueAttr, uint64_t alignment, } // Canonicalizes the module to its final form prior to emission. -// This verifies that we only have ops we can serialize and performs any of the -// required transformations (such as debug op stripping). +// This verifies that we only have ops we can serialize and removes any +// pseudo-ops and debug ops (when stripping is enabled). +// All transformation passes should have run in the main VM transformation +// pipeline before this is called. static LogicalResult canonicalizeModule(IREE::VM::BytecodeTargetOptions bytecodeOptions, IREE::VM::ModuleOp moduleOp) { RewritePatternSet patterns(moduleOp.getContext()); ConversionTarget target(*moduleOp.getContext()); target.addLegalDialect(); - target.addLegalOp(); // Add all VM canonicalization patterns and mark pseudo-ops illegal. auto *context = moduleOp.getContext(); @@ -145,7 +142,6 @@ canonicalizeModule(IREE::VM::BytecodeTargetOptions bytecodeOptions, } // Debug ops must not be present when stripping. - // TODO(benvanik): add RemoveDisabledDebugOp pattern. if (op.hasTrait() && bytecodeOptions.stripDebugOps) { target.setOpAction(op, ConversionTarget::LegalizationAction::Illegal); @@ -156,48 +152,6 @@ canonicalizeModule(IREE::VM::BytecodeTargetOptions bytecodeOptions, return moduleOp.emitError() << "unable to fully apply conversion to module"; } - PassManager passManager(context); - // TODO(12938): Handle or investigate failure result. - auto logicalRes = mlir::applyPassManagerCLOptions(passManager); - (void)logicalRes; - mlir::applyDefaultTimingPassManagerCLOptions(passManager); - passManager.addInstrumentation(std::make_unique()); - auto &modulePasses = passManager.nest(); - - // TODO(benvanik): these ideally happen beforehand but when performing - // serialization the input IR often has some of these low-level VM ops. In - // real workflows these have already run earlier and are no-ops. - modulePasses.addPass(IREE::VM::createGlobalInitializationPass()); - modulePasses.addPass(IREE::VM::createDropEmptyModuleInitializersPass()); - - if (bytecodeOptions.optimize) { - // TODO(benvanik): run this as part of a fixed-point iteration. - modulePasses.addPass(mlir::createInlinerPass()); - modulePasses.addPass(mlir::createCSEPass()); - // TODO(benvanik): re-evaluate whether this canonicalizer pass should exist - // in the bytecode target. It may be removing ops (like vm.discard.refs) - // that were intentionally inserted by earlier passes. - modulePasses.addPass(mlir::createCanonicalizerPass()); - } - - modulePasses.addPass(IREE::Util::createDropCompilerHintsPass()); - - // Insert explicit discard ops for ref values at their last use points. - // Uses edge-based placement: refs dying on control flow edges get discards - // inserted on those edges, refs dying mid-block get discards after last use. - modulePasses.addPass(IREE::VM::createMaterializeRefDiscardsPass()); - - // Mark up the module with ordinals for each top-level op (func, etc). - // This will make it easier to correlate the MLIR textual output to the - // binary output. - // We don't want any more modifications after this point as they could - // invalidate the ordinals. - modulePasses.addPass(IREE::VM::createOrdinalAllocationPass()); - - if (failed(passManager.run(moduleOp->getParentOfType()))) { - return moduleOp.emitError() << "failed during transform passes"; - } - return success(); } @@ -205,8 +159,9 @@ canonicalizeModule(IREE::VM::BytecodeTargetOptions bytecodeOptions, // empty/null list). static iree_vm_AttrDef_vec_ref_t makeAttrDefs(DictionaryAttr attrs, FlatbufferBuilder &fbb) { - if (!attrs || attrs.empty()) + if (!attrs || attrs.empty()) { return 0; + } SmallVector attrRefs; for (auto attr : attrs) { auto key = attr.getName().strref(); @@ -262,8 +217,9 @@ makeImportFunctionSignatureDef(IREE::VM::ImportOp importOp, FlatbufferBuilder &fbb) { // Generate the signature calling convention string based on types. auto cconv = makeImportCallingConventionString(importOp); - if (!cconv.has_value()) + if (!cconv.has_value()) { return {}; + } return createFunctionSignatureDef(importOp.getFunctionType(), typeTable, cconv.value(), /*attrsRef=*/0, fbb); } @@ -275,8 +231,9 @@ makeFunctionSignatureDef(IREE::VM::FuncOp funcOp, FlatbufferBuilder &fbb) { // Generate the signature calling convention string based on types. auto cconv = makeCallingConventionString(funcOp); - if (!cconv.has_value()) + if (!cconv.has_value()) { return {}; + } // Encode reflection attributes. iree_vm_AttrDef_vec_ref_t attrsRef = makeAttrDefs( @@ -317,12 +274,11 @@ static iree_vm_FeatureBits_enum_t findRequiredFeatures(Operation *rootOp) { // has been packed into the top-level table. This results in a messier function // here during serialization but a much more trivial (and cache-friendly) // representation at runtime. -static LogicalResult -buildFlatBufferModule(IREE::VM::TargetOptions vmOptions, - IREE::VM::BytecodeTargetOptions bytecodeOptions, - IREE::VM::ModuleOp moduleOp, - MutableArrayRef rodataRefs, - FlatbufferBuilder &fbb) { +static LogicalResult buildFlatBufferModule( + IREE::VM::TargetOptions vmOptions, + IREE::VM::BytecodeTargetOptions bytecodeOptions, + IREE::VM::ModuleOp moduleOp, const OrdinalAnalysis &ordinalAnalysis, + MutableArrayRef rodataRefs, FlatbufferBuilder &fbb) { // Start the buffer so that we can begin recording data prior to the root // table (which we do at the very end). This does not change the layout of the // file and is only used to prime the flatcc builder. @@ -334,26 +290,22 @@ buildFlatBufferModule(IREE::VM::TargetOptions vmOptions, DebugDatabaseBuilder debugDatabase; SymbolTable symbolTable(moduleOp); - OrdinalCountsAttr ordinalCounts = moduleOp.getOrdinalCountsAttr(); - if (!ordinalCounts) { - return moduleOp.emitError() << "ordinal_counts attribute not found. The " - "OrdinalAllocationPass must be run before."; - } + const auto &ordinalCounts = ordinalAnalysis.getCounts(); // Find all structural ops in the module. std::vector importFuncOps; std::vector exportFuncOps; std::vector internalFuncOps; - importFuncOps.resize(ordinalCounts.getImportFuncs()); - exportFuncOps.resize(ordinalCounts.getExportFuncs()); - internalFuncOps.resize(ordinalCounts.getInternalFuncs()); + importFuncOps.resize(ordinalCounts.importFuncs); + exportFuncOps.resize(ordinalCounts.exportFuncs); + internalFuncOps.resize(ordinalCounts.internalFuncs); for (auto &op : moduleOp.getBlock().getOperations()) { if (auto funcOp = dyn_cast(op)) { - internalFuncOps[funcOp.getOrdinal()->getLimitedValue()] = funcOp; + internalFuncOps[ordinalAnalysis.getOrdinal(funcOp)] = funcOp; } else if (auto exportOp = dyn_cast(op)) { - exportFuncOps[exportOp.getOrdinal()->getLimitedValue()] = exportOp; + exportFuncOps[ordinalAnalysis.getOrdinal(exportOp)] = exportOp; } else if (auto importOp = dyn_cast(op)) { - importFuncOps[importOp.getOrdinal()->getLimitedValue()] = importOp; + importFuncOps[ordinalAnalysis.getOrdinal(importOp)] = importOp; if (!importOp.getName().contains('.')) { return importOp.emitOpError("must reference a function in a module " "(@module_name.func_name); got unscoped `@") @@ -380,7 +332,7 @@ buildFlatBufferModule(IREE::VM::TargetOptions vmOptions, size_t totalBytecodeLength = 0; for (auto [i, funcOp] : llvm::enumerate(internalFuncOps)) { auto encodedFunction = BytecodeEncoder::encodeFunction( - funcOp, typeOrdinalMap, symbolTable, debugDatabase); + funcOp, typeOrdinalMap, symbolTable, ordinalAnalysis, debugDatabase); if (!encodedFunction) { return funcOp.emitError() << "failed to encode function bytecode"; } @@ -441,8 +393,9 @@ buildFlatBufferModule(IREE::VM::TargetOptions vmOptions, flatbuffers_uint8_vec_ref_t embeddedRef = serializeEmbeddedData( rodataRef.rodataOp.getLoc(), rodataRef.rodataOp.getValue(), rodataRef.alignment, rodataRef.totalSize, fbb); - if (!embeddedRef) + if (!embeddedRef) { return failure(); + } iree_vm_RodataSegmentDef_start(fbb); iree_vm_RodataSegmentDef_embedded_data_add(fbb, embeddedRef); rodataSegmentRefs.push_back(iree_vm_RodataSegmentDef_end(fbb)); @@ -466,7 +419,7 @@ buildFlatBufferModule(IREE::VM::TargetOptions vmOptions, iree_vm_ExportFunctionDef_start(fbb); iree_vm_ExportFunctionDef_local_name_add(fbb, localNameRef); iree_vm_ExportFunctionDef_internal_ordinal_add( - fbb, funcOp.getOrdinal()->getLimitedValue()); + fbb, ordinalAnalysis.getOrdinal(funcOp)); return iree_vm_ExportFunctionDef_end(fbb); }); @@ -530,8 +483,8 @@ buildFlatBufferModule(IREE::VM::TargetOptions vmOptions, auto dependenciesRef = fbb.createOffsetVecDestructive(dependencyRefs); auto typesRef = fbb.createOffsetVecDestructive(typeRefs); - int32_t globalRefs = ordinalCounts.getGlobalRefs(); - int32_t globalBytes = ordinalCounts.getGlobalBytes(); + int32_t globalRefs = ordinalCounts.globalRefs; + int32_t globalBytes = ordinalCounts.globalBytes; iree_vm_ModuleStateDef_ref_t moduleStateDef = 0; if (globalBytes || globalRefs) { @@ -553,10 +506,12 @@ buildFlatBufferModule(IREE::VM::TargetOptions vmOptions, // so that we can multi-version. For now the moduleRequirements will be the OR // of all functions. iree_vm_FeatureBits_enum_t allowedFeatures = 0; - if (vmOptions.f32Extension) + if (vmOptions.f32Extension) { allowedFeatures |= iree_vm_FeatureBits_EXT_F32; - if (vmOptions.f64Extension) + } + if (vmOptions.f64Extension) { allowedFeatures |= iree_vm_FeatureBits_EXT_F64; + } // Yield/unwind are core VM semantics once supported by the runtime. allowedFeatures |= iree_vm_FeatureBits_YIELD; allowedFeatures |= iree_vm_FeatureBits_UNWIND; @@ -656,15 +611,18 @@ translateModuleToBytecode(IREE::VM::ModuleOp moduleOp, assert(false && "unhandled output format combination"); } + // Compute ordinals for all module-level symbols. + OrdinalAnalysis ordinalAnalysis(moduleOp); + // Declare all rodata entries we want to end up as external data first. This // allows us to compute offsets if needed without having had to perform // serialization yet. Note that not all rodata ends up as external data: if // it's small (like strings) we can avoid the extra seeks and keep it more // local by embedding it in the FlatBuffer. std::vector rodataOps; - rodataOps.resize(moduleOp.getOrdinalCountsAttr().getRodatas()); + rodataOps.resize(ordinalAnalysis.getCounts().rodatas); for (auto rodataOp : moduleOp.getOps()) { - rodataOps[rodataOp.getOrdinal()->getLimitedValue()] = rodataOp; + rodataOps[ordinalAnalysis.getOrdinal(rodataOp)] = rodataOp; } SmallVector rodataRefs; rodataRefs.resize(rodataOps.size()); @@ -699,7 +657,7 @@ translateModuleToBytecode(IREE::VM::ModuleOp moduleOp, llvm::endianness::little, os); }); } - rodataRefs[rodataOp.getOrdinal()->getLimitedValue()] = rodataRef; + rodataRefs[ordinalAnalysis.getOrdinal(rodataOp)] = rodataRef; } // NOTE: we order things so that all of the metadata is close to the start of @@ -708,7 +666,7 @@ translateModuleToBytecode(IREE::VM::ModuleOp moduleOp, // can be large bulk data. FlatbufferBuilder fbb; if (failed(buildFlatBufferModule(vmOptions, bytecodeOptions, moduleOp, - rodataRefs, fbb))) { + ordinalAnalysis, rodataRefs, fbb))) { return failure(); } if (failed(archiveWriter->flush(fbb))) { @@ -751,11 +709,6 @@ void BytecodeTargetOptions::bindOptions(OptionsBinder &binder) { clEnumValN(BytecodeOutputFormat::kAnnotatedMlirText, "annotated-mlir-text", "MLIR module file in the VM dialect with annotations"))); - binder.opt( - "iree-vm-bytecode-module-optimize", optimize, - llvm::cl::cat(vmBytecodeOptionsCategory), - llvm::cl::desc("Optimizes the VM module with CSE/inlining/etc prior to " - "serialization")); binder.opt( "iree-vm-bytecode-source-listing", sourceListing, llvm::cl::cat(vmBytecodeOptionsCategory), diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h index 7f7e8788ed93..e2137ad49143 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h @@ -35,9 +35,6 @@ struct BytecodeTargetOptions { // Format of the module written to the output stream. BytecodeOutputFormat outputFormat = BytecodeOutputFormat::kFlatBufferBinary; - // Run basic CSE/inlining/etc passes prior to serialization. - bool optimize = true; - // Dump a VM MLIR file and annotate source locations with it. // This allows for the runtime to serve stack traces referencing both the // original source locations and the VM IR. diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/CMakeLists.txt index c2c94c4a327a..cabe350656e9 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/CMakeLists.txt @@ -27,18 +27,16 @@ iree_cc_library( DEPS LLVMSupport MLIRIR - MLIRPass MLIRSupport MLIRTransformUtils MLIRTransforms MLIRTranslateLib iree::compiler::Dialect::Util::IR - iree::compiler::Dialect::Util::Transforms iree::compiler::Dialect::VM::Analysis + iree::compiler::Dialect::VM::Analysis::OrdinalAnalysis iree::compiler::Dialect::VM::Analysis::ValueLiveness iree::compiler::Dialect::VM::Conversion iree::compiler::Dialect::VM::IR - iree::compiler::Dialect::VM::Transforms iree::compiler::Dialect::VM::Utils::CallingConvention iree::compiler::Dialect::VM::Utils::TypeTable iree::compiler::Utils diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/DebugDatabaseBuilder.cpp b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/DebugDatabaseBuilder.cpp index 6e1edb581c00..172777efdde6 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/DebugDatabaseBuilder.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/DebugDatabaseBuilder.cpp @@ -34,8 +34,9 @@ struct LocationTable { // Inserts a string into the location table string subtable if needed. flatbuffers_string_ref_t insert(StringRef value) { auto it = strings.find(value); - if (it != strings.end()) + if (it != strings.end()) { return it->second; + } auto stringRef = fbb.createString(value); strings[value] = stringRef; return stringRef; @@ -45,8 +46,9 @@ struct LocationTable { // Returns the ordinal of the location in the table. int32_t insert(Location baseLoc) { auto it = map.find(baseLoc); - if (it != map.end()) + if (it != map.end()) { return it->second; + } auto locationRef = llvm::TypeSwitch(baseLoc) .Case([&](CallSiteLoc loc) { @@ -103,8 +105,9 @@ struct LocationTable { iree_vm_DebugDatabaseDef_ref_t DebugDatabaseBuilder::build(FlatbufferBuilder &fbb) { - if (functionSourceMaps.empty()) + if (functionSourceMaps.empty()) { return 0; + } LocationTable locationTable(fbb); diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/BUILD.bazel index 8ef4a8724d18..99467bbbb317 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "constant_encoding.mlir", "dependencies.mlir", diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/constant_encoding.mlir b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/constant_encoding.mlir index 52f160c3c124..a110a46dad4c 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/constant_encoding.mlir +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/constant_encoding.mlir @@ -4,11 +4,6 @@ // CHECK: "name": "constants" vm.module @constants { - vm.export @func - vm.func @func() { - vm.return - } - // CHECK: "rodata_segments": [{ // Tests that we densely pack i2 values. Note that the final element (3) is @@ -18,7 +13,7 @@ vm.module @constants { // CHECK-NEXT: 26, // CHECK-NEXT: 3 // CHECK-NEXT: ] - vm.rodata private @dense_i2 dense<[0, 1, 2, 3, 2, 2, 1, 0, 3]> : tensor<9xi2> + vm.rodata public @dense_i2 dense<[0, 1, 2, 3, 2, 2, 1, 0, 3]> : tensor<9xi2> // Tests that we densely pack i3 values and insert the wasted 2-bits of // padding in each byte. Smarter implementations would pack to 16- or 64-bit @@ -29,7 +24,7 @@ vm.module @constants { // CHECK-NEXT: 44, // CHECK-NEXT: 62 // CHECK-NEXT: ] - vm.rodata private @dense_i3 dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi3> + vm.rodata public @dense_i3 dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi3> // Tests that we densely pack i4 values and handle partial values (14). // CHECK: "embedded_data": [ @@ -43,7 +38,7 @@ vm.module @constants { // CHECK-NEXT: 254, // CHECK-NEXT: 14 // CHECK-NEXT: ] - vm.rodata private @dense_i4 dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 14]> : tensor<17xi4> + vm.rodata public @dense_i4 dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 14]> : tensor<17xi4> // CHECK: "embedded_data": [ // CHECK-NEXT: 98, @@ -51,14 +46,14 @@ vm.module @constants { // CHECK-NEXT: 197, // CHECK-NEXT: 28 // CHECK-NEXT: ] - vm.rodata private @dense_i5 dense<[2, 3, 4, 5, 6, 7]> : tensor<6xi5> + vm.rodata public @dense_i5 dense<[2, 3, 4, 5, 6, 7]> : tensor<6xi5> // CHECK: "embedded_data": [ // CHECK-NEXT: 1, // CHECK-NEXT: 2, // CHECK-NEXT: 3 // CHECK-NEXT: ] - vm.rodata private @dense_i8 dense<[1, 2, 3]> : tensor<3xi8> + vm.rodata public @dense_i8 dense<[1, 2, 3]> : tensor<3xi8> // CHECK: "embedded_data": [ // CHECK-NEXT: 1, @@ -70,7 +65,7 @@ vm.module @constants { // CHECK-NEXT: 0, // CHECK-NEXT: 0 // CHECK-NEXT: ] - vm.rodata private @dense_i9 dense<[1, 2, 3, 4, 5]> : tensor<5xi9> + vm.rodata public @dense_i9 dense<[1, 2, 3, 4, 5]> : tensor<5xi9> // CHECK: "embedded_data": [ // CHECK-NEXT: 0, @@ -80,7 +75,7 @@ vm.module @constants { // CHECK-NEXT: 0, // CHECK-NEXT: 66 // CHECK-NEXT: ] - vm.rodata private @dense_f16 dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf16> + vm.rodata public @dense_f16 dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf16> // CHECK: "embedded_data": [ // CHECK-NEXT: 0, @@ -90,7 +85,7 @@ vm.module @constants { // CHECK-NEXT: 0, // CHECK-NEXT: 60 // CHECK-NEXT: ] - vm.rodata private @splat_f16 dense<1.000000e+00> : tensor<3xf16> + vm.rodata public @splat_f16 dense<1.000000e+00> : tensor<3xf16> // CHECK: "embedded_data": [ // CHECK-NEXT: 0, @@ -106,7 +101,7 @@ vm.module @constants { // CHECK-NEXT: 64, // CHECK-NEXT: 64 // CHECK-NEXT: ] - vm.rodata private @dense_f32 dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32> + vm.rodata public @dense_f32 dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32> // CHECK: "embedded_data": [ @@ -128,7 +123,7 @@ vm.module @constants { // CHECK-NEXT: 128, // CHECK-NEXT: 64 // CHECK-NEXT: ] - vm.rodata private @dense_resource_complex_f32 dense< + vm.rodata public @dense_resource_complex_f32 dense< "0x0000803F000000400000404000008040" > : tensor<2xcomplex> @@ -146,7 +141,7 @@ vm.module @constants { // CHECK-NEXT: 128, // CHECK-NEXT: 63 // CHECK-NEXT: ] - vm.rodata private @splat_f32 dense<1.000000e+00> : tensor<3xf32> + vm.rodata public @splat_f32 dense<1.000000e+00> : tensor<3xf32> // Tests that elided tensors of sub-byte types get filled with zeros when the // --iree-util-zero-fill-elided-attrs flag is passed. This is useful for @@ -157,7 +152,7 @@ vm.module @constants { // CHECK-NEXT: 0, // CHECK-NEXT: 0 // CHECK-NEXT: ] - vm.rodata private @elided_i2 dense_resource<__elided__> : tensor<9xi2> + vm.rodata public @elided_i2 dense_resource<__elided__> : tensor<9xi2> // CHECK: "embedded_data": [ // CHECK-NEXT: 0, @@ -173,7 +168,7 @@ vm.module @constants { // CHECK-NEXT: 0, // CHECK-NEXT: 0 // CHECK-NEXT: ] - vm.rodata private @elided_f32 dense_resource<__elided__> : tensor<3xf32> + vm.rodata public @elided_f32 dense_resource<__elided__> : tensor<3xf32> // Tests #util.byte_pattern on sub-byte types. // CHECK: "embedded_data": [ @@ -181,7 +176,7 @@ vm.module @constants { // CHECK-NEXT: 1, // CHECK-NEXT: 1 // CHECK-NEXT: ] - vm.rodata private @byte_pattern_i2 #util.byte_pattern<1> : tensor<9xi2> + vm.rodata public @byte_pattern_i2 #util.byte_pattern<1> : tensor<9xi2> // CHECK: "embedded_data": [ // CHECK-NEXT: 0, @@ -209,5 +204,5 @@ vm.module @constants { // CHECK-NEXT: 8, // CHECK-NEXT: 64 // CHECK-NEXT: ] - vm.rodata private @dense_f64 dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf64> + vm.rodata public @dense_f64 dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf64> } diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/dependencies.mlir b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/dependencies.mlir index 914ddae9e830..6b3dd3478905 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/dependencies.mlir +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/dependencies.mlir @@ -32,6 +32,16 @@ vm.module @main_module attributes { version = 100 : i32 } { // CHECK: "flags": "OPTIONAL" vm.import private optional @optional.method1() attributes { minimum_version = 11 : i32 } + // Use the imports so they're not eliminated by DCE. + vm.export @use_imports + vm.func private @use_imports() { + vm.call @required.method0() : () -> () + vm.call @required.method1() : () -> () + vm.call @required.method2() : () -> () + vm.call @optional.method0() : () -> () + vm.call @optional.method1() : () -> () + vm.return + } } // ----- diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp b/compiler/src/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp index 2f8e5a2fb3a8..87e2c29dbd7d 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp @@ -28,7 +28,6 @@ canonicalizeModule(IREE::VM::ModuleOp moduleOp, RewritePatternSet patterns(moduleOp.getContext()); ConversionTarget target(*moduleOp.getContext()); target.addLegalDialect(); - target.addLegalOp(); // Add all VM canonicalization patterns and mark pseudo-ops illegal. auto *context = moduleOp.getContext(); @@ -86,15 +85,19 @@ canonicalizeModule(IREE::VM::ModuleOp moduleOp, // invalidate the ordinals. modulePasses.addPass(IREE::VM::createOrdinalAllocationPass()); - // C target specific pass - modulePasses.addPass(createConvertVMToEmitCPass()); + // Drop vm.optimization_barrier ops before EmitC conversion. The barriers + // prevent folding during VM-level optimizations above, but EmitC doesn't + // have conversion patterns for vm.optimization_barrier. + modulePasses.addPass(IREE::VM::createDropOptimizationBarriersPass()); - // Drop optimization barriers after EmitC conversion. Must be after conversion - // so barriers prevent folding during VM-level canonicalization, but the - // subsequent canonicalizer only sees EmitC ops (which don't fold VM - // constants). - modulePasses.addPass(IREE::Util::createDropCompilerHintsPass()); + // Clean up dead code created by dropping barriers. The barriers prevented + // constant folding, so after dropping them we need to eliminate unused + // constants to avoid generating unused variables in EmitC. modulePasses.addPass(mlir::createCanonicalizerPass()); + modulePasses.addPass(mlir::createCSEPass()); + + // C target specific pass + modulePasses.addPass(createConvertVMToEmitCPass()); if (failed(passManager.run(moduleOp->getParentOfType()))) { return moduleOp.emitError() << "failed during transform passes"; diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/C/test/control_flow.mlir b/compiler/src/iree/compiler/Dialect/VM/Target/C/test/control_flow.mlir index 2e746bcc51a6..6c7bea18179a 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/C/test/control_flow.mlir +++ b/compiler/src/iree/compiler/Dialect/VM/Target/C/test/control_flow.mlir @@ -18,12 +18,10 @@ vm.module @control_flow_module { // CHECK-NEXT: int32_t [[V0:[^ ]*]]; // CHECK-NEXT: iree_status_t [[STATUS:[^ ]*]]; // CHECK-NEXT: int32_t [[C:[^ ]*]]; - // CHECK-NEXT: int32_t [[D:[^ ]*]]; // CHECK-NEXT: [[COND_NZ]] = vm_cmp_nz_i32([[COND]]); // CHECK-NEXT: [[COND_BOOL]] = (bool) [[COND_NZ]]; // CHECK-NEXT: if ([[COND_BOOL]]) { // CHECK-NEXT: [[C]] = [[A]]; - // CHECK-NEXT: [[D]] = [[A]]; // CHECK-NEXT: goto [[BB2:[^ ]*]]; // CHECK-NEXT: } else { // CHECK-NEXT: goto [[BB1:[^ ]*]]; @@ -31,10 +29,9 @@ vm.module @control_flow_module { // CHECK-NEXT: [[BB1]]: // CHECK-NEXT: [[B]] = vm_add_i32([[A]], [[A]]); // CHECK-NEXT: [[C]] = [[B]]; - // CHECK-NEXT: [[D]] = [[A]]; // CHECK-NEXT: goto [[BB2:[^ ]*]]; // CHECK-NEXT: [[BB2]]: - // CHECK-NEXT: [[V0]] = vm_add_i32([[C]], [[D]]); + // CHECK-NEXT: [[V0]] = vm_add_i32([[C]], [[A]]); // CHECK-NEXT: EMITC_DEREF_ASSIGN_VALUE([[RESULT]], [[V0]]); // CHECK-NEXT: [[STATUS]] = iree_ok_status(); // CHECK-NEXT: return [[STATUS]]; diff --git a/compiler/src/iree/compiler/Dialect/VM/Tools/VMOpEncoderGen.cpp b/compiler/src/iree/compiler/Dialect/VM/Tools/VMOpEncoderGen.cpp index 821a8701f894..b89d33c5cdcd 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Tools/VMOpEncoderGen.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Tools/VMOpEncoderGen.cpp @@ -36,11 +36,13 @@ bool emitEncodeFnDefs(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) { auto defs = recordKeeper.getAllDerivedDefinitions("VM_Op"); for (const auto *def : defs) { - if (def->isValueUnset("encoding")) + if (def->isValueUnset("encoding")) { continue; + } auto encodingExprs = def->getValueAsListOfDefs("encoding"); - if (encodingExprs.empty()) + if (encodingExprs.empty()) { continue; + } Operator op(def); tblgen::DialectNamespaceEmitter emitter(os, op.getDialect()); diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/AnnotateFunctions.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/AnnotateFunctions.cpp index 4b0be6ab54e8..10a5d9a3f223 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/AnnotateFunctions.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/AnnotateFunctions.cpp @@ -50,11 +50,13 @@ static FuncInfo analyzeFunction(IREE::VM::FuncOp funcOp, funcOp.walk([&](Operation *op) { // Collect callees. if (auto callOp = dyn_cast(op)) { - if (auto callee = symbolTable.lookup(callOp.getCallee())) + if (auto callee = symbolTable.lookup(callOp.getCallee())) { info.callees.push_back(callee); + } } else if (auto callOp = dyn_cast(op)) { - if (auto callee = symbolTable.lookup(callOp.getCallee())) + if (auto callee = symbolTable.lookup(callOp.getCallee())) { info.callees.push_back(callee); + } } // Check for yield ops. if (isa(op)) { @@ -127,8 +129,9 @@ class AnnotateFunctionsPass bool sccUnwind = false; for (CallGraphNode *node : scc) { - if (node->isExternal()) + if (node->isExternal()) { continue; + } Operation *op = node->getCallableRegion()->getParentOp(); auto it = funcInfos.find(op); if (it != funcInfos.end()) { @@ -139,17 +142,20 @@ class AnnotateFunctionsPass // Propagate from callees (already processed, outside this SCC). for (CallGraphNode *node : scc) { - if (node->isExternal()) + if (node->isExternal()) { continue; + } Operation *op = node->getCallableRegion()->getParentOp(); auto it = funcInfos.find(op); - if (it == funcInfos.end()) + if (it == funcInfos.end()) { continue; + } for (Operation *calleeOp : it->second.callees) { auto calleeIt = funcInfos.find(calleeOp); - if (calleeIt == funcInfos.end()) + if (calleeIt == funcInfos.end()) { continue; + } // Only propagate from callees outside this SCC (they have final // bits). @@ -170,8 +176,9 @@ class AnnotateFunctionsPass // Apply to all nodes in this SCC. for (CallGraphNode *node : scc) { - if (node->isExternal()) + if (node->isExternal()) { continue; + } Operation *op = node->getCallableRegion()->getParentOp(); auto it = funcInfos.find(op); if (it != funcInfos.end()) { @@ -184,8 +191,9 @@ class AnnotateFunctionsPass // Phase 4: Apply attributes to functions. for (auto funcOp : moduleOp.getOps()) { auto it = funcInfos.find(funcOp); - if (it == funcInfos.end()) + if (it == funcInfos.end()) { continue; + } if (it->second.needsYield && !funcOp->hasAttr("vm.yield")) { funcOp->setAttr("vm.yield", UnitAttr::get(funcOp.getContext())); diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD.bazel index a666cc18bb5f..926aba932870 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD.bazel @@ -25,6 +25,7 @@ iree_compiler_cc_library( "ConvertToYieldableCalls.cpp", "DeduplicateRodata.cpp", "DropEmptyModuleInitializers.cpp", + "DropOptimizationBarriers.cpp", "DropUnusedCalls.cpp", "GlobalInitialization.cpp", "HoistInlinedRodata.cpp", diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt index 14917efaaa8e..3e107cc2820c 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt @@ -22,6 +22,7 @@ iree_cc_library( "ConvertToYieldableCalls.cpp" "DeduplicateRodata.cpp" "DropEmptyModuleInitializers.cpp" + "DropOptimizationBarriers.cpp" "DropUnusedCalls.cpp" "GlobalInitialization.cpp" "HoistInlinedRodata.cpp" diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp index 9fa70b26b3fb..50243cf870fd 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp @@ -73,11 +73,13 @@ gatherUsedDialectInterfaces(mlir::ModuleOp moduleOp) { // Generic dialect lookup. dialect = op->getDialect(); } - if (!dialect) + if (!dialect) { return; + } auto *dialectInterface = dialect->getRegisteredInterface(); - if (!dialectInterface) + if (!dialectInterface) { return; + } resultSet.insert(dialectInterface); }); @@ -97,8 +99,9 @@ class ConversionPass : public IREE::VM::impl::ConversionPassBase { using Base::Base; void runOnOperation() override { - if (getOperation().getBody()->empty()) + if (getOperation().getBody()->empty()) { return; + } auto targetOptions = targetOptionsFromConversionPass(); @@ -144,6 +147,8 @@ class ConversionPass // legalization when types need conversion (e.g., index -> i32). conversionTarget.addIllegalOp(); patterns.add(typeConverter, context); + // Convert util.optimization_barrier to vm.optimization_barrier. + conversionTarget.addIllegalOp(); populateUtilToVMPatterns(context, conversionTarget, typeConverter, importTable, patterns); diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/ConvertToYieldableCalls.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/ConvertToYieldableCalls.cpp index fc53769abb33..bebed5bd905a 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/ConvertToYieldableCalls.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/ConvertToYieldableCalls.cpp @@ -120,11 +120,13 @@ class ConvertToYieldableCallsPass // Extract segment info. SmallVector segmentSizes; - for (auto val : callOp.getSegmentSizes()) + for (auto val : callOp.getSegmentSizes()) { segmentSizes.push_back(val.getSExtValue()); + } SmallVector segmentTypes; - for (auto typeAttr : callOp.getSegmentTypes()) + for (auto typeAttr : callOp.getSegmentTypes()) { segmentTypes.push_back(cast(typeAttr).getValue()); + } // Create the vm.call.variadic.yieldable op and erase the original call. builder.setInsertionPoint(callOp); diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/DeduplicateRodata.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/DeduplicateRodata.cpp index 451288028435..e06138a84d63 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/DeduplicateRodata.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/DeduplicateRodata.cpp @@ -77,8 +77,9 @@ class DeduplicateRodataPass replacer.addReplacement( [&](SymbolRefAttr attr) -> std::pair { auto replacement = replacements.find(attr); - if (replacement != replacements.end()) + if (replacement != replacements.end()) { return {replacement->getSecond(), WalkResult::skip()}; + } return {attr, WalkResult::skip()}; }); moduleOp.walk([&](Operation *op) { replacer.replaceElementsIn(op); }); diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/DropEmptyModuleInitializers.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/DropEmptyModuleInitializers.cpp index fd0b149f8c82..0ed40e5dcd39 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/DropEmptyModuleInitializers.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/DropEmptyModuleInitializers.cpp @@ -42,8 +42,9 @@ class DropEmptyModuleInitializersPass auto initFuncOp = symbolTable.lookup("__init"); if (initFuncOp && isFuncEmpty(initFuncOp)) { auto exportOp = exportOps[initFuncOp.getName()]; - if (exportOp) + if (exportOp) { exportOp.erase(); + } initFuncOp.erase(); } @@ -51,8 +52,9 @@ class DropEmptyModuleInitializersPass auto deinitFuncOp = symbolTable.lookup("__deinit"); if (deinitFuncOp && isFuncEmpty(deinitFuncOp)) { auto exportOp = exportOps[deinitFuncOp.getName()]; - if (exportOp) + if (exportOp) { exportOp.erase(); + } deinitFuncOp.erase(); } } diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/DropOptimizationBarriers.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/DropOptimizationBarriers.cpp new file mode 100644 index 000000000000..6536a0a6663c --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/DropOptimizationBarriers.cpp @@ -0,0 +1,28 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/VM/IR/VMOps.h" +#include "iree/compiler/Dialect/VM/Transforms/Passes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler::IREE::VM { + +#define GEN_PASS_DEF_DROPOPTIMIZATIONBARRIERSPASS +#include "iree/compiler/Dialect/VM/Transforms/Passes.h.inc" + +class DropOptimizationBarriersPass + : public IREE::VM::impl::DropOptimizationBarriersPassBase< + DropOptimizationBarriersPass> { + void runOnOperation() override { + getOperation()->walk([&](IREE::VM::OptimizationBarrierOp op) { + op.replaceAllUsesWith(op.getOperands()); + op.erase(); + }); + } +}; + +} // namespace mlir::iree_compiler::IREE::VM diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp index 98b47b80ad9c..eb944a133605 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp @@ -96,8 +96,9 @@ static void fixupGlobalMutability(Operation *moduleOp, explorer.initialize(); SmallVector deadOps; explorer.forEachGlobal([&](const Explorer::GlobalInfo *globalInfo) { - if (globalInfo->uses.empty()) + if (globalInfo->uses.empty()) { return; + } // TODO(benvanik): verify we want this behavior - we likely want to change // this to be mutable only if stores exist outside of initializers. // diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/MaterializeRefDiscards.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/MaterializeRefDiscards.cpp index d0093e1cce13..03f26cb7a95c 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/MaterializeRefDiscards.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/MaterializeRefDiscards.cpp @@ -15,6 +15,7 @@ #include "iree/compiler/Dialect/VM/IR/VMOps.h" #include "iree/compiler/Dialect/VM/Transforms/Passes.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Builders.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" @@ -75,8 +76,9 @@ class MaterializeRefDiscardsPass // to callee). bool isTerminatorMoveOperand(Value value, Operation *terminator) { auto refMoveOp = dyn_cast(terminator); - if (!refMoveOp) + if (!refMoveOp) { return false; + } // Check if value is forwarded to any successor - if so, it's not a "pure" // MOVE to callee, it's a forward to successor block. @@ -118,14 +120,16 @@ class MaterializeRefDiscardsPass bool isForwardedOnEdge(Value value, Block *pred, Block *succ) { Operation *terminator = pred->getTerminator(); auto branchOp = dyn_cast(terminator); - if (!branchOp) + if (!branchOp) { return false; + } for (unsigned i = 0; i < terminator->getNumSuccessors(); ++i) { if (terminator->getSuccessor(i) == succ) { auto operands = branchOp.getSuccessorOperands(i); - if (llvm::is_contained(operands.getForwardedOperands(), value)) + if (llvm::is_contained(operands.getForwardedOperands(), value)) { return true; + } } } return false; @@ -183,12 +187,22 @@ class MaterializeRefDiscardsPass // Get the operands being passed to succ on this edge. SuccessorOperands succOperands = branchOp.getSuccessorOperands(succIndex); SmallVector operandValues(succOperands.getForwardedOperands()); + unsigned producedCount = succOperands.getProducedOperandCount(); - // Add block arguments to newBlock to receive the operands. + // Add block arguments to newBlock to receive the forwarded operands. for (Value operand : operandValues) { newBlock->addArgument(operand.getType(), operand.getLoc()); } + // Add block arguments for produced operands (e.g., vm.call.yieldable + // results). These are created by the terminator at runtime and must be + // forwarded through the new block to the original successor. + for (unsigned i = 0; i < producedCount; ++i) { + // Produced operands come after forwarded operands in succ's arguments. + Type type = succ->getArgument(operandValues.size() + i).getType(); + newBlock->addArgument(type, loc); + } + // Update predecessor's terminator to go to new block instead of succ. // The operands stay the same - they'll now be passed to newBlock. terminator->setSuccessor(newBlock, succIndex); @@ -212,8 +226,9 @@ class MaterializeRefDiscardsPass LogicalResult processFunction(FuncOp funcOp) { // Skip empty functions. - if (funcOp.getBlocks().empty()) + if (funcOp.getBlocks().empty()) { return success(); + } // Compute liveness information. ValueLiveness liveness; @@ -224,8 +239,10 @@ class MaterializeRefDiscardsPass OpBuilder builder(funcOp.getContext()); - // Collect all refs in the function. - llvm::DenseSet allRefs; + // Collect all refs in the function in deterministic order. + // Walk blocks and operations in order and insert into SetVector, which + // maintains insertion order for deterministic iteration. + llvm::SetVector allRefs; for (Block &block : funcOp.getBlocks()) { for (BlockArgument arg : block.getArguments()) { if (isa(arg.getType())) { @@ -265,8 +282,9 @@ class MaterializeRefDiscardsPass SmallVector dyingRefs; for (Value ref : allRefs) { - if (escapingRefs.count(ref)) + if (escapingRefs.count(ref)) { continue; + } // Check if ref should be discarded on this edge. bool isInLiveOuts = llvm::is_contained(liveOuts, ref); @@ -281,22 +299,26 @@ class MaterializeRefDiscardsPass } // Skip if ref is neither in liveOuts nor forwarded on any edge. - if (!isInLiveOuts && !isForwardedOnAny) + if (!isInLiveOuts && !isForwardedOnAny) { continue; + } // Skip if ref is live-in to successor. - if (llvm::is_contained(succLiveIns, ref)) + if (llvm::is_contained(succLiveIns, ref)) { continue; + } // Skip if ref is forwarded on this specific edge. - if (isForwardedOnEdge(ref, &block, succ)) + if (isForwardedOnEdge(ref, &block, succ)) { continue; + } // Skip if ref is a MOVE operand of the terminator. // MOVE operands transfer ownership to the callee, so we must NOT // discard them - the callee takes responsibility for the ref. - if (isTerminatorMoveOperand(ref, terminator)) + if (isTerminatorMoveOperand(ref, terminator)) { continue; + } // Ref dies on this edge. dyingRefs.push_back(ref); @@ -324,17 +346,20 @@ class MaterializeRefDiscardsPass llvm::DenseMap opToIndex; for (Operation &op : block) { - if (isa(&op)) + if (isa(&op)) { continue; + } for (OpOperand &operand : op.getOpOperands()) { Value value = operand.get(); - if (!isa(value.getType())) + if (!isa(value.getType())) { continue; + } // Skip escaping refs. - if (escapingRefs.count(value)) + if (escapingRefs.count(value)) { continue; + } // Check if this is the last use and value doesn't escape via // live-outs. @@ -352,6 +377,19 @@ class MaterializeRefDiscardsPass if (op.hasTrait()) { continue; } + + // Skip refs that are MOVE operands of RefMoveInterface + // operations. When an operand is movable and this is its last + // use, the MOVE bit will be set by the register allocator and + // ownership transfers to the operation (e.g., vm.call, + // vm.call.variadic). Inserting a discard would be incorrect as + // the ref is consumed by the operation. + if (auto refMoveOp = dyn_cast(&op)) { + if (refMoveOp.isRefOperandMovable(operand.getOperandNumber())) { + continue; + } + } + // Group by insertion point. auto it = opToIndex.find(&op); if (it == opToIndex.end()) { @@ -382,8 +420,9 @@ class MaterializeRefDiscardsPass // Unused block arguments. SmallVector unusedBlockArgs; for (BlockArgument arg : block.getArguments()) { - if (!isa(arg.getType())) + if (!isa(arg.getType())) { continue; + } if (arg.use_empty() && !escapingRefs.count(arg)) { unusedBlockArgs.push_back(arg); } @@ -399,8 +438,9 @@ class MaterializeRefDiscardsPass llvm::DenseMap opToResultIndex; for (Operation &op : block) { for (Value result : op.getResults()) { - if (!isa(result.getType())) + if (!isa(result.getType())) { continue; + } if (result.use_empty() && !escapingRefs.count(result)) { auto it = opToResultIndex.find(&op); if (it == opToResultIndex.end()) { diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/OrdinalAllocation.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/OrdinalAllocation.cpp index ea072fe8eb85..d539c8209f9b 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/OrdinalAllocation.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/OrdinalAllocation.cpp @@ -85,8 +85,9 @@ class OrdinalAllocationPass int globalBytes = 0; for (auto sizeGlobalOps : llvm::enumerate(primitiveGlobalOps)) { size_t storageSize = sizeGlobalOps.index(); - if (sizeGlobalOps.value().empty()) + if (sizeGlobalOps.value().empty()) { continue; + } nextGlobalBytesOrdinal = llvm::alignTo(nextGlobalBytesOrdinal, storageSize); for (auto &globalOp : sizeGlobalOps.value()) { diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp index 8a3464da60e1..0098e03440ed 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp @@ -169,6 +169,16 @@ void buildVMTransformPassPipeline(OpPassManager &passManager, if (targetOptions.optimizeForStackSize) { passManager.addNestedPass(createSinkDefiningOpsPass()); } + + // Drop vm.optimization_barrier ops now that optimization is complete. + passManager.addNestedPass( + createDropOptimizationBarriersPass()); + + // Insert explicit discard ops for ref values at their last use points. + // Uses edge-based placement: refs dying on control flow edges get discards + // inserted on those edges, refs dying mid-block get discards after last use. + passManager.addNestedPass( + createMaterializeRefDiscardsPass()); } //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.td index 06e540f76673..2300a01fa645 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.td @@ -158,6 +158,16 @@ def DropUnusedCallsPass : let summary = "Drops vm.call ops that have no side effects and are unused."; } +def DropOptimizationBarriersPass : + Pass<"iree-vm-drop-optimization-barriers", "IREE::VM::ModuleOp"> { + let summary = "Drops vm.optimization_barrier ops."; + let description = [{ + Removes vm.optimization_barrier ops by replacing them with their operands. + This pass should run after all optimization passes that could fold through + the barriers. + }]; +} + def SinkDefiningOpsPass : Pass<"iree-vm-sink-defining-ops", "IREE::VM::ModuleOp"> { let summary = "Sinks defining ops with few uses to their use-sites."; diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/ResolveRodataLoads.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/ResolveRodataLoads.cpp index 8869fd662c82..3ede87b5b528 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/ResolveRodataLoads.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/ResolveRodataLoads.cpp @@ -67,11 +67,13 @@ static void processBufferGlobal(Explorer &explorer, const Explorer::GlobalInfo *globalInfo, DenseSet &deadOps) { // Ignore indirect/unanalyzable globals. - if (globalInfo->isIndirect) + if (globalInfo->isIndirect) { return; + } // Ignore mutable globals, as they could be changed to various values. - if (globalInfo->op.isGlobalMutable()) + if (globalInfo->op.isGlobalMutable()) { return; + } // If there are no stores to the global then it's always null. if (globalInfo->getStores().empty()) { @@ -90,8 +92,9 @@ static void processBufferGlobal(Explorer &explorer, // the program (there may be multiple initializers or control flow that // determines the stored value). auto rodataOp = findUniformlyStoredRodata(explorer, globalInfo); - if (!rodataOp) + if (!rodataOp) { return; + } // All stores to the global are of the same rodata. // Replace all of the loads with direct references to the rodata and then @@ -136,8 +139,9 @@ class ResolveRodataLoadsPass }); // Erase all ops after we're done iterating them. - for (auto *deadOp : deadOps) + for (auto *deadOp : deadOps) { deadOp->erase(); + } } }; diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/BUILD.bazel index ae4b2025c32c..3290f81e6711 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "annotate_functions.mlir", "convert_to_yieldable_calls.mlir", diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/materialize_ref_discards.mlir b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/materialize_ref_discards.mlir index 06de0d74f865..695fc4dccfbd 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/materialize_ref_discards.mlir +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/materialize_ref_discards.mlir @@ -2,13 +2,15 @@ // RUN: --pass-pipeline="builtin.module(vm.module(iree-vm-materialize-ref-discards))" \ // RUN: %s | FileCheck %s -// Single ref, single use - discard after use. +// Single ref, single use - NO discard (vm.call has MOVE semantics). // CHECK-LABEL: @single_ref_single_use // CHECK-SAME: (%[[BUF:.*]]: !vm.buffer) vm.module @my_module { vm.func @single_ref_single_use(%buf: !vm.buffer) { + // vm.call supports MOVE, so ref is consumed by call - no discard needed. // CHECK: vm.call @consume(%[[BUF]]) - // CHECK-NEXT: vm.discard.refs %[[BUF]] : !vm.buffer + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.return vm.call @consume(%buf) : (!vm.buffer) -> () vm.return } @@ -17,16 +19,19 @@ vm.module @my_module { // ----- -// Multiple uses - discard after LAST use only. +// Multiple uses - NO discard (both calls have MOVE semantics, only last matters). // CHECK-LABEL: @multiple_uses // CHECK-SAME: (%[[BUF:.*]]: !vm.buffer) vm.module @my_module { vm.func @multiple_uses(%buf: !vm.buffer) { + // First call: not last use, no discard. // CHECK: vm.call @consume(%[[BUF]]) // CHECK-NOT: vm.discard.refs vm.call @consume(%buf) : (!vm.buffer) -> () + // Second call: last use with MOVE semantics - ref consumed by call. // CHECK: vm.call @consume(%[[BUF]]) - // CHECK-NEXT: vm.discard.refs %[[BUF]] : !vm.buffer + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.return vm.call @consume(%buf) : (!vm.buffer) -> () vm.return } @@ -58,8 +63,10 @@ vm.module @my_module { vm.cond_br %cond, ^bb1(%buf : !vm.buffer), ^bb2 // CHECK: ^[[BB1]](%[[ARG:.*]]: !vm.buffer): ^bb1(%arg: !vm.buffer): + // vm.call has MOVE semantics - ref consumed by call. // CHECK: vm.call @consume(%[[ARG]]) - // CHECK-NEXT: vm.discard.refs %[[ARG]] : !vm.buffer + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.br ^[[EXIT:.*]] vm.call @consume(%arg) : (!vm.buffer) -> () vm.br ^exit ^bb2: @@ -72,13 +79,15 @@ vm.module @my_module { // ----- -// Multiple refs dying at same point - batched into single discard. +// Multiple refs passed to call - NO discards (MOVE semantics). // CHECK-LABEL: @multiple_refs_same_death_point // CHECK-SAME: (%[[A:.*]]: !vm.buffer, %[[B:.*]]: !vm.buffer) vm.module @my_module { vm.func @multiple_refs_same_death_point(%a: !vm.buffer, %b: !vm.buffer) { + // Both refs consumed by call with MOVE semantics. // CHECK: vm.call @consume2(%[[A]], %[[B]]) - // CHECK-NEXT: vm.discard.refs %[[A]], %[[B]] : !vm.buffer, !vm.buffer + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.return vm.call @consume2(%a, %b) : (!vm.buffer, !vm.buffer) -> () vm.return } @@ -130,8 +139,10 @@ vm.module @my_module { vm.cond_br %cond, ^then, ^else // CHECK: ^[[THEN]]: ^then: + // vm.call has MOVE semantics - ref consumed by call. // CHECK: vm.call @consume(%[[BUF]]) - // CHECK-NEXT: vm.discard.refs %[[BUF]] : !vm.buffer + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.br ^[[EXIT:.*]] vm.call @consume(%buf) : (!vm.buffer) -> () vm.br ^exit // CHECK: ^[[ELSE]]: @@ -157,14 +168,18 @@ vm.module @my_module { vm.cond_br %cond, ^then, ^else // CHECK: ^[[THEN]]: ^then: + // vm.call has MOVE semantics - ref consumed by call. // CHECK: vm.call @consume(%[[BUF]]) - // CHECK-NEXT: vm.discard.refs %[[BUF]] : !vm.buffer + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.br ^[[EXIT:.*]] vm.call @consume(%buf) : (!vm.buffer) -> () vm.br ^exit // CHECK: ^[[ELSE]]: ^else: + // vm.call has MOVE semantics - ref consumed by call. // CHECK: vm.call @consume(%[[BUF]]) - // CHECK-NEXT: vm.discard.refs %[[BUF]] : !vm.buffer + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.br ^[[EXIT]] vm.call @consume(%buf) : (!vm.buffer) -> () vm.br ^exit ^exit: @@ -189,8 +204,10 @@ vm.module @my_module { vm.cond_br %cond, ^then, ^else // CHECK: ^[[THEN]]: ^then: + // vm.call has MOVE semantics - ref consumed by call. // CHECK: vm.call @consume(%[[USED]]) - // CHECK-NEXT: vm.discard.refs %[[USED]] : !vm.buffer + // CHECK-NOT: vm.discard.refs %[[USED]] + // CHECK-NEXT: vm.br ^[[EXIT:.*]] vm.call @consume(%used) : (!vm.buffer) -> () vm.br ^exit // CHECK: ^[[ELSE]]: @@ -235,14 +252,18 @@ vm.module @my_module { vm.cond_br %cond, ^left, ^right // CHECK: ^[[LEFT]]: ^left: + // vm.call has MOVE semantics - ref consumed by call. // CHECK: vm.call @consume(%[[BUF]]) - // CHECK-NEXT: vm.discard.refs %[[BUF]] : !vm.buffer + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.br ^[[MERGE:.*]] vm.call @consume(%buf) : (!vm.buffer) -> () vm.br ^merge // CHECK: ^[[RIGHT]]: ^right: + // vm.call has MOVE semantics - ref consumed by call. // CHECK: vm.call @consume(%[[BUF]]) - // CHECK-NEXT: vm.discard.refs %[[BUF]] : !vm.buffer + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.br ^[[MERGE]] vm.call @consume(%buf) : (!vm.buffer) -> () vm.br ^merge ^merge: @@ -314,8 +335,8 @@ vm.module @my_module { %ref = vm.cast.ref.any %buffer : !vm.buffer -> !vm.ref // CHECK: %[[CAST:.*]] = vm.cast.any.ref %[[REF]] %cast = vm.cast.any.ref %ref : !vm.ref -> !vm.buffer - // %ref is discarded after its last use (vm.cast.any.ref) - // CHECK: vm.discard.refs %[[REF]] + // vm.cast has MOVE semantics - %ref consumed by cast, no discard. + // CHECK-NOT: vm.discard.refs %[[REF]] // CHECK: vm.cmp.eq.ref %[[BUFFER]], %[[CAST]] %eq = vm.cmp.eq.ref %buffer, %cast : !vm.buffer // Both %buffer and %cast die at same point - batched discard @@ -327,7 +348,8 @@ vm.module @my_module { // ----- // Each cast produces a new ref with independent lifetime. -// Refs are discarded after their last use. +// vm.cast has MOVE semantics - refs consumed by casts, not discarded. +// Only refs passed to vm.call operations (which also have MOVE) are consumed. // CHECK-LABEL: @chained_casts_independent vm.module @my_module { vm.func @chained_casts_independent() { @@ -339,18 +361,21 @@ vm.module @my_module { %ref1 = vm.cast.ref.any %buf : !vm.buffer -> !vm.ref // CHECK: %[[BUF2:.*]] = vm.cast.any.ref %[[REF1]] %buf2 = vm.cast.any.ref %ref1 : !vm.ref -> !vm.buffer - // ref1's last use is vm.cast.any.ref, discard it now - // CHECK: vm.discard.refs %[[REF1]] + // vm.cast has MOVE semantics - ref1 consumed, no discard. + // CHECK-NOT: vm.discard.refs %[[REF1]] // CHECK: %[[REF2:.*]] = vm.cast.ref.any %[[BUF2]] %ref2 = vm.cast.ref.any %buf2 : !vm.buffer -> !vm.ref - // buf2's last use is vm.cast.ref.any, discard it now - // CHECK: vm.discard.refs %[[BUF2]] + // vm.cast has MOVE semantics - buf2 consumed, no discard. + // CHECK-NOT: vm.discard.refs %[[BUF2]] // CHECK: vm.call @use_buffer(%[[BUF]]) vm.call @use_buffer(%buf) : (!vm.buffer) -> () - // CHECK: vm.discard.refs %[[BUF]] + // vm.call has MOVE semantics - buf consumed, no discard. + // CHECK-NOT: vm.discard.refs %[[BUF]] // CHECK: vm.call @use_ref(%[[REF2]]) vm.call @use_ref(%ref2) : (!vm.ref) -> () - // CHECK: vm.discard.refs %[[REF2]] + // vm.call has MOVE semantics - ref2 consumed, no discard. + // CHECK-NOT: vm.discard.refs %[[REF2]] + // CHECK-NEXT: vm.return vm.return } vm.import private @use_buffer(%buf: !vm.buffer) @@ -359,7 +384,7 @@ vm.module @my_module { // ----- -// Each ref discarded after its last use, no aliasing. +// Each ref consumed by vm.call (MOVE semantics). // CHECK-LABEL: @ref_used_then_original_used vm.module @my_module { vm.func @ref_used_then_original_used() { @@ -371,11 +396,13 @@ vm.module @my_module { %ref = vm.cast.ref.any %buf : !vm.buffer -> !vm.ref // CHECK: vm.call @use_ref(%[[REF]]) vm.call @use_ref(%ref) : (!vm.ref) -> () - // ref's last use is use_ref, discard it - // CHECK: vm.discard.refs %[[REF]] + // vm.call has MOVE semantics - ref consumed, no discard. + // CHECK-NOT: vm.discard.refs %[[REF]] // CHECK: vm.call @use_buffer(%[[BUF]]) vm.call @use_buffer(%buf) : (!vm.buffer) -> () - // CHECK: vm.discard.refs %[[BUF]] + // vm.call has MOVE semantics - buf consumed, no discard. + // CHECK-NOT: vm.discard.refs %[[BUF]] + // CHECK-NEXT: vm.return vm.return } vm.import private @use_buffer(%buf: !vm.buffer) @@ -384,7 +411,7 @@ vm.module @my_module { // ----- -// Ref used in branch, original used after merge - independent lifetimes. +// Ref used in branch (vm.call has MOVE), original used after merge. // CHECK-LABEL: @ref_in_branch_original_after_merge vm.module @my_module { vm.func @ref_in_branch_original_after_merge(%cond: i32) { @@ -399,8 +426,9 @@ vm.module @my_module { ^left: // CHECK: vm.call @use_ref(%[[REF]]) vm.call @use_ref(%ref) : (!vm.ref) -> () - // ref's last use on this path - // CHECK: vm.discard.refs %[[REF]] + // vm.call has MOVE semantics - ref consumed, no discard. + // CHECK-NOT: vm.discard.refs %[[REF]] + // CHECK-NEXT: vm.br ^[[MERGE:.*]] vm.br ^merge ^right: // ref not used on this path - edge discard @@ -409,7 +437,9 @@ vm.module @my_module { ^merge: // CHECK: vm.call @use_buffer(%[[BUF]]) vm.call @use_buffer(%buf) : (!vm.buffer) -> () - // CHECK: vm.discard.refs %[[BUF]] + // vm.call has MOVE semantics - buf consumed, no discard. + // CHECK-NOT: vm.discard.refs %[[BUF]] + // CHECK-NEXT: vm.return vm.return } vm.import private @use_buffer(%buf: !vm.buffer) @@ -418,7 +448,7 @@ vm.module @my_module { // ----- -// Ref used in loop, original used after loop exit - independent lifetimes. +// Ref used in loop (vm.call has MOVE), original used after loop exit. // CHECK-LABEL: @ref_in_loop_original_after vm.module @my_module { vm.func @ref_in_loop_original_after(%n: i32) { @@ -438,11 +468,14 @@ vm.module @my_module { %cmp = vm.cmp.lt.i32.s %next, %n : i32 vm.cond_br %cmp, ^loop(%next : i32), ^exit ^exit: - // ref is live throughout loop, dies at exit + // ref is NOT live at exit - it's consumed by vm.call in the loop. + // The last iteration's vm.call has MOVE semantics - ref consumed. // CHECK: vm.discard.refs %[[REF]] // CHECK: vm.call @use_buffer(%[[BUF]]) vm.call @use_buffer(%buf) : (!vm.buffer) -> () - // CHECK: vm.discard.refs %[[BUF]] + // vm.call has MOVE semantics - buf consumed, no discard. + // CHECK-NOT: vm.discard.refs %[[BUF]] + // CHECK-NEXT: vm.return vm.return } vm.import private @use_buffer(%buf: !vm.buffer) @@ -975,3 +1008,175 @@ vm.module @my_module { vm.return } } + +// ----- + +//===----------------------------------------------------------------------===// +// MOVE semantics for regular vm.call and vm.call.variadic +// These are the key tests for the bug fix: non-terminator calls that support +// MOVE semantics must NOT have discards inserted for their ref operands. +//===----------------------------------------------------------------------===// + +// vm.call with ref operand at last use - MOVE semantics, no discard. +// This was the original bug: mid-block discard logic would insert a discard +// after the call, but the call already consumed the ref with MOVE. +// CHECK-LABEL: @call_ref_move_last_use +vm.module @my_module { + vm.import private @consume(!vm.buffer) + vm.func @call_ref_move_last_use(%buf: !vm.buffer) { + // Ref passed to call with MOVE semantics - no discard should be inserted. + // CHECK: vm.call @consume(%[[BUF:.*]]) + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.return + vm.call @consume(%buf) : (!vm.buffer) -> () + vm.return + } +} + +// ----- + +// vm.call.variadic with ref operands at last use - MOVE semantics, no discard. +// This is the specific case from the smoketest.mlir bug report. +// CHECK-LABEL: @call_variadic_ref_move_last_use +vm.module @my_module { + vm.import private @hal.command_buffer.dispatch(!vm.buffer, !vm.ref, i32, i32, i32, i32) + vm.func @call_variadic_ref_move_last_use(%cmd: !vm.buffer, %exec: !vm.ref) { + %c0 = vm.const.i32 0 + %c1 = vm.const.i32 1 + // Ref operands passed with MOVE semantics - no discards should be inserted. + // CHECK: vm.call.variadic @hal.command_buffer.dispatch(%[[CMD:.*]], %[[EXEC:.*]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.return + vm.call.variadic @hal.command_buffer.dispatch(%cmd, %exec, %c0, %c1, %c1, %c1) : (!vm.buffer, !vm.ref, i32, i32, i32, i32) -> () + vm.return + } +} + +// ----- + +// Multiple refs passed to vm.call - all with MOVE semantics. +// CHECK-LABEL: @call_multiple_ref_operands_move +vm.module @my_module { + vm.import private @multi(!vm.buffer, !vm.buffer, !vm.ref) + vm.func @call_multiple_ref_operands_move(%buf1: !vm.buffer, %buf2: !vm.buffer, %ref: !vm.ref) { + // Use refs before call to ensure they're live. + // CHECK-DAG: vm.cmp.nz.ref %[[BUF1:[^ ]+]] + %nz1 = vm.cmp.nz.ref %buf1 : !vm.buffer + // CHECK-DAG: vm.cmp.nz.ref %[[BUF2:[^ ]+]] + %nz2 = vm.cmp.nz.ref %buf2 : !vm.buffer + // CHECK-DAG: vm.cmp.nz.ref %[[REF:[^ ]+]] + %nz3 = vm.cmp.nz.ref %ref : !vm.ref + // All refs passed to call with MOVE - no discards. + // CHECK-NOT: vm.discard.refs + // CHECK: vm.call @multi(%[[BUF1]], %[[BUF2]], %[[REF]]) + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.return + vm.call @multi(%buf1, %buf2, %ref) : (!vm.buffer, !vm.buffer, !vm.ref) -> () + vm.return + } +} + +// ----- + +// Ref used, then NOT passed to call - still needs discard. +// This verifies the fix is precise: only refs actually passed to MOVE calls +// are exempted from mid-block discards. +// CHECK-LABEL: @call_ref_not_passed +vm.module @my_module { + vm.import private @compute(i32) + vm.func @call_ref_not_passed(%buf: !vm.buffer, %x: i32) { + // CHECK: vm.cmp.nz.ref %[[BUF:[^ ]+]] + %nz = vm.cmp.nz.ref %buf : !vm.buffer + // Ref NOT passed to call, so it needs a discard after its last use. + // CHECK-NEXT: vm.discard.refs %[[BUF]] + // CHECK: vm.call @compute + vm.call @compute(%x) : (i32) -> () + vm.return + } +} + +// ----- + +// Mixed scenario: one ref passed to call (MOVE), another not passed (discard). +// CHECK-LABEL: @call_mixed_ref_operands +vm.module @my_module { + vm.import private @consume(!vm.buffer) + vm.func @call_mixed_ref_operands(%buf1: !vm.buffer, %buf2: !vm.buffer) { + // CHECK-DAG: vm.cmp.nz.ref %[[BUF1:[^ ]+]] + %nz1 = vm.cmp.nz.ref %buf1 : !vm.buffer + // CHECK-DAG: vm.cmp.nz.ref %[[BUF2:[^ ]+]] + %nz2 = vm.cmp.nz.ref %buf2 : !vm.buffer + // buf2 NOT passed to call, discarded after its last use. + // CHECK: vm.discard.refs %[[BUF2]] + // buf1 passed to call with MOVE, not discarded. + // CHECK: vm.call @consume(%[[BUF1]]) + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.return + vm.call @consume(%buf1) : (!vm.buffer) -> () + vm.return + } +} + +// ----- + +// Ref passed to multiple calls - only last call gets MOVE, earlier uses need ref retained. +// CHECK-LABEL: @call_ref_multiple_calls +vm.module @my_module { + vm.import private @consume(!vm.buffer) + vm.func @call_ref_multiple_calls(%buf: !vm.buffer) { + // First call: not last use, no discard. + // CHECK: vm.call @consume(%[[BUF:.*]]) + // CHECK-NOT: vm.discard.refs + vm.call @consume(%buf) : (!vm.buffer) -> () + // Second call: last use, MOVE semantics, no discard. + // CHECK: vm.call @consume(%[[BUF]]) + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.return + vm.call @consume(%buf) : (!vm.buffer) -> () + vm.return + } +} + +// ----- + +// vm.call.variadic with mixed ref and non-ref operands. +// CHECK-LABEL: @call_variadic_mixed_operands +vm.module @my_module { + vm.import private @mixed(!vm.buffer, i32, i32, !vm.ref, i32) + vm.func @call_variadic_mixed_operands(%buf: !vm.buffer, %ref: !vm.ref) { + %c1 = vm.const.i32 1 + %c2 = vm.const.i32 2 + %c3 = vm.const.i32 3 + // Refs passed with MOVE, integers are just values. + // CHECK: vm.call.variadic @mixed(%[[BUF:.*]], %{{.*}}, %{{.*}}, %[[REF:.*]], %{{.*}}) + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.return + vm.call.variadic @mixed(%buf, %c1, %c2, %ref, %c3) : (!vm.buffer, i32, i32, !vm.ref, i32) -> () + vm.return + } +} + +// ----- + +// Control flow with vm.call: ref used in one branch, passed to call in another. +// CHECK-LABEL: @call_ref_control_flow +vm.module @my_module { + vm.import private @consume(!vm.buffer) + vm.func @call_ref_control_flow(%buf: !vm.buffer, %cond: i32) { + // CHECK: vm.cond_br %{{.*}}, ^[[USE:.*]], ^[[CALL:.*]] + vm.cond_br %cond, ^use, ^call + ^use: + // Ref used here, then discarded. + // CHECK: vm.cmp.nz.ref %[[BUF:.*]] + // CHECK-NEXT: vm.discard.refs %[[BUF]] + %nz = vm.cmp.nz.ref %buf : !vm.buffer + vm.return + ^call: + // Ref passed to call with MOVE here, not discarded. + // CHECK: vm.call @consume(%[[BUF:.*]]) + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.return + vm.call @consume(%buf) : (!vm.buffer) -> () + vm.return + } +} diff --git a/compiler/src/iree/compiler/Dialect/VM/Utils/TypeTable.cpp b/compiler/src/iree/compiler/Dialect/VM/Utils/TypeTable.cpp index a1c56544d366..b92d4f6cbf23 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Utils/TypeTable.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Utils/TypeTable.cpp @@ -17,8 +17,9 @@ std::vector buildTypeTable(IREE::VM::ModuleOp moduleOp) { if (auto refPtrType = dyn_cast(type)) { type = refPtrType.getObjectType(); } - if (typeMap.count(type)) + if (typeMap.count(type)) { return; + } std::string str; llvm::raw_string_ostream sstream(str); type.print(sstream); @@ -31,10 +32,12 @@ std::vector buildTypeTable(IREE::VM::ModuleOp moduleOp) { }; for (auto funcOp : moduleOp.getBlock().getOps()) { funcOp.walk([&](Operation *op) { - for (auto type : op->getOperandTypes()) + for (auto type : op->getOperandTypes()) { tryInsertType(type); - for (auto type : op->getResultTypes()) + } + for (auto type : op->getResultTypes()) { tryInsertType(type); + } }); } diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Conversion/HALToVMVX/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VMVX/Conversion/HALToVMVX/test/BUILD.bazel index 99b86ec28216..3ab8248cceff 100644 --- a/compiler/src/iree/compiler/Dialect/VMVX/Conversion/HALToVMVX/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VMVX/Conversion/HALToVMVX/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted ["interface_ops.mlir"], include = ["*.mlir"], ), diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Conversion/StandardToVMVX/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VMVX/Conversion/StandardToVMVX/test/BUILD.bazel index 5684262af8b2..5a72a1402263 100644 --- a/compiler/src/iree/compiler/Dialect/VMVX/Conversion/StandardToVMVX/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VMVX/Conversion/StandardToVMVX/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ ], include = ["*.mlir"], diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Conversion/VMVXToVM/ConvertVMVXToVM.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Conversion/VMVXToVM/ConvertVMVXToVM.cpp index 451168805145..8953a4760746 100644 --- a/compiler/src/iree/compiler/Dialect/VMVX/Conversion/VMVXToVM/ConvertVMVXToVM.cpp +++ b/compiler/src/iree/compiler/Dialect/VMVX/Conversion/VMVXToVM/ConvertVMVXToVM.cpp @@ -66,8 +66,9 @@ class VMVXImportOpConversion : public OpConversionPattern { return failure(); } auto results = emitCall(op, adaptor, importOp, rewriter); - if (!results.has_value()) + if (!results.has_value()) { return failure(); + } rewriter.replaceOp(op, results.value()); return success(); } diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Conversion/VMVXToVM/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VMVX/Conversion/VMVXToVM/test/BUILD.bazel index 4995c8bdca7b..5e2a2c5c577e 100644 --- a/compiler/src/iree/compiler/Dialect/VMVX/Conversion/VMVXToVM/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VMVX/Conversion/VMVXToVM/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "binary.mlir", "copy.mlir", diff --git a/compiler/src/iree/compiler/Dialect/VMVX/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VMVX/IR/BUILD.bazel index b05ab342f6ab..4e49ae3ad0b6 100644 --- a/compiler/src/iree/compiler/Dialect/VMVX/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VMVX/IR/BUILD.bazel @@ -18,6 +18,7 @@ exports_files(["VMLXOps.td"]) iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "VMVXBase.td", "VMVXInterfaces.td", diff --git a/compiler/src/iree/compiler/Dialect/VMVX/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VMVX/IR/test/BUILD.bazel index 5684262af8b2..5a72a1402263 100644 --- a/compiler/src/iree/compiler/Dialect/VMVX/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VMVX/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ ], include = ["*.mlir"], diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/MaterializeConstants.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/MaterializeConstants.cpp index 73e7a69a455c..a51de1e55cc7 100644 --- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/MaterializeConstants.cpp +++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/MaterializeConstants.cpp @@ -55,8 +55,9 @@ class MaterializeConstantsPass final } // No constants found; omit the constant block entirely. - if (allLoadOps.empty()) + if (allLoadOps.empty()) { return; + } // Create global ops for each constant and replace the HAL ops so they load // from them. Each global will track what constant key it represents for diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp index 0be2c806c15f..4284b1b5b298 100644 --- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp +++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp @@ -230,8 +230,9 @@ struct FromMemRefSubView : public OpRewritePattern { LogicalResult matchAndRewrite(GetBufferDescriptorOp op, PatternRewriter &rewriter) const override { auto subview = op.getSource().template getDefiningOp(); - if (!subview) + if (!subview) { return failure(); + } auto loc = op.getLoc(); IndexSet indexSet(loc, rewriter); @@ -266,8 +267,9 @@ struct FromMemRefSubView : public OpRewritePattern { llvm::SmallBitVector droppedDims = subview.getDroppedDims(); int targetIndex = 0; for (int i = 0; i < sourceRank; ++i) { - if (droppedDims.test(i)) + if (droppedDims.test(i)) { continue; + } rewriter.replaceAllUsesWith( op.getSizes()[targetIndex], getValueOrCreateConstantIndexOp(rewriter, loc, @@ -297,8 +299,9 @@ struct FromHalInterfaceBindingSubspan auto binding = op.getSource() .template getDefiningOp(); - if (!binding) + if (!binding) { return failure(); + } auto loc = op.getLoc(); FailureOr resultDescriptor = @@ -379,8 +382,9 @@ struct FromAllocation : public OpRewritePattern { LogicalResult matchAndRewrite(GetBufferDescriptorOp op, PatternRewriter &rewriter) const override { auto alloca = op.getSource().template getDefiningOp(); - if (!alloca) + if (!alloca) { return failure(); + } auto memRefType = cast(alloca.getResult().getType()); if (!memRefType.getLayout().isIdentity()) { return rewriter.notifyMatchFailure(op, "not identity allocation"); @@ -413,8 +417,9 @@ struct FromGlobal : public OpRewritePattern { LogicalResult matchAndRewrite(GetBufferDescriptorOp op, PatternRewriter &rewriter) const override { auto global = op.getSource().template getDefiningOp(); - if (!global) + if (!global) { return failure(); + } auto memRefType = cast(global.getResult().getType()); if (!memRefType.getLayout().isIdentity()) { return rewriter.notifyMatchFailure(op, "not identity allocation"); diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/BUILD.bazel index c02abbca8ad2..2320156440e3 100644 --- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "materialize_constants.mlir", "resolve_buffer_descriptors.mlir", diff --git a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp index 656cc8984f70..3df39eb8ecb2 100644 --- a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp @@ -146,8 +146,9 @@ struct SwapExtractSliceOfFill final LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp, PatternRewriter &rewriter) const override { auto fillOp = extractOp.getSource().getDefiningOp(); - if (!fillOp) + if (!fillOp) { return failure(); + } auto newExtractOp = tensor::ExtractSliceOp::create( rewriter, extractOp.getLoc(), extractOp.getType(), diff --git a/compiler/src/iree/compiler/DispatchCreation/CloneProducersIntoDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/CloneProducersIntoDispatchRegions.cpp index 25b7c1915188..352f0dd5bcc5 100644 --- a/compiler/src/iree/compiler/DispatchCreation/CloneProducersIntoDispatchRegions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/CloneProducersIntoDispatchRegions.cpp @@ -35,8 +35,9 @@ struct CloneProducersIntoDispatchRegionsPass final IREE::Flow::ClonableIntoDispatchOptions options; options.aggressive = aggressive; funcOp->walk([&](IREE::Flow::DispatchRegionOp regionOp) { - if (failed(cloneProducersToRegion(rewriter, regionOp, options))) + if (failed(cloneProducersToRegion(rewriter, regionOp, options))) { return signalPassFailure(); + } }); funcOp->walk([&](Operation *op) { @@ -58,8 +59,9 @@ struct CloneProducersIntoDispatchRegionsPass final // Rerun the cloning again to move still clonable operations into // dispatches. funcOp->walk([&](IREE::Flow::DispatchRegionOp regionOp) { - if (failed(cloneProducersToRegion(rewriter, regionOp, options))) + if (failed(cloneProducersToRegion(rewriter, regionOp, options))) { return signalPassFailure(); + } }); } }; diff --git a/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp b/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp index 28ae827d99a3..a2bfeb53bff6 100644 --- a/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp @@ -291,8 +291,9 @@ populateReassocAndMaps(tensor::ExtractSliceOp sliceOp, auto isZeroOffsetAndFullSize = [&](OpFoldResult offset, OpFoldResult sliceSize, int64_t inputDim) { - if (!isZeroInteger(offset)) + if (!isZeroInteger(offset)) { return false; + } ValueBoundsConstraintSet::Variable inputSize(sliceOp.getSource(), inputDim); FailureOr maybeEqual = diff --git a/compiler/src/iree/compiler/DispatchCreation/ConvertTensorToFlow.cpp b/compiler/src/iree/compiler/DispatchCreation/ConvertTensorToFlow.cpp index 375ffe06cff2..56eea555be8e 100644 --- a/compiler/src/iree/compiler/DispatchCreation/ConvertTensorToFlow.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/ConvertTensorToFlow.cpp @@ -44,20 +44,25 @@ static FailureOr wrapInWorkgroupsOp(mlir::TensorDimTrackingRewriter &rewriter, Operation *op) { SmallVector dimOps = rewriter.getTensorDimOps(); - if (failed(IREE::Flow::simplifyDimOps(rewriter, rewriter.getTensorDimOps()))) + if (failed( + IREE::Flow::simplifyDimOps(rewriter, rewriter.getTensorDimOps()))) { return failure(); + } // Wrap operation. auto regionOp = IREE::Flow::wrapOpInDispatchRegion(rewriter, op); - if (failed(regionOp)) + if (failed(regionOp)) { return failure(); - if (failed(cloneProducersToRegion(rewriter, *regionOp))) + } + if (failed(cloneProducersToRegion(rewriter, *regionOp))) { return failure(); + } auto workgroupsOp = IREE::Flow::rewriteFlowDispatchRegionToFlowDispatchWorkgroups(*regionOp, rewriter); - if (failed(workgroupsOp)) + if (failed(workgroupsOp)) { return failure(); + } return *workgroupsOp; } @@ -68,8 +73,9 @@ wrapInWorkgroupsOp(mlir::TensorDimTrackingRewriter &rewriter, SmallVector result; for (Operation *rootOp : rootOps) { auto workgroupsOp = wrapInWorkgroupsOp(rewriter, rootOp); - if (failed(workgroupsOp)) + if (failed(workgroupsOp)) { return failure(); + } result.push_back(*workgroupsOp); } return result; @@ -84,8 +90,9 @@ static FailureOr convertInsertSliceOps( // Find eligible InsertSliceOps. SmallVector insertSliceOps; funcOp.walk([&](tensor::InsertSliceOp op) { - if (!isInDispatchRegion(op)) + if (!isInDispatchRegion(op)) { insertSliceOps.push_back(op); + } }); // Rewrite InsertSliceOps to FlowUpdateOps. @@ -102,8 +109,9 @@ static FailureOr convertInsertSliceOps( // Create a DispatchWorkgroupsOp for every remaining InsertSliceOp. FailureOr> newWorkgroupsOps = wrapInWorkgroupsOp(rewriter, remainingInsertSliceOps); - if (failed(newWorkgroupsOps)) + if (failed(newWorkgroupsOps)) { return failure(); + } workgroupsOps.append(newWorkgroupsOps->begin(), newWorkgroupsOps->end()); return numRemainingInsertSliceOps; @@ -118,8 +126,9 @@ static FailureOr convertExtractSliceOps( // Find eligible ExtractSliceOps. SmallVector extractSliceOps; funcOp.walk([&](tensor::ExtractSliceOp op) { - if (!isInDispatchRegion(op)) + if (!isInDispatchRegion(op)) { extractSliceOps.push_back(op); + } }); // Rewrite ExtractSliceOps to FlowSliceOps. @@ -137,8 +146,9 @@ static FailureOr convertExtractSliceOps( // Create a DispatchWorkgroupsOp for every remaining ExtractSliceOp. FailureOr> newWorkgroupsOps = wrapInWorkgroupsOp(rewriter, remainingExtractSliceOps); - if (failed(newWorkgroupsOps)) + if (failed(newWorkgroupsOps)) { return failure(); + } workgroupsOps.append(newWorkgroupsOps->begin(), newWorkgroupsOps->end()); return numRemainingExtractSliceOps; diff --git a/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp b/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp index e463015c33f7..c0bff40d7084 100644 --- a/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp @@ -162,8 +162,9 @@ void ElementwiseOpFusionPass::runOnOperation() { operands.insert(std::next(consumer->operand_begin(), fusedOperand->getOperandNumber() + 1), consumer->operand_end()); - if (operands.size() >= kIreeMaxOperandCount) + if (operands.size() >= kIreeMaxOperandCount) { return false; + } ElementwiseOpsFusabilityOptions options; options.fuseMultiReduction = fuseMultiReduction; diff --git a/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp b/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp index 5c5e08f924bd..6a7870bce3a8 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp @@ -390,8 +390,9 @@ foldUnitDimsOnGlobal(IRRewriter &rewriter, IREE::Util::GlobalOpInterface global, } auto newGlobalType = globalType.clone(newShape); auto initialValue = global.getGlobalInitialValue(); - if (!initialValue) + if (!initialValue) { return success(); + } // TODO: Handle other cases auto newInitialValue = llvm::TypeSwitch(initialValue) diff --git a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp index 1695cda731fc..370fd631b096 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp @@ -73,8 +73,9 @@ static llvm::SmallBitVector getOuterParallelLoops(Operation *op) { interfaceOp.getLoopIteratorTypes(); llvm::SmallBitVector parallelLoops(loopIteratorTypes.size()); for (auto iteratorType : llvm::enumerate(loopIteratorTypes)) { - if (iteratorType.value() != utils::IteratorType::parallel) + if (iteratorType.value() != utils::IteratorType::parallel) { break; + } parallelLoops.set(iteratorType.index()); } return parallelLoops; @@ -565,8 +566,9 @@ static bool canUseInOperandAsInitOperand(OpOperand *inOperand, // Check that the owner is a `generic` op. auto genericOp = dyn_cast(inOperand->getOwner()); - if (!genericOp) + if (!genericOp) { return false; + } // All loops to be parallel. if (genericOp.getNumLoops() != genericOp.getNumParallelLoops()) { @@ -574,13 +576,15 @@ static bool canUseInOperandAsInitOperand(OpOperand *inOperand, } /// The input operand cannot be an init operand already. - if (genericOp.isDpsInit(inOperand)) + if (genericOp.isDpsInit(inOperand)) { return false; + } // If the init operand value is used it cannot be reused for the input // operand. - if (genericOp.payloadUsesValueFromOperand(initOperand)) + if (genericOp.payloadUsesValueFromOperand(initOperand)) { return false; + } // Indexing map used to access the input and init have to match. if (genericOp.getMatchingIndexingMap(inOperand) != @@ -590,8 +594,9 @@ static bool canUseInOperandAsInitOperand(OpOperand *inOperand, // Types have to match for the input operand to reuse the buffer from the init // operand - if (inOperand->get().getType() != initOperand->get().getType()) + if (inOperand->get().getType() != initOperand->get().getType()) { return false; + } return true; } @@ -676,8 +681,9 @@ isFusableWithConsumer(OpOperand &fusedOperand, const FusionTracker &tracker, dyn_cast(producer); auto consumerFusionOp = dyn_cast(consumer); - if (!producerFusionOp || !consumerFusionOp) + if (!producerFusionOp || !consumerFusionOp) { return false; + } // Check that the consumer is all parallel. if (consumerFusionOp.getNumLoops() != @@ -727,10 +733,11 @@ isFusableWithConsumer(OpOperand &fusedOperand, const FusionTracker &tracker, } for (OpOperand *inputOperand : consumerDstOp.getDpsInputOperands()) { - if (inputOperand->get().getDefiningOp() != producer) + if (inputOperand->get().getDefiningOp() != producer) { continue; + } if (isa(producer) && - !llvm::any_of( + llvm::none_of( consumerDstOp.getDpsInitsMutable(), [&](OpOperand &initOperand) { return canUseInOperandAsInitOperand(inputOperand, &initOperand); })) { @@ -876,8 +883,9 @@ fuseRootsWithProducers(MLIRContext *context, Operation *root, Operation *candidate = worklist.pop_back_val(); for (OpOperand &operand : candidate->getOpOperands()) { Operation *producer = operand.get().getDefiningOp(); - if (!producer) + if (!producer) { continue; + } if (IREE::Flow::isClonableIntoDispatchOp(producer, clonableOptions) || tracker.isFusedOp(producer) || tracker.isRootOp(producer)) { continue; @@ -890,8 +898,9 @@ fuseRootsWithProducers(MLIRContext *context, Operation *root, SmallVector fusableUses = getFusableUses(context, producer, dominanceInfo, /*aggressiveFusion=*/options.aggressiveFusion); - if (fusableUses.empty() || fusableUses.front()->getOwner() != candidate) + if (fusableUses.empty() || fusableUses.front()->getOwner() != candidate) { continue; + } tracker.appendToFusionGroup(producer, fusionGroup); worklist.push_back(producer); @@ -926,8 +935,9 @@ decideFusableLinalgOps(Region ®ion, DominanceInfo const &dominanceInfo, } // Start with a root operation and fuse its producers. - if (tracker.isFusedOp(&op) || !isRootLikeOp(&op)) + if (tracker.isFusedOp(&op) || !isRootLikeOp(&op)) { continue; + } FusionGroup &newGroup = tracker.createFusionGroup(context, &op); fuseRootsWithProducers(context, &op, newGroup, dominanceInfo, options, tracker, @@ -950,8 +960,9 @@ decideFusableLinalgOps(Region ®ion, DominanceInfo const &dominanceInfo, SmallVector roots; for (Operation &op : llvm::reverse(block)) { // If it is part of a fusion group or root op, ignore it. - if (tracker.isFusedOp(&op) || tracker.isRootOp(&op)) + if (tracker.isFusedOp(&op) || tracker.isRootOp(&op)) { continue; + } // Only look for Linalg ops here. Avoid moving `linalg.fill` that aren't // fused with anything else into their own dispatches since it is better // to convert them to splats. Also avoid moving dequantization-like ops @@ -968,9 +979,8 @@ decideFusableLinalgOps(Region ®ion, DominanceInfo const &dominanceInfo, // by the `isClonableIntoDispatchOp` call above, but for now this is done // as a point fix. if (IREE::LinalgExt::isGatherlikeOp(&op) && - llvm::all_of(op.getUsers(), [](Operation *op) { - return isa(op); - })) { + llvm::all_of(op.getUsers(), + llvm::IsaPred)) { continue; } diff --git a/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp b/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp index ed03dc8891ea..68315e478288 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp @@ -253,7 +253,7 @@ static bool isHorizontalToGroup(Operation *op, llvm::SetVector slice; [[maybe_unused]] LogicalResult result = getBackwardSlice(op, &slice, options); assert(result.succeeded()); - return !llvm::any_of(currGroup, [&](Operation *groupedOp) { + return llvm::none_of(currGroup, [&](Operation *groupedOp) { return slice.contains(groupedOp); }); } diff --git a/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp b/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp index ae1ef9271693..3510b6209d4b 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp @@ -287,8 +287,9 @@ void FuseMultiUseElementwiseProducerPass::runOnOperation() { funcOp->emitError("failed to fuse multi-use producers"); return signalPassFailure(); } - if (numOfFusableCandidates.value() == 0) + if (numOfFusableCandidates.value() == 0) { break; + } } } diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionPreprocessing.cpp b/compiler/src/iree/compiler/DispatchCreation/FusionPreprocessing.cpp index 8d404bdc6074..0136587fbed6 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FusionPreprocessing.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FusionPreprocessing.cpp @@ -50,8 +50,9 @@ struct ElementwiseOpInterchangePattern final LogicalResult matchAndRewrite(linalg::GenericOp genericOp, PatternRewriter &rewriter) const override { if (!linalg::isElementwise(genericOp) || genericOp.getNumResults() != 1 || - genericOp.getNumDpsInputs() == 0) + genericOp.getNumDpsInputs() == 0) { return failure(); + } // All input maps must be equal and non-identity. All maps, including // output, must be be permutations. Permutation maps are checked by diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp index 1dd745fd964a..7f226201e986 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp @@ -20,16 +20,19 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand, ElementwiseOpsFusabilityOptions options) { Operation *producerOp = fusedOperand->get().getDefiningOp(); Operation *consumerOp = fusedOperand->getOwner(); - if (!producerOp) + if (!producerOp) { return false; + } // Check for i1 return types, if so aggressively fuse to avoid `i1` buffers. if (llvm::all_of(producerOp->getResultTypes(), [](Type t) { - if (t.isInteger(1)) + if (t.isInteger(1)) { return true; + } if (auto shapedType = dyn_cast(t)) { - if (shapedType.getElementType().isInteger(1)) + if (shapedType.getElementType().isInteger(1)) { return true; + } } return false; })) { @@ -38,8 +41,9 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand, // If the generic op is "just" copy, then fuse always. Block &body = producerOp->getRegion(0).front(); - if (std::begin(body)->hasTrait()) + if (std::begin(body)->hasTrait()) { return true; + } auto linalgConsumerOp = dyn_cast(consumerOp); if (!linalgConsumerOp) { diff --git a/compiler/src/iree/compiler/DispatchCreation/HoistEncodingOps.cpp b/compiler/src/iree/compiler/DispatchCreation/HoistEncodingOps.cpp index d7c2b6521f64..4561cc366c7d 100644 --- a/compiler/src/iree/compiler/DispatchCreation/HoistEncodingOps.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/HoistEncodingOps.cpp @@ -110,7 +110,7 @@ bubbleUpSetEncodingThroughGenericOp(RewriterBase &rewriter, auto resType = RankedTensorType::get( operandType.getShape(), operandType.getElementType(), newEncoding); Value encodedInput = IREE::Encoding::SetEncodingOp::create( - rewriter, loc, resType, operand->get()); + rewriter, loc, resType, operand->get(), /*encodingDims=*/ValueRange{}); encodedOperands.push_back(encodedInput); } diff --git a/compiler/src/iree/compiler/DispatchCreation/MaterializeDefaultWorkgroupCountRegion.cpp b/compiler/src/iree/compiler/DispatchCreation/MaterializeDefaultWorkgroupCountRegion.cpp index 4d7ee1d7c58d..1936ba45bc68 100644 --- a/compiler/src/iree/compiler/DispatchCreation/MaterializeDefaultWorkgroupCountRegion.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/MaterializeDefaultWorkgroupCountRegion.cpp @@ -57,8 +57,9 @@ static LogicalResult createDefaultWorkgroupCountRegion( SmallVector workloadLocs; for (auto argument : workgroupsOp.getArguments()) { Type argumentType = argument.getType(); - if (!isa(argumentType)) + if (!isa(argumentType)) { continue; + } workload.push_back(argument); workloadTypes.push_back(argumentType); workloadLocs.push_back(argument.getLoc()); @@ -114,8 +115,9 @@ static LogicalResult createDefaultWorkgroupCountRegion( rewriter.setInsertionPointToStart(&body.front()); int ordinalNumber = 0; for (auto [index, operand] : llvm::enumerate(workgroupsOp.getArguments())) { - if (!isa(operand.getType())) + if (!isa(operand.getType())) { continue; + } BlockArgument arg = workgroupsOp.getInputBlockArgument(index); auto ordinalOp = IREE::TensorExt::DispatchWorkloadOrdinalOp::create( rewriter, loc, arg, rewriter.getIndexAttr(ordinalNumber++)); diff --git a/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp b/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp index af7fe39078c6..721dbed57803 100644 --- a/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp @@ -36,15 +36,16 @@ using IREE::Encoding::EncodingAttr; //===---------------------------------------------------------------------===// static Value setEncoding(OpBuilder &builder, Location loc, Value source, - Attribute encodingAttr) { + Attribute encodingAttr, ValueRange encodingDims = {}) { auto resultType = cast(source.getType()).cloneWithEncoding(encodingAttr); - return IREE::Encoding::SetEncodingOp::create(builder, loc, resultType, - source); -}; + return IREE::Encoding::SetEncodingOp::create(builder, loc, resultType, source, + encodingDims); +} static Value unsetEncoding(OpBuilder &builder, Location loc, Value source, - SmallVector sizes) { + SmallVector sizes, + ValueRange encodingDims = {}) { SmallVector dynamicSizesVec; SmallVector staticSizesVec; dispatchIndexOpFoldResults(sizes, dynamicSizesVec, staticSizesVec); @@ -53,7 +54,8 @@ static Value unsetEncoding(OpBuilder &builder, Location loc, Value source, auto unsetEncodingReturnType = RankedTensorType::get(sourceType.getShape(), sourceType.getElementType()); return IREE::Encoding::UnsetEncodingOp::create( - builder, loc, unsetEncodingReturnType, source, dynamicSizesVec); + builder, loc, unsetEncodingReturnType, source, dynamicSizesVec, + encodingDims); } static SmallVector @@ -91,15 +93,18 @@ static LogicalResult setDataTilingEncodings(RewriterBase &rewriter, SmallVector encodedInputOperands; for (auto [idx, props] : llvm::enumerate(encProps.operands)) { Value src = linalgOp.getDpsInputs()[idx]; - Value encoded = setEncoding(rewriter, loc, src, props.encoding); + Value encoded = + setEncoding(rewriter, loc, src, props.encoding, props.dynamicValues); encodedInputOperands.push_back(encoded); } // Set encoding on init operand. // For now, we assume single init. assert(encProps.inits.size() == 1 && "Expected single init encoding"); - Value encodedInitOperand = setEncoding( - rewriter, loc, linalgOp.getDpsInits()[0], encProps.inits[0].encoding); + IREE::Encoding::EncodingProperties &initProps = encProps.inits[0]; + Value encodedInitOperand = + setEncoding(rewriter, loc, linalgOp.getDpsInits()[0], initProps.encoding, + initProps.dynamicValues); SmallVector encodedOperands(encodedInputOperands); encodedOperands.push_back(encodedInitOperand); @@ -110,7 +115,8 @@ static LogicalResult setDataTilingEncodings(RewriterBase &rewriter, // Sizes are computed by original output size. SmallVector outSizes = tensor::getMixedSizes(rewriter, loc, linalgOp.getDpsInits()[0]); - Value result = unsetEncoding(rewriter, loc, opTiled, outSizes); + Value result = + unsetEncoding(rewriter, loc, opTiled, outSizes, initProps.dynamicValues); rewriter.replaceOp(linalgOp, result); return success(); @@ -126,8 +132,9 @@ struct FoldFillWithSetEncoding final LogicalResult matchAndRewrite(IREE::Encoding::SetEncodingOp encodingOp, PatternRewriter &rewriter) const override { auto fillOp = encodingOp.getSource().getDefiningOp(); - if (!fillOp) + if (!fillOp) { return failure(); + } // Create a new fill op, with outs being defined by a new `tensor.empty` op. RankedTensorType encodingType = encodingOp.getResultType(); @@ -242,7 +249,8 @@ static std::optional padProducerOfValue(RewriterBase &rewriter, // Find the new value to yield. Value newYieldedVal = map.lookup(operand); auto encodingOp = IREE::Encoding::SetEncodingOp::create( - rewriter, returnOp->getLoc(), newResultType, newYieldedVal); + rewriter, returnOp->getLoc(), newResultType, newYieldedVal, + /*encodingDims=*/ValueRange{}); rewriter.modifyOpInPlace( returnOp, [&]() { returnOp.setOperand(resultNumber, encodingOp); }); @@ -276,7 +284,7 @@ static SmallVector padOperandsOfOp(RewriterBase &rewriter, Type operandType = operand.get().getType(); auto unsetEncodignOp = IREE::Encoding::UnsetEncodingOp::create( rewriter, op->getLoc(), operandType, paddedVal->paddedValue, - paddedVal->dynamicDims); + paddedVal->dynamicDims, /*encodingDims=*/ValueRange{}); op->setOperand(operandNum, unsetEncodignOp.getResult()); }); } diff --git a/compiler/src/iree/compiler/DispatchCreation/TensorPadToTensorInsertSlice.cpp b/compiler/src/iree/compiler/DispatchCreation/TensorPadToTensorInsertSlice.cpp index d31ccb781d9c..dceb76824250 100644 --- a/compiler/src/iree/compiler/DispatchCreation/TensorPadToTensorInsertSlice.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/TensorPadToTensorInsertSlice.cpp @@ -49,8 +49,9 @@ struct TensorPadOpConversion : public OpRewritePattern { // scalar that is not one of the arguments of the linalg operation. Region ®ion = padTensorOp.getRegion(); Block &block = region.front(); - if (!llvm::hasSingleElement(block)) + if (!llvm::hasSingleElement(block)) { return failure(); + } auto yieldOp = cast(block.getTerminator()); Value yieldVal = yieldOp.getValue(); if (llvm::any_of(block.getArguments(), diff --git a/compiler/src/iree/compiler/DispatchCreation/TransposeGenericOps.cpp b/compiler/src/iree/compiler/DispatchCreation/TransposeGenericOps.cpp index b2f5ad28dc4a..edd0c233d00a 100644 --- a/compiler/src/iree/compiler/DispatchCreation/TransposeGenericOps.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/TransposeGenericOps.cpp @@ -37,18 +37,21 @@ struct MakeReductionInnermostPattern final SmallVector interchange; bool needInterchange = false; unsigned numParallelLoop = genericOp.getNumParallelLoops(); - if (numParallelLoop == 0) + if (numParallelLoop == 0) { return failure(); + } for (auto iter : llvm::enumerate(genericOp.getIteratorTypesArray())) { if (linalg::isParallelIterator(iter.value())) { interchange.push_back(iter.index()); - if (iter.index() >= numParallelLoop) + if (iter.index() >= numParallelLoop) { needInterchange = true; + } } } // If all the parallel loops are outter loops skip the pattern. - if (!needInterchange) + if (!needInterchange) { return failure(); + } for (auto iter : llvm::enumerate(genericOp.getIteratorTypesArray())) { if (linalg::isReductionIterator(iter.value())) { interchange.push_back(iter.index()); @@ -83,8 +86,9 @@ struct TransposeGenericOpPattern final // elementwise op) with a single use. auto producer = operand->get().getDefiningOp(); if (!producer || !llvm::hasSingleElement(producer->getUsers()) || - linalg::isElementwise(producer)) + linalg::isElementwise(producer)) { continue; + } // check if the generic op has a non-identity map for the operand. auto indexingMap = genericOp.getMatchingIndexingMap(operand); @@ -93,11 +97,13 @@ struct TransposeGenericOpPattern final return rewriter.notifyMatchFailure(genericOp, "already normalized"); } // The map must be a permutation. If not, then look for other operand. - if (!indexingMap.isPermutation()) + if (!indexingMap.isPermutation()) { continue; + } - if (!mapForInterchange) + if (!mapForInterchange) { mapForInterchange = indexingMap; + } } if (!mapForInterchange) { diff --git a/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel b/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel index 98f8f232ddfb..cdc161e94f1c 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel +++ b/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel @@ -15,30 +15,27 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "annotate_data_tiling_hints.mlir", "bitcast_unsupported_element_types.mlir", + "bubble_up_expand_shapes.mlir", + "bubble_up_extract_slice.mlir", "clone_producers_into_dispatch_regions.mlir", "collapse_dimensions.mlir", "collapse_linalg_generic_on_tensors.mlir", - "elementwise_op_fusion.mlir", - "dispatch_region_formation_preprocessing.mlir", - "fold_reshapes_into_tensor_barriers.mlir", - "fold_unit_dims.mlir", - "form_dispatch_regions.mlir", - "dispatch_linalg_on_tensors.mlir", "convert_encoding_to_flow.mlir", "convert_region_to_workgroups.mlir", - "bubble_up_expand_shapes.mlir", - "bubble_up_extract_slice.mlir", - "form_dispatch_workgroups.mlir", "dispatch_linalg_ext_fusion.mlir", - "hoist_encoding_ops.mlir", - "hoist_uniform_scalar_compute.mlir", - "insert_tensor_barriers.mlir", - "remove_tensor_barriers.mlir", + "dispatch_linalg_on_tensors.mlir", "dispatch_linalg_on_tensors_default.mlir", "dispatch_linalg_on_tensors_fusion_with_transpose.mlir", + "dispatch_region_formation_preprocessing.mlir", + "elementwise_op_fusion.mlir", + "fold_reshapes_into_tensor_barriers.mlir", + "fold_unit_dims.mlir", + "form_dispatch_regions.mlir", + "form_dispatch_workgroups.mlir", "form_scalar_dispatches.mlir", "form_split_reduction_dispatches.mlir", "fuse_encoding_ops_into_dispatch_regions.mlir", @@ -46,6 +43,9 @@ iree_lit_test_suite( "fuse_multiuse_elementwise_producer.mlir", "fuse_multiuse_intra_dispatch.mlir", "fusion_preprocessing.mlir", + "hoist_encoding_ops.mlir", + "hoist_uniform_scalar_compute.mlir", + "insert_tensor_barriers.mlir", "materialize_default_workgroup_count_region.mlir", "pad_fusion_with_consumer.mlir", "pad_fusion_with_producer.mlir", @@ -53,6 +53,7 @@ iree_lit_test_suite( "pipeline_tests_aggressive.mlir", "pipeline_tests_split_reduction.mlir", "propagate_encodings.mlir", + "remove_tensor_barriers.mlir", "set_encoding.mlir", "set_encoding_padding.mlir", "set_encoding_pipeline.mlir", diff --git a/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel b/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel index 4bb133040940..06ef7e993587 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel +++ b/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel @@ -42,6 +42,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/TensorExt/IR", "//compiler/src/iree/compiler/Dialect/Util/IR", "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:BufferizationInterfaces", "@llvm-project//mlir:ControlFlowInterfaces", @@ -52,6 +53,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:LinalgOpsIncGen", "@llvm-project//mlir:LinalgStructuredOpsIncGen", "@llvm-project//mlir:MLProgramDialect", + "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:ValueBoundsOpInterface", diff --git a/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt b/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt index a183c5539427..25c6eee54999 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt +++ b/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt @@ -31,6 +31,7 @@ iree_cc_library( "UtilExternalModels.cpp" DEPS LLVMSupport + MLIRAffineDialect MLIRArithDialect MLIRControlFlowInterfaces MLIRGPUDialect @@ -39,6 +40,7 @@ iree_cc_library( MLIRLinalgOpsIncGenLib MLIRLinalgStructuredOpsIncGenLib MLIRMLProgramDialect + MLIRMemRefDialect MLIRSCFDialect MLIRTensorDialect MLIRValueBoundsOpInterface diff --git a/compiler/src/iree/compiler/ExternalInterfaces/EncodingExternalModels.cpp b/compiler/src/iree/compiler/ExternalInterfaces/EncodingExternalModels.cpp index 999ff6b543ba..4b5f7998e384 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/EncodingExternalModels.cpp +++ b/compiler/src/iree/compiler/ExternalInterfaces/EncodingExternalModels.cpp @@ -71,7 +71,8 @@ static IREE::Encoding::PropagationResult propagateThroughEncodingCastableOp( } // Otherwise, we need to create a new set_encoding op. auto setEncodingOp = IREE::Encoding::SetEncodingOp::create( - builder, op->getLoc(), encodedOperandType, operand); + builder, op->getLoc(), encodedOperandType, operand, + /*encodingDims=*/ValueRange{}); encodedOperands.push_back(setEncodingOp.getResult()); result.generatedEncodingOps.push_back(setEncodingOp); } @@ -100,7 +101,7 @@ static IREE::Encoding::PropagationResult propagateThroughEncodingCastableOp( std::tie(std::ignore, resultDynamicDims) = decomposeMixedValues(mixedSizes); auto unsetEncodingOp = IREE::Encoding::UnsetEncodingOp::create( builder, op->getLoc(), originalResult.getType(), encodedResult, - resultDynamicDims); + resultDynamicDims, /*encodingDims=*/ValueRange{}); result.generatedEncodingOps.push_back(unsetEncodingOp); result.replacements.push_back(unsetEncodingOp.getResult()); } diff --git a/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp b/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp index 3f0be78a7a1f..762b03a09c28 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp +++ b/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp @@ -80,6 +80,10 @@ struct OptionalOpAffinityAttrExternalModel IREE::Stream::AffinityAttr getResultAffinityAttr(Operation *op) const { return getAffinityAttr(op); } + + void removeAffinityAttrs(Operation *op) const { + op->removeAttr("stream.affinity"); + } }; struct FlowBarrierTargetAffinityAttrExternalModel @@ -108,6 +112,8 @@ struct FlowBarrierTargetAffinityAttrExternalModel IREE::Stream::AffinityAttr getResultAffinityAttr(Operation *op) const { return getAffinityAttr(op); } + + void removeAffinityAttrs(Operation *op) const { op->removeAttr("target"); } }; struct FlowTransferTargetAffinityAttrExternalModel @@ -132,6 +138,8 @@ struct FlowTransferTargetAffinityAttrExternalModel IREE::Stream::AffinityAttr getResultAffinityAttr(Operation *op) const { return getAffinityAttr(op); } + + void removeAffinityAttrs(Operation *op) const { op->removeAttr("target"); } }; template @@ -164,6 +172,8 @@ struct HALTensorAffinityAttrExternalModel IREE::Stream::AffinityAttr getResultAffinityAttr(Operation *op) const { return getAffinityAttr(op); } + + void removeAffinityAttrs(Operation *op) const { op->removeAttr("affinity"); } }; template @@ -197,6 +207,10 @@ struct GlobalOpAffinityAttrExternalModel IREE::Stream::AffinityAttr getResultAffinityAttr(Operation *op) const { return getAffinityAttr(op); } + + void removeAffinityAttrs(Operation *op) const { + op->removeAttr("stream.affinity"); + } }; template @@ -227,6 +241,10 @@ struct AffinityOpAttrExternalModel IREE::Stream::AffinityAttr getResultAffinityAttr(Operation *op) const { return getAffinityAttr(op); } + + void removeAffinityAttrs(Operation *op) const { + op->removeAttr("stream.affinity"); + } }; struct TensorAffinityTypeExternalModel diff --git a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp index 75d442ba463d..e1ff0e5c6d2b 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp +++ b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp @@ -16,11 +16,14 @@ #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" @@ -49,6 +52,232 @@ getDivisibilityOfOperand(Value v, return IREE::Util::ConstantIntDivisibility(1, 1); } +/// Visits affine expressions and recursively calculates the divisibilities of +/// each subexpression. The final divisibilities of the expression and its +/// subexpressions will be stored in the map for which a reference is provided +/// to the AffineExprDivisibilityFinder (i.e., `divisibilityMap`). +class AffineExprDivisibilityFinder + : public AffineExprVisitor { +public: + using ExprDivisibilityMap = + llvm::DenseMap; + AffineExprDivisibilityFinder(ExprDivisibilityMap &divisibilityMap) + : divisibilityMap(divisibilityMap) {} + + IREE::Util::ConstantIntDivisibility + visitConstantExpr(AffineConstantExpr expr) { + // Constant expressions are trivial, since they are always static. + uint64_t constValue = std::abs(expr.getValue()); + return IREE::Util::ConstantIntDivisibility(constValue, constValue); + } + + IREE::Util::ConstantIntDivisibility visitDimExpr(AffineDimExpr expr) { + // Dim expressions cannot be analyzed further, so return the divisibility + // in `divisibilityMap` if it has been populated by the caller, or fallback + // to the minimum divisibility. + if (divisibilityMap.contains(expr)) { + return divisibilityMap[expr]; + } + return IREE::Util::IntegerDivisibility::getMinDivisibility().getValue(); + } + + IREE::Util::ConstantIntDivisibility visitSymbolExpr(AffineSymbolExpr expr) { + // Symbol expressions cannot be analyzed further, so return the divisibility + // in `divisibilityMap` if it has been populated by the caller, or fallback + // to the minimum divisibility. + if (divisibilityMap.contains(expr)) { + return divisibilityMap[expr]; + } + return IREE::Util::IntegerDivisibility::getMinDivisibility().getValue(); + } + + /// Infer the divisibility of an addition or subtraction expression by + /// recursively visiting the LHS and RHS, and then unioning the results. + IREE::Util::ConstantIntDivisibility visitAddExpr(AffineBinaryOpExpr expr) { + if (divisibilityMap.contains(expr)) { + return divisibilityMap[expr]; + } + // The divisibility of an addition is the GCD of its constituents' + // divisibilities. + IREE::Util::ConstantIntDivisibility lhsDiv = visit(expr.getLHS()); + IREE::Util::ConstantIntDivisibility rhsDiv = visit(expr.getRHS()); + return lhsDiv.getUnion(rhsDiv); + } + + /// Infer the divisibility of a multiplication expression by recursively + /// visiting the LHS and RHS, and then multiplying the results. + IREE::Util::ConstantIntDivisibility visitMulExpr(AffineBinaryOpExpr expr) { + if (divisibilityMap.contains(expr)) { + return divisibilityMap[expr]; + } + // The divisibility of a multiplication is the product of its constituents' + // divisibilities. + IREE::Util::ConstantIntDivisibility lhsDiv = visit(expr.getLHS()); + IREE::Util::ConstantIntDivisibility rhsDiv = visit(expr.getRHS()); + return IREE::Util::ConstantIntDivisibility(lhsDiv.udiv() * rhsDiv.udiv(), + lhsDiv.sdiv() * rhsDiv.sdiv()); + } + + IREE::Util::ConstantIntDivisibility + visitFloorDivExpr(AffineBinaryOpExpr expr) { + return visitDivExpr(expr); + } + + IREE::Util::ConstantIntDivisibility + visitCeilDivExpr(AffineBinaryOpExpr expr) { + return visitDivExpr(expr); + } + + /// Mod expressions could be inferred to be zero in some cases, but for now + /// just return the minimum divisibility. + /// TODO(Max191): Handle evenly divisible cases, and ensure that the zero + /// divisibility propagates properly through parent expressions. + IREE::Util::ConstantIntDivisibility visitModExpr(AffineBinaryOpExpr expr) { + return visitInvalidExpr(expr); + } + +private: + IREE::Util::ConstantIntDivisibility + visitInvalidExpr(AffineBinaryOpExpr expr) { + return IREE::Util::IntegerDivisibility::getMinDivisibility().getValue(); + } + + /// Helper shared by ceildiv and floordiv implementations. Returns the minimum + /// divisibility as a fallback if the divisor is not a constant, because the + /// divisibility cannot be inferred in this case. If the divisor is a + /// constant, then this function recursively visits the dividend, and returns + /// the quotient of the dividend's divisibility with the divisor. + IREE::Util::ConstantIntDivisibility visitDivExpr(AffineBinaryOpExpr expr) { + if (divisibilityMap.contains(expr)) { + return divisibilityMap[expr]; + } + auto constRhs = dyn_cast(expr.getRHS()); + // Division by zero is undefined, so return the minimum divisibility. + if (!constRhs || constRhs.getValue() == 0) { + return IREE::Util::ConstantIntDivisibility(1, 1); + } + auto constValue = static_cast(std::abs(constRhs.getValue())); + IREE::Util::ConstantIntDivisibility lhsDiv = visit(expr.getLHS()); + uint64_t divUDiv = + lhsDiv.udiv() % constValue == 0 ? lhsDiv.udiv() / constValue : 1; + uint64_t divSDiv = + lhsDiv.sdiv() % constValue == 0 ? lhsDiv.sdiv() / constValue : 1; + return IREE::Util::ConstantIntDivisibility(divUDiv, divSDiv); + } + + ExprDivisibilityMap &divisibilityMap; +}; + +/// Returns the divisibilities of each AffineMap result based on the +/// divisibilities of its dims and symbols. The `dimAndSymbolDivisibilities` +/// should contain the divisibilities of the dims, followed by the +/// divisibilities of the symbols in ascending order by their positions. +static SmallVector getResultDivisibilities( + AffineMap map, + ArrayRef dimAndSymbolDivisibilities) { + // Seed the AffineExprDivisibilityFinder with the dimAndSymbolDivisibilities. + llvm::DenseMap + exprDivisibilityMap; + SmallVector inputExprs; + inputExprs.append(llvm::map_to_vector( + llvm::seq(map.getNumDims()), + [&](int64_t dim) { return getAffineDimExpr(dim, map.getContext()); })); + inputExprs.append(llvm::map_to_vector( + llvm::seq(map.getNumSymbols()), + [&](int64_t sym) { return getAffineSymbolExpr(sym, map.getContext()); })); + for (auto [expr, divisibility] : + llvm::zip_equal(inputExprs, dimAndSymbolDivisibilities)) { + exprDivisibilityMap[expr] = divisibility; + } + AffineExprDivisibilityFinder divisibilityFinder(exprDivisibilityMap); + + // Walk each result expression and compute their divisibilities. + SmallVector resultDivisibilities; + for (AffineExpr resultExpr : map.getResults()) { + resultDivisibilities.push_back(divisibilityFinder.visit(resultExpr)); + } + return resultDivisibilities; +} + +struct AffineApplyInferIntDivisibilityOpInterface + : public IREE::Util::InferIntDivisibilityOpInterface::ExternalModel< + AffineApplyInferIntDivisibilityOpInterface, affine::AffineApplyOp> { + + void inferResultDivisibility( + Operation *op, ArrayRef argDivs, + IREE::Util::SetIntDivisibilityFn setResultDivs) const { + auto affineApplyOp = cast(op); + SmallVector operandDivisibilities; + for (auto [operand, divisibility] : + llvm::zip(affineApplyOp.getOperands(), argDivs)) { + operandDivisibilities.push_back( + getDivisibilityOfOperand(operand, divisibility)); + } + + SmallVector resultDivisibilities = + getResultDivisibilities(affineApplyOp.getMap(), operandDivisibilities); + for (auto [result, divisibility] : + llvm::zip_equal(affineApplyOp->getResults(), resultDivisibilities)) { + setResultDivs(result, divisibility); + } + } +}; + +/// Infer the result divisibility of an affine.min or affine.max operation +/// based on its operand divisibilities. The result divisibility is the GCD +/// of the divisibilities of each of the affine map results, because the result +/// of the affine.min/max op could be any of these results. +template +static void inferAffineMinOrMaxResultDivisibility( + MinOrMaxTy minOrMaxOp, ArrayRef argDivs, + IREE::Util::SetIntDivisibilityFn setResultDivs) { + static_assert( + llvm::is_one_of::value, + "MinOrMaxTy must be affine::AffineMinOp or affine::AffineMaxOp"); + SmallVector operandDivisibilities; + for (auto [operand, divisibility] : + llvm::zip(minOrMaxOp.getOperands(), argDivs)) { + operandDivisibilities.push_back( + getDivisibilityOfOperand(operand, divisibility)); + } + + SmallVector resultDivisibilities = + getResultDivisibilities(minOrMaxOp.getMap(), operandDivisibilities); + + IREE::Util::ConstantIntDivisibility resultDivisibility = + resultDivisibilities.pop_back_val(); + for (auto divisibility : resultDivisibilities) { + resultDivisibility = resultDivisibility.getUnion(divisibility); + } + setResultDivs(minOrMaxOp.getResult(), resultDivisibility); +} + +struct AffineMinInferIntDivisibilityOpInterface + : public IREE::Util::InferIntDivisibilityOpInterface::ExternalModel< + AffineMinInferIntDivisibilityOpInterface, affine::AffineMinOp> { + + void inferResultDivisibility( + Operation *op, ArrayRef argDivs, + IREE::Util::SetIntDivisibilityFn setResultDivs) const { + auto affineMinOp = cast(op); + inferAffineMinOrMaxResultDivisibility(affineMinOp, argDivs, setResultDivs); + } +}; + +struct AffineMaxInferIntDivisibilityOpInterface + : public IREE::Util::InferIntDivisibilityOpInterface::ExternalModel< + AffineMaxInferIntDivisibilityOpInterface, affine::AffineMaxOp> { + + void inferResultDivisibility( + Operation *op, ArrayRef argDivs, + IREE::Util::SetIntDivisibilityFn setResultDivs) const { + auto affineMaxOp = cast(op); + inferAffineMinOrMaxResultDivisibility(affineMaxOp, argDivs, setResultDivs); + } +}; + struct ArithConstantInferIntDivisibilityOpInterface : public IREE::Util::InferIntDivisibilityOpInterface::ExternalModel< ArithConstantInferIntDivisibilityOpInterface, arith::ConstantOp> { @@ -104,8 +333,13 @@ struct ArithDivUIInferIntDivisibilityOpInterface auto lhsDivisibility = getDivisibilityOfOperand(divOp.getLhs(), argDivs[0]); - uint64_t divUDiv = lhsDivisibility.udiv() / intVal.getZExtValue(); - uint64_t divSDiv = lhsDivisibility.sdiv() / std::abs(intVal.getSExtValue()); + uint64_t divUDiv = lhsDivisibility.udiv() % intVal.getZExtValue() == 0 + ? lhsDivisibility.udiv() / intVal.getZExtValue() + : 1; + uint64_t divSDiv = + lhsDivisibility.sdiv() % std::abs(intVal.getSExtValue()) == 0 + ? lhsDivisibility.sdiv() / std::abs(intVal.getSExtValue()) + : 1; setResultDivs(divOp, IREE::Util::ConstantIntDivisibility(divUDiv, divSDiv)); } @@ -174,8 +408,9 @@ struct GlobalOpInterfaceExternalModel IREE::Util::InliningPolicyAttrInterface getGlobalInliningPolicy(Operation *op) const { - if (op->hasAttr("noinline")) + if (op->hasAttr("noinline")) { return IREE::Util::InlineNeverAttr::get(op->getContext()); + } return {}; } void @@ -282,8 +517,9 @@ struct LinalgOpTiedOpInterface SmallVector getTiedResultOperandIndices(Operation *op) const { SmallVector result; - for (unsigned i = 0; i < op->getNumResults(); ++i) + for (unsigned i = 0; i < op->getNumResults(); ++i) { result.push_back(*getTiedResultOperandIndex(op, i)); + } return result; } }; @@ -902,10 +1138,60 @@ struct SCFIndexSwitchOpMutableRegionBranchOpInterface } }; +// Hoistable interface for region-containing control flow operations. +// Control flow is hoistable if control operands are constant and +// nested operations are hoistable (checked via atomic hoisting). +template +struct RegionControlFlowHoistableOpInterface + : public IREE::Util::HoistableOpInterface::ExternalModel< + RegionControlFlowHoistableOpInterface, OpTy> { + bool isHoistableOp(Operation *op) const { + // Control flow is hoistable if all nested operations are hoistable. + for (Region ®ion : op->getRegions()) { + WalkResult result = region.walk([](Operation *nestedOp) { + // Check if nested op is hoistable. + bool isHoistable = false; + if (auto hoistable = + dyn_cast(nestedOp)) { + isHoistable = hoistable.isHoistableOp(); + } else { + // Ops without interface must be memory-effect-free to be hoistable. + isHoistable = mlir::isMemoryEffectFree(nestedOp); + } + if (!isHoistable) { + return WalkResult::interrupt(); + } + // Don't descend into IsolatedFromAbove ops - treat them atomically. + return nestedOp->hasTrait() + ? WalkResult::skip() + : WalkResult::advance(); + }); + if (result.wasInterrupted()) { + return false; + } + } + return true; + } + + bool isHoistableLeafOp(Operation *) const { return false; } + bool isAtomicallyHoistableOp(Operation *) const { return true; } + bool isOperandHoistable(Operation *, OpOperand *) const { return true; } +}; + +template +struct RegionControlFlowHoistableOpInterfaceHelper { + static void registerOpInterface(MLIRContext *context) { + (Ops::template attachInterface>( + *context), + ...); + } +}; + } // namespace void registerUtilExternalModels(DialectRegistry ®istry) { // Must ensure that any dependent dialects are registered. + registry.insert(); registry.insert(); registry.insert(); registry.insert(); @@ -932,6 +1218,16 @@ void registerUtilExternalModels(DialectRegistry ®istry) { *context); }); + registry.addExtension( + +[](MLIRContext *context, affine::AffineDialect *dialect) { + affine::AffineApplyOp::attachInterface< + AffineApplyInferIntDivisibilityOpInterface>(*context); + affine::AffineMinOp::attachInterface< + AffineMinInferIntDivisibilityOpInterface>(*context); + affine::AffineMaxOp::attachInterface< + AffineMaxInferIntDivisibilityOpInterface>(*context); + }); + registry.addExtension( +[](MLIRContext *context, tensor::TensorDialect *dialect) { tensor::InsertSliceOp::attachInterface( @@ -1022,6 +1318,7 @@ void registerUtilExternalModels(DialectRegistry ®istry) { }); // Register MutableRegionBranchOpInterface for SCF ops. + // Register hoistable op interfaces for SCF control flow ops. registry.addExtension(+[](MLIRContext *context, scf::SCFDialect *dialect) { scf::ForOp::attachInterface( *context); @@ -1030,6 +1327,8 @@ void registerUtilExternalModels(DialectRegistry ®istry) { *context); scf::IndexSwitchOp::attachInterface< SCFIndexSwitchOpMutableRegionBranchOpInterface>(*context); + RegionControlFlowHoistableOpInterfaceHelper< + scf::ForOp, scf::IfOp, scf::WhileOp>::registerOpInterface(context); }); } diff --git a/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp b/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp index 475acc17a19c..1445e0dce06d 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp @@ -29,8 +29,9 @@ class Convert1x1FilterConvToMatmul : public OpRewritePattern { PatternRewriter &rewriter) const override { auto filterShapeType = dyn_cast( convOp.getDpsInputOperand(1)->get().getType()); - if (!filterShapeType) + if (!filterShapeType) { return failure(); + } constexpr bool isNCHW = std::is_same_v; @@ -48,8 +49,9 @@ class Convert1x1FilterConvToMatmul : public OpRewritePattern { constexpr int khLoopIndex = isNHWC ? 4 : 5; constexpr int kwLoopIndex = isNHWC ? 5 : 6; - if (filterShape[khIndex] != 1 || filterShape[kwIndex] != 1) + if (filterShape[khIndex] != 1 || filterShape[kwIndex] != 1) { return failure(); + } SmallVector dimReplacements; for (int i = 0; i < numLoops; i++) { diff --git a/compiler/src/iree/compiler/GlobalOptimization/ConvertStridedContractionToContraction.cpp b/compiler/src/iree/compiler/GlobalOptimization/ConvertStridedContractionToContraction.cpp index aec860cd5a67..a7eb49b6111a 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/ConvertStridedContractionToContraction.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/ConvertStridedContractionToContraction.cpp @@ -27,15 +27,18 @@ class ConvertStridedContractionToContraction PatternRewriter &rewriter) const override { // Check if the generic op satisfies all other conditions for being a // contraction. - if (op.getNumDpsInputs() != 2 || op.getNumDpsInits() != 1) + if (op.getNumDpsInputs() != 2 || op.getNumDpsInits() != 1) { return failure(); - if (op.getNumReductionLoops() == 0) + } + if (op.getNumReductionLoops() == 0) { return failure(); + } if (!mlir::linalg::detail::isContractionBody( *op.getBlock(), [](Operation *first, Operation *second) { if ((isa(first) && isa(second)) || - (isa(first) && isa(second))) + (isa(first) && isa(second))) { return true; + } return false; })) { return failure(); @@ -54,16 +57,18 @@ class ConvertStridedContractionToContraction !resultMap.isProjectedPermutation()) { return failure(); } - if (inputMap.isProjectedPermutation()) + if (inputMap.isProjectedPermutation()) { return failure(); + } SmallVector staticShape = op.getStaticLoopRanges(); llvm::SmallDenseMap strides; SmallVector replacementExprs; Value input = op.getDpsInputs()[0]; auto inputTy = dyn_cast(input.getType()); - if (!inputTy) + if (!inputTy) { return failure(); + } SmallVector inputShape(inputTy.getShape()); replacementExprs.reserve(inputMap.getNumResults()); // Walk through input map and look for expressions of the form `dim * cst`. @@ -76,8 +81,9 @@ class ConvertStridedContractionToContraction // Look at binary op expressions. auto binexpr = dyn_cast(expr); // Fail if we see some unexpected kind of expression. - if (!binexpr) + if (!binexpr) { return failure(); + } auto rhs = dyn_cast(binexpr.getRHS()); auto lhs = dyn_cast(binexpr.getLHS()); // Binary expressions must be of the form `dim * cst`. @@ -87,15 +93,17 @@ class ConvertStridedContractionToContraction } strides.insert(std::pair(pos, rhs.getValue())); int64_t newSize = staticShape[lhs.getPosition()]; - if (newSize == ShapedType::kDynamic || newSize == 0) + if (newSize == ShapedType::kDynamic || newSize == 0) { return failure(); + } inputShape[pos] = newSize; replacementExprs.push_back(lhs); } // Fail if we don't have any work to do. - if (strides.empty()) + if (strides.empty()) { return failure(); + } mapRange[inputPos] = AffineMap::get(inputMap.getNumDims(), inputMap.getNumSymbols(), diff --git a/compiler/src/iree/compiler/GlobalOptimization/DecomposeConcat.cpp b/compiler/src/iree/compiler/GlobalOptimization/DecomposeConcat.cpp index 6ad30ce8e87c..cfef1f765698 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/DecomposeConcat.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/DecomposeConcat.cpp @@ -54,8 +54,9 @@ struct TransposeInnerConcatenation : public OpRewritePattern { ArrayRef concatShape = concatType.getShape(); int64_t outerMostNonUnitDim = 0; while (outerMostNonUnitDim < concatOp.getRank()) { - if (concatShape[outerMostNonUnitDim] != 1) + if (concatShape[outerMostNonUnitDim] != 1) { break; + } outerMostNonUnitDim++; } diff --git a/compiler/src/iree/compiler/GlobalOptimization/DetachElementwiseFromNamedOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/DetachElementwiseFromNamedOps.cpp index 9b6cdffa8a64..f38ed8f8a312 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/DetachElementwiseFromNamedOps.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/DetachElementwiseFromNamedOps.cpp @@ -40,8 +40,9 @@ struct DetachElementwisePattern !isa(*linalgOp)) { return failure(); } - if (!linalgOp.hasPureTensorSemantics()) + if (!linalgOp.hasPureTensorSemantics()) { return failure(); + } // Nothing to do if the output tensor operand is already a fill op. SmallVector outputOperands; @@ -52,8 +53,9 @@ struct DetachElementwisePattern } // Right now all the cases we see have one output. This can be relaxed once // we see multiple output ops. - if (outputOperands.size() != 1) + if (outputOperands.size() != 1) { return failure(); + } Value outputOperand = outputOperands.front()->get(); auto outsDefiningOp = outputOperand.getDefiningOp(); @@ -62,8 +64,9 @@ struct DetachElementwisePattern return failure(); } auto outputType = cast(outputOperand.getType()); - if (!outputType.getElementType().isIntOrFloat()) + if (!outputType.getElementType().isIntOrFloat()) { return failure(); + } auto elementType = outputType.getElementType(); Location loc = linalgOp.getLoc(); @@ -139,17 +142,20 @@ struct DetachSplatConstantOutsOperands for (auto outOperand : llvm::enumerate(dpsInterfaceOp.getDpsInits())) { auto constOp = outOperand.value().template getDefiningOp(); - if (!constOp) + if (!constOp) { continue; + } auto resultType = dyn_cast(constOp.getResult().getType()); - if (!resultType || !resultType.getElementType().isIntOrFloat()) + if (!resultType || !resultType.getElementType().isIntOrFloat()) { continue; + } auto attr = dyn_cast(constOp.getValue()); - if (!attr || !attr.isSplat()) + if (!attr || !attr.isSplat()) { continue; + } Location loc = constOp.getLoc(); Type elementType = resultType.getElementType(); diff --git a/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp b/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp index 8fd07bef521a..e2b900ade36e 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp @@ -121,8 +121,9 @@ static void expandType(Type type, SmallVectorImpl &newTypes) { // Expands tensors in the given |types| list to (tensor, dynamic dims...). // This could be changed to some iterator magic to avoid the alloc. static SmallVector expandTypes(TypeRange types) { - if (types.empty()) + if (types.empty()) { return {}; + } SmallVector newTypes; newTypes.reserve(types.size() * 2); for (auto type : types) { @@ -205,22 +206,25 @@ static void expandTensorDims(Operation *op, SymbolTable &symbolTable, static void expandRegion(Region ®ion, SymbolTable &symbolTable, ExpandedGlobalMap &globalMap, IndexSet &indexSet, TensorDimMap tensorDimMap) { - if (region.empty()) + if (region.empty()) { return; + } // Update all block arguments. auto indexType = IndexType::get(region.getContext()); for (auto &block : region.getBlocks()) { - if (!llvm::any_of(block.getArgumentTypes(), isDynamicTensor)) + if (llvm::none_of(block.getArgumentTypes(), isDynamicTensor)) { continue; + } // Insert and build a list of expanded (tensor, dynamic dims...) tuples. SmallVector expansions; for (int i = block.getNumArguments() - 1; i >= 0; --i) { auto arg = block.getArgument(i); auto tensorType = dyn_cast(arg.getType()); - if (!tensorType || tensorType.hasStaticShape()) + if (!tensorType || tensorType.hasStaticShape()) { continue; + } ExpandedValue expandedValue; expandedValue.tensor = arg; for (unsigned j = 0; j < tensorType.getNumDynamicDims(); ++j) { @@ -302,8 +306,9 @@ static void retieResults(Operation *op, Operation *newOp, static void expandGlobalLoadOp(IREE::Util::GlobalLoadOpInterface op, ExpandedGlobalMap &globalMap, IndexSet &indexSet, TensorDimMap &tensorDimMap) { - if (!usesDynamicTensors(op)) + if (!usesDynamicTensors(op)) { return; + } OpBuilder builder(op); builder.setInsertionPointAfter(op); auto &expandedGlobal = globalMap[op.getGlobalName()]; @@ -335,8 +340,9 @@ static void expandGlobalStoreOp(IREE::Util::GlobalStoreOpInterface op, ExpandedGlobalMap &globalMap, IndexSet &indexSet, TensorDimMap &tensorDimMap) { - if (!usesDynamicTensors(op)) + if (!usesDynamicTensors(op)) { return; + } OpBuilder builder(op); builder.setInsertionPointAfter(op); auto expandedValue = consumeExpandedValue( @@ -395,13 +401,15 @@ static void expandFuncOp(IREE::Util::FuncOp op, SymbolTable &symbolTable, // %2 = flow.tensor.tie_shape %r : tensor{%rd} static void expandCallOp(IREE::Util::CallOp op, SymbolTable &symbolTable, IndexSet &indexSet, TensorDimMap &tensorDimMap) { - if (!usesDynamicTensors(op)) + if (!usesDynamicTensors(op)) { return; + } // Ignore calls to public/external functions. auto calleeOp = symbolTable.lookup(op.getCallee()); - if (IREE::Util::isPublicOrExternal(calleeOp)) + if (IREE::Util::isPublicOrExternal(calleeOp)) { return; + } // Build the new call op with expanded operands and results. OpBuilder builder(op); @@ -429,10 +437,13 @@ static void expandCallOp(IREE::Util::CallOp op, SymbolTable &symbolTable, // util.return %0, %d static void expandReturnOp(IREE::Util::ReturnOp op, IndexSet &indexSet, TensorDimMap &tensorDimMap) { - if (!usesDynamicTensors(op)) + if (!usesDynamicTensors(op)) { return; - if (IREE::Util::isPublicOrExternal(op->getParentOfType())) + } + if (IREE::Util::isPublicOrExternal( + op->getParentOfType())) { return; + } OpBuilder builder(op); auto operands = expandOperands(op.getLoc(), op.getOperands(), tensorDimMap, indexSet, builder); @@ -462,8 +473,9 @@ static void expandBranchOp(mlir::cf::BranchOp op, IndexSet &indexSet, static void expandCondBranchOp(mlir::cf::CondBranchOp op, IndexSet &indexSet, TensorDimMap &tensorDimMap) { - if (!usesDynamicTensors(op)) + if (!usesDynamicTensors(op)) { return; + } OpBuilder builder(op); mlir::cf::CondBranchOp::create( builder, op.getLoc(), op.getCondition(), op.getTrueDest(), @@ -487,8 +499,9 @@ static void expandCondBranchOp(mlir::cf::CondBranchOp op, IndexSet &indexSet, // %4 = flow.tensor.tie_shape %2 : tensor{%3} static void expandSelectOp(mlir::arith::SelectOp op, IndexSet &indexSet, TensorDimMap &tensorDimMap) { - if (!usesDynamicTensors(op)) + if (!usesDynamicTensors(op)) { return; + } OpBuilder builder(op); auto trueValue = consumeExpandedValue(op.getLoc(), op.getTrueValue(), diff --git a/compiler/src/iree/compiler/GlobalOptimization/GlobalLoopInvariantCodeMotion.cpp b/compiler/src/iree/compiler/GlobalOptimization/GlobalLoopInvariantCodeMotion.cpp index cbbbe5f4880c..4f8f7258bf94 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/GlobalLoopInvariantCodeMotion.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/GlobalLoopInvariantCodeMotion.cpp @@ -39,8 +39,9 @@ static bool isHoistableOp(LoopLikeOpInterface loopOp, Operation *op, for (OpOperand &operand : op->getOpOperands()) { Value value = operand.get(); // Ignore values defined outside the loop. - if (loopOp.isDefinedOutsideOfLoop(value)) + if (loopOp.isDefinedOutsideOfLoop(value)) { continue; + } Operation *producer = value.getDefiningOp(); // If the producer is not an operation, can't hoist it. @@ -61,8 +62,9 @@ static LogicalResult hoistLoopInvariants(LoopLikeOpInterface loopOp, llvm::SetVector hoistableOps; for (Region *region : loopOp.getLoopRegions()) { // Skip loops with multi-block regions to simplify op's dependency. - if (!region->hasOneBlock()) + if (!region->hasOneBlock()) { return failure(); + } // Consider only the top-level ops in the region. The forward visiting in a // single block ensures we are check and add ops in topological order. @@ -73,8 +75,9 @@ static LogicalResult hoistLoopInvariants(LoopLikeOpInterface loopOp, } } } - if (hoistableOps.empty()) + if (hoistableOps.empty()) { return success(); + } // Wrap the loop in zero-trip-check so the hoisted ops will only run when the // loop condition is ever satisfied. @@ -87,8 +90,9 @@ static LogicalResult hoistLoopInvariants(LoopLikeOpInterface loopOp, return scf::wrapWhileLoopInZeroTripCheck(op, rewriter); }) .Default([&](Operation *op) { return failure(); }); - if (failed(wrappedLoop)) + if (failed(wrappedLoop)) { return failure(); + } // Hoist ops out of the loop in topological order. for (Operation *op : hoistableOps) { @@ -118,15 +122,17 @@ struct GlobalLoopInvariantCodeMotionPass // to move across multiple loop levels. funcOp.walk([&](LoopLikeOpInterface op) { // Check if the loop type is supported. - if (isa(op)) + if (isa(op)) { candidateLoops.push_back(op); + } return; }); IRRewriter rewriter(context); for (auto loopOp : candidateLoops) { - if (failed(hoistLoopInvariants(loopOp, rewriter))) + if (failed(hoistLoopInvariants(loopOp, rewriter))) { return signalPassFailure(); + } } } }; diff --git a/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp b/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp index 4cc86170c52e..d7f507b868d5 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp @@ -38,8 +38,9 @@ struct MaterializeHomogeneousEncodingsPass final void runOnOperation() override { mlir::ModuleOp moduleOp = getOperation(); IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp); - if (failed(deviceAnalysis.run())) + if (failed(deviceAnalysis.run())) { return signalPassFailure(); + } SetVector executableTargets; deviceAnalysis.gatherAllExecutableTargets(executableTargets); diff --git a/compiler/src/iree/compiler/GlobalOptimization/OptimizeNumerics.cpp b/compiler/src/iree/compiler/GlobalOptimization/OptimizeNumerics.cpp index c8f38fa2c4d8..6f7369dd0a14 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/OptimizeNumerics.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/OptimizeNumerics.cpp @@ -22,8 +22,9 @@ namespace { int getNextPotBitWidth(int bitWidth, int minBitWidth = 8) { for (int i = minBitWidth;; i *= 2) { - if (i >= bitWidth) + if (i >= bitWidth) { return i; + } } } @@ -108,8 +109,9 @@ struct TensorEmptyCast LogicalResult matchAndRewrite(IREE::Util::NumericCastOpInterface castOp, PatternRewriter &rewriter) const override { auto emptyOp = castOp.getInput().getDefiningOp(); - if (!emptyOp) + if (!emptyOp) { return failure(); + } Type resultType = castOp.getCasted().getType(); rewriter.replaceOpWithNewOp(castOp, resultType, @@ -127,8 +129,9 @@ struct LinalgFillCast PatternRewriter &rewriter) const override { auto loc = castOp.getLoc(); auto fillOp = castOp.getInput().getDefiningOp(); - if (!fillOp) + if (!fillOp) { return failure(); + } Type toElementType = getElementTypeOrSelf(castOp.getCastedType()); Value fillInput = fillOp.value(); diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp index b756902b9b10..58398be0f896 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp @@ -180,6 +180,8 @@ void buildGlobalOptimizationPassPipeline( transformOptions.aggressiveTransposePropagation; options.enableConvolutionPropagation = transformOptions.propagateTransposesThroughConv; + options.enableSinkTransposeThroughPad = + transformOptions.sinkTransposeThroughPad; options.enableAttentionVTranspose = clEnableAttentionVTranspose; options.enableEdgeReshapePropagation = @@ -271,10 +273,10 @@ void buildGlobalOptimizationPassPipeline( exportParametersOptions)); } - if (!transformOptions.parameterSplatExportFile.empty()) { + if (!transformOptions.parameterSplatPath.empty()) { IREE::IO::Parameters::GenerateSplatParameterArchivePassOptions generateSplatOptions; - generateSplatOptions.filePath = transformOptions.parameterSplatExportFile; + generateSplatOptions.filePath = transformOptions.parameterSplatPath; mainPassManager.addPass( IREE::IO::Parameters::createGenerateSplatParameterArchivePass( generateSplatOptions)); diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.h b/compiler/src/iree/compiler/GlobalOptimization/Passes.h index cf6a89135990..d1e1c925c1b7 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.h +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.h @@ -49,9 +49,9 @@ struct TransformOptions : public PassPipelineOptions { llvm::cl::desc("Minimum size of constants to export as parameters."), llvm::cl::init(0), }; - Option parameterSplatExportFile{ + Option parameterSplatPath{ *this, - "parameter-splat-export-file", + "parameter-splat-path", llvm::cl::desc("File path to create a splat parameter archive out of all " "parameters in the module."), llvm::cl::init(""), @@ -71,6 +71,12 @@ struct TransformOptions : public PassPipelineOptions { "Enables propagation of transpose ops through convolutions"), llvm::cl::init(false), }; + Option sinkTransposeThroughPad{ + *this, + "sink-transpose-through-pad", + llvm::cl::desc("Enables sinking transpose through pad operations"), + llvm::cl::init(false), + }; Option outerDimConcat{ *this, "outer-dim-concat", diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.td b/compiler/src/iree/compiler/GlobalOptimization/Passes.td index 4d89a5ed2273..1bf312fee2c7 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.td +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.td @@ -121,6 +121,8 @@ def PropagateLinalgTransposePass : /*default=*/"false", "enable propagation through convolutions">, Option<"enableEdgeReshapePropagation", "enable-edge-reshape-propagation", "bool", /*default=*/"false", "Enable propagation of reshapes on the edges of the program">, + Option<"enableSinkTransposeThroughPad", "enable-sink-transpose-through-pad", "bool", + /*default=*/"false", "Enable sinking transpose through pad operations">, ]; } diff --git a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp index 2e73a25774b4..d7b80f988238 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp @@ -285,8 +285,9 @@ class FuseTransposeWithProducerLinalgOp rewriter.replaceOp(transposeOp, newGenericOp->getResult(resultIndex)); for (auto [oldRes, newRes] : llvm::zip_equal(genericOp.getResults(), newGenericOp->getResults())) { - if (oldRes.getResultNumber() == resultIndex) + if (oldRes.getResultNumber() == resultIndex) { continue; + } rewriter.replaceAllUsesWith(oldRes, newRes); } return success(); @@ -589,6 +590,52 @@ class SinkTransposeThroughExpandShape bool enableEdgeReshapePropagation = true; }; +// Sinks a transpose through a tensor.pad. +class SinkTransposeThroughPad : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const override { + if (!IREE::Flow::isNonNullAndOutsideDispatch(padOp)) { + return failure(); + } + Value source = padOp.getSource(); + auto transposeOp = source.getDefiningOp(); + if (!transposeOp) { + return failure(); + } + + Block &block = padOp.getRegion().front(); + if (llvm::any_of(block.getArguments(), [](BlockArgument blockArg) { + return blockArg.getNumUses(); + })) { + return failure(); + } + + auto invPerm = invertPermutationVector(transposeOp.getPermutation()); + SmallVector lowSizes = padOp.getMixedLowPad(); + SmallVector highSizes = padOp.getMixedHighPad(); + applyPermutationToVector(lowSizes, invPerm); + applyPermutationToVector(highSizes, invPerm); + + RankedTensorType oldPaddedType = cast(padOp.getType()); + RankedTensorType newPaddedType = oldPaddedType.clone( + applyPermutation(oldPaddedType.getShape(), invPerm)); + + auto newPadOp = tensor::PadOp::create( + rewriter, padOp.getLoc(), newPaddedType, transposeOp.getInput(), + lowSizes, highSizes, padOp.getNofold()); + rewriter.cloneRegionBefore(padOp.getRegion(), newPadOp.getRegion(), + newPadOp.getRegion().begin()); + + Value newTransposeOp = + createTranspose(rewriter, newPadOp, transposeOp.getPermutation()); + rewriter.replaceOp(padOp, newTransposeOp); + return success(); + } +}; + // Fuses a transpose with the input of a linalg.generic op or contraction op. // Contraction ops are generalized and then treated as a generic. For example, // @@ -1292,6 +1339,9 @@ void PropagateLinalgTransposePass::runOnOperation() { sinkingPatterns.insert(context); sinkingPatterns.insert( context, enableEdgeReshapePropagation); + if (enableSinkTransposeThroughPad) { + sinkingPatterns.insert(context); + } sinkingPatterns.insert( context, enableAggressivePropagation, enableConvolutionPropagation); sinkingPatterns.insert(context); diff --git a/compiler/src/iree/compiler/GlobalOptimization/QuantizedConvToConv.cpp b/compiler/src/iree/compiler/GlobalOptimization/QuantizedConvToConv.cpp index 4614f7b56c8b..e516cbf559c7 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/QuantizedConvToConv.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/QuantizedConvToConv.cpp @@ -103,8 +103,9 @@ void GetDynamicDym(ImplicitLocOpBuilder &builder, int64_t dim) { ShapedType ty = cast(value.getType()); dims.push_back(ty.getDimSize(dim)); - if (ty && ty.isDynamicDim(dim)) + if (ty && ty.isDynamicDim(dim)) { dynDims.push_back(tensor::DimOp::create(builder, value, dim)); + } } Value multiplyDims(ImplicitLocOpBuilder &builder, Value value, @@ -178,8 +179,9 @@ struct QuantizedConvToConv // Materialize a length-1 dimension at the end of the summation. SmallVector reassociationMap(3); - for (int i = 0; i < 3; i++) + for (int i = 0; i < 3; i++) { reassociationMap[i].push_back(builder.getAffineDimExpr(i)); + } reassociationMap.back().push_back(builder.getAffineDimExpr(3)); auto expandTy = diff --git a/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp index 638ca1f64f05..462a6caf3417 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp @@ -83,8 +83,9 @@ raiseTensorExtractToInput(linalg::GenericOp linalgOp, RewriterBase &rewriter) { // Restrict to cases where the constant is 0. This is because handling // constants other than 0 in indexing map, may cause problems in the // lowering pipeline later. - if (constantIndex.getLimitedValue() != 0) + if (constantIndex.getLimitedValue() != 0) { return failure(); + } exprs.push_back(getAffineConstantExpr(0, rewriter.getContext())); continue; } @@ -306,8 +307,9 @@ class NamedImplicitCastOpConversion : public OpInterfaceRewritePattern { } if (!llvm::all_of(producer.getIndexingMapsArray(), - [](AffineMap map) { return map.isIdentity(); })) + [](AffineMap map) { return map.isIdentity(); })) { return false; + } std::optional castOp = getDefiningNonI1ExtendingCastOp(operand.get()); @@ -319,8 +321,9 @@ class NamedImplicitCastOpConversion : public OpInterfaceRewritePattern { // preferred to fuse those with producers (and the consumer fusion is // arguably the less canonical form). auto canFoldCast = [&]() { - if (isa(*castOp)) + if (isa(*castOp)) { return true; + } // Signed operations can only be folded with (implicitly) signed // linalg named ops if (isa(*castOp)) { diff --git a/compiler/src/iree/compiler/GlobalOptimization/Utils.cpp b/compiler/src/iree/compiler/GlobalOptimization/Utils.cpp index d46faef5f3b4..4b025f71c87d 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Utils.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Utils.cpp @@ -112,8 +112,9 @@ Value sumReduceDimensionSubset(ImplicitLocOpBuilder &rewriter, Value val, llvm::SmallVector staticSizes; SmallVector dynSizes; for (int i = 0, s = is_reduction.size(); i < s; i++) { - if (is_reduction[i]) + if (is_reduction[i]) { continue; + } staticSizes.push_back(ty.getDimSize(i)); if (ty.isDynamicDim(i)) { diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel index 76233b25c578..6a2a09440819 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel +++ b/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "cleanup_numeric_narrowing.mlir", "conv1x1_to_matmul.mlir", diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir index d9be97a6ba8e..49def08acc02 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir @@ -4,6 +4,7 @@ // RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-propagate-linalg-transpose{test-bubbling-only=true}))" --split-input-file %s | FileCheck %s --check-prefix=BUBBLE // RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-propagate-linalg-transpose{enable-aggressive-propagation-through-conv=true}))" --split-input-file %s | FileCheck %s --check-prefix=CONV // RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-propagate-linalg-transpose{enable-edge-reshape-propagation=true}))" %s -o - --split-input-file | FileCheck %s --check-prefix=ENABLE-EDGE-PROP +// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-propagate-linalg-transpose{enable-sink-transpose-through-pad=true}))" --split-input-file %s | FileCheck %s --check-prefix=SINK-PAD util.func public @specialize_transpose_op(%arg0 : tensor<1x2x3xf32>, %empty : tensor<3x2x1xf32>) -> tensor<3x2x1xf32> { @@ -1040,6 +1041,7 @@ util.func public @bubble_transpose_through_truncf_and_fuse_with_conv( // BUBBLE: linalg.generic // BUBBLE: } -> tensor<16x2x2x4xbf16> // BUBBLE-NOT: linalg.transpose +// BUBBLE: util.return // With enable-aggressive-propagation-through-conv, transpose is fully fused with conv. // CONV-LABEL: util.func public @bubble_transpose_through_truncf_and_fuse_with_conv @@ -1050,3 +1052,58 @@ util.func public @bubble_transpose_through_truncf_and_fuse_with_conv( // CONV: } -> tensor<16x2x2x4xbf16> // CONV-NOT: linalg.transpose // CONV: util.return %[[TRUNCF]] + +// ----- + +util.func public @sink_transpose_through_pad(%arg0: tensor<16x64x64x128xf16>) -> tensor<16x128x66x66xf16> { + %cst = arith.constant 0.000000e+00 : f16 + %empty = tensor.empty() : tensor<16x128x64x64xf16> + %transposed = linalg.transpose ins(%arg0 : tensor<16x64x64x128xf16>) outs(%empty : tensor<16x128x64x64xf16>) permutation = [0, 3, 1, 2] + %padded = tensor.pad %transposed low[0, 0, 1, 1] high[0, 0, 1, 1] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): + tensor.yield %cst : f16 + } : tensor<16x128x64x64xf16> to tensor<16x128x66x66xf16> + util.return %padded : tensor<16x128x66x66xf16> +} +// With enable-sink-transpose-through-pad=true, transpose sinks through pad. +// SINK-PAD-LABEL: util.func public @sink_transpose_through_pad +// SINK-PAD: %[[PAD:.+]] = tensor.pad +// SINK-PAD: %[[TRANSPOSE:.+]] = linalg.transpose +// SINK-PAD-SAME: ins(%[[PAD]] +// SINK-PAD: util.return %[[TRANSPOSE]] + +// Without the flag, transpose does not sink through pad. +// SINK-LABEL: util.func public @sink_transpose_through_pad +// SINK: %[[TRANSPOSE:.+]] = linalg.transpose +// SINK: %[[PAD:.+]] = tensor.pad %[[TRANSPOSE]] +// SINK: util.return %[[PAD]] + +// ----- + +util.func public @sink_transpose_through_expand_shape_and_pad(%arg0: tensor<16x2x48x32x288xbf16>) -> tensor<16x3x96x4x48x32xbf16> { + %cst = arith.constant 0.000000e+00 : bf16 + %empty = tensor.empty() : tensor<16x288x2x48x32xbf16> + %transposed = linalg.transpose ins(%arg0 : tensor<16x2x48x32x288xbf16>) outs(%empty : tensor<16x288x2x48x32xbf16>) permutation = [0, 4, 1, 2, 3] + %expanded = tensor.expand_shape %transposed [[0], [1, 2], [3], [4], [5]] output_shape [16, 3, 96, 2, 48, 32] : tensor<16x288x2x48x32xbf16> into tensor<16x3x96x2x48x32xbf16> + %padded = tensor.pad %expanded low[0, 0, 0, 1, 0, 0] high[0, 0, 0, 1, 0, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index): + tensor.yield %cst : bf16 + } : tensor<16x3x96x2x48x32xbf16> to tensor<16x3x96x4x48x32xbf16> + util.return %padded : tensor<16x3x96x4x48x32xbf16> +} +// With enable-sink-transpose-through-pad=true, transpose sinks through both +// expand_shape and pad. +// SINK-PAD-LABEL: util.func public @sink_transpose_through_expand_shape_and_pad +// SINK-PAD: %[[EXPAND:.+]] = tensor.expand_shape %arg0 +// SINK-PAD: %[[PAD:.+]] = tensor.pad %[[EXPAND]] +// SINK-PAD: %[[TRANSPOSE:.+]] = linalg.transpose +// SINK-PAD-SAME: ins(%[[PAD]] +// SINK-PAD: util.return %[[TRANSPOSE]] + +// Without the flag, transpose sinks through expand_shape but not pad. +// SINK-LABEL: util.func public @sink_transpose_through_expand_shape_and_pad +// SINK: %[[EXPAND:.+]] = tensor.expand_shape %arg0 +// SINK: %[[TRANSPOSE:.+]] = linalg.transpose +// SINK-SAME: ins(%[[EXPAND]] +// SINK: %[[PAD:.+]] = tensor.pad %[[TRANSPOSE]] +// SINK: util.return %[[PAD]] diff --git a/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp b/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp index 08cc77b31eee..13a75494d08a 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp +++ b/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp @@ -32,15 +32,17 @@ class AutoInputConversionPipelinePass final }; void AutoInputConversionPipelinePass::runOnOperation() { - if (!pipelineExtensions) + if (!pipelineExtensions) { return; + } mlir::ModuleOp moduleOp = getOperation(); llvm::StringSet<> detectedTypeMnemonics; pipelineExtensions->populateDetectedCustomInputConversionTypes( moduleOp, detectedTypeMnemonics); - if (detectedTypeMnemonics.empty()) + if (detectedTypeMnemonics.empty()) { return; + } if (detectedTypeMnemonics.getNumItems() > 1) { // TODO(scotttodd): handle multiple typeMnemonics (use all?) diff --git a/compiler/src/iree/compiler/InputConversion/Common/ConvertPrimitiveType.cpp b/compiler/src/iree/compiler/InputConversion/Common/ConvertPrimitiveType.cpp index 48080d1c0fea..b6480b0b797e 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/ConvertPrimitiveType.cpp +++ b/compiler/src/iree/compiler/InputConversion/Common/ConvertPrimitiveType.cpp @@ -43,8 +43,9 @@ Value convertRankedFloat(OpBuilder &builder, Type type, ValueRange inputs, Location loc) { Type eTy = getElementTypeOrSelf(type); Type inputETy = getElementTypeOrSelf(inputs[0].getType()); - if (!isa(getElementTypeOrSelf(type))) + if (!isa(getElementTypeOrSelf(type))) { return nullptr; + } if (inputETy.getIntOrFloatBitWidth() > eTy.getIntOrFloatBitWidth()) { return arith::TruncFOp::create(builder, loc, type, inputs[0]); @@ -61,8 +62,9 @@ Value convertRankedInteger(OpBuilder &builder, Type type, ValueRange inputs, Location loc) { Type eTy = getElementTypeOrSelf(type); Type inputETy = getElementTypeOrSelf(inputs[0].getType()); - if (!isa(getElementTypeOrSelf(type))) + if (!isa(getElementTypeOrSelf(type))) { return nullptr; + } bool isUnsigned = eTy.isUnsignedInteger(); int64_t inBitwidth = inputETy.getIntOrFloatBitWidth(); @@ -89,8 +91,9 @@ struct PrimitiveTypeConverter : public TypeConverter { explicit PrimitiveTypeConverter() { addConversion([](Type type) { return type; }); addConversion([&](SourceType type) -> Type { - if (!isSourceType(type)) + if (!isSourceType(type)) { return type; + } return getTargetType(type); }); addConversion([&](ComplexType type) { @@ -302,21 +305,25 @@ struct ConvertTypesPass : public Base { return typeConverter.isLegal(globalOp.getGlobalType()); } else if (auto funcOp = dyn_cast(op)) { for (Type type : funcOp.getArgumentTypes()) { - if (!typeConverter.isLegal(type)) + if (!typeConverter.isLegal(type)) { return false; + } } for (Type type : funcOp.getResultTypes()) { - if (!typeConverter.isLegal(type)) + if (!typeConverter.isLegal(type)) { return false; + } } } for (Type type : op->getResultTypes()) { - if (!typeConverter.isLegal(type)) + if (!typeConverter.isLegal(type)) { return false; + } } for (Type type : op->getOperandTypes()) { - if (!typeConverter.isLegal(type)) + if (!typeConverter.isLegal(type)) { return false; + } } return true; }); diff --git a/compiler/src/iree/compiler/InputConversion/Common/ImportMLProgram.cpp b/compiler/src/iree/compiler/InputConversion/Common/ImportMLProgram.cpp index a4c591fa945d..06eb0bf203f6 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/ImportMLProgram.cpp +++ b/compiler/src/iree/compiler/InputConversion/Common/ImportMLProgram.cpp @@ -95,8 +95,9 @@ class MLProgramGlobalOpPattern matchAndRewrite(ml_program::GlobalOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type newType = typeConverter->convertType(srcOp.getType()); - if (!newType) + if (!newType) { return failure(); + } std::map externs; @@ -115,12 +116,14 @@ class MLProgramGlobalOpPattern globalOp.setVisibility(SymbolTable::Visibility::Private); globalOp->setDialectAttrs(srcOp->getDialectAttrs()); - if (isExtern) + if (isExtern) { externGlobals.emplace_back(srcOp.getName(), newType); + } // No more work needed if not public global. - if (visibility != SymbolTable::Visibility::Public) + if (visibility != SymbolTable::Visibility::Public) { return success(); + } ModuleOp module = srcOp->getParentOfType(); @@ -140,12 +143,15 @@ class MLProgramGlobalOpPattern StringRef s = format; // Verify only single replacement of 0th index. s = s.drop_until([](char c) { return c == '{'; }); - if (s.empty() || !s.consume_front("{")) + if (s.empty() || !s.consume_front("{")) { return failure(); - if (!s.consume_front("0")) + } + if (!s.consume_front("0")) { return failure(); - if (!s.consume_front("}")) + } + if (!s.consume_front("}")) { return failure(); + } s = s.drop_until([](char c) { return c == '{'; }); return success(s.empty()); }; @@ -157,15 +163,17 @@ class MLProgramGlobalOpPattern v ? dyn_cast_if_present(v.get("get")) : nullptr; { const std::string getFormat = get ? get.str() : "global${0}$get"; - if (failed(verifyFormat(getFormat))) + if (failed(verifyFormat(getFormat))) { return failure(); + } getterName = llvm::formatv(getFormat.c_str(), globalOp.getSymName()); } auto set = v ? dyn_cast_if_present(v.get("set")) : nullptr; { const std::string setFormat = set ? set.str() : "global${0}$set"; - if (failed(verifyFormat(setFormat))) + if (failed(verifyFormat(setFormat))) { return failure(); + } setterName = llvm::formatv(setFormat.c_str(), globalOp.getSymName()); } @@ -258,12 +266,15 @@ void ImportMLProgramPass::runOnOperation() { ONE_TO_ONE(ml_program::GlobalLoadConstOp, IREE::Util::GlobalLoadOp); ONE_TO_ONE(ml_program::GlobalStoreOp, IREE::Util::GlobalStoreOp); - if (failed(applyFullConversion(getOperation(), target, std::move(patterns)))) + if (failed( + applyFullConversion(getOperation(), target, std::move(patterns)))) { signalPassFailure(); + } if (!externGlobals.empty() && - failed(createExternInitFunction(getOperation(), externGlobals))) + failed(createExternInitFunction(getOperation(), externGlobals))) { signalPassFailure(); + } } } // namespace mlir::iree_compiler::InputConversion diff --git a/compiler/src/iree/compiler/InputConversion/Common/SanitizeModuleNames.cpp b/compiler/src/iree/compiler/InputConversion/Common/SanitizeModuleNames.cpp index e5c966c9ef75..b9ab257650c4 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/SanitizeModuleNames.cpp +++ b/compiler/src/iree/compiler/InputConversion/Common/SanitizeModuleNames.cpp @@ -30,8 +30,9 @@ class SanitizeModuleNamesPass final mlir::ModuleOp moduleOp = getOperation(); auto optionalName = moduleOp.getName(); - if (!optionalName.has_value()) + if (!optionalName.has_value()) { return; + } auto name = optionalName.value(); moduleOp.setName(sanitizeSymbolName(name)); diff --git a/compiler/src/iree/compiler/InputConversion/Common/test/BUILD.bazel b/compiler/src/iree/compiler/InputConversion/Common/test/BUILD.bazel index 3cb0af85328c..918a3d5a6671 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/InputConversion/Common/test/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "demote_f32_to_f16.mlir", "demote_f64_to_f32.mlir", diff --git a/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp b/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp index ae76182a3526..07f8f02a5789 100644 --- a/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp +++ b/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp @@ -37,8 +37,9 @@ struct OptionalCheckImportConversion : public VMImportOpConversion { rewriter.setInsertionPointToStart(callBlock); auto results = rewriteToCall(op, adaptor, this->importOp, *this->getTypeConverter(), rewriter); - if (!results.has_value()) + if (!results.has_value()) { return failure(); + } rewriter.replaceOp(op, results.value()); IREE::VM::BranchOp::create(rewriter, op.getLoc(), followingBlock); return success(); diff --git a/compiler/src/iree/compiler/Modules/Check/IR/BUILD.bazel b/compiler/src/iree/compiler/Modules/Check/IR/BUILD.bazel index 3f2bea0aef55..0a9d1dc304bd 100644 --- a/compiler/src/iree/compiler/Modules/Check/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/Check/IR/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "CheckBase.td", "CheckOps.td", diff --git a/compiler/src/iree/compiler/Modules/Check/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/Check/test/BUILD.bazel index 8f8f49a2a7ec..6f031429a476 100644 --- a/compiler/src/iree/compiler/Modules/Check/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/Check/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "canonicalize.mlir", "ops.mlir", diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALInlineToVM/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALInlineToVM/test/BUILD.bazel index 315246764f8e..accd28591098 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALInlineToVM/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALInlineToVM/test/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ ], include = ["*.mlir"], diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/Patterns.cpp b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/Patterns.cpp index ac0f7b91b1ee..b8e9e556bb6e 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/Patterns.cpp +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/Patterns.cpp @@ -27,9 +27,10 @@ struct ElementTypeOpConversion ConversionPatternRewriter &rewriter) const override { auto value = IREE::HAL::ElementTypeOp::getTypeValue(op.getTypeAttr().getValue()); - if (!value.has_value()) + if (!value.has_value()) { return rewriter.notifyMatchFailure(op.getLoc(), "unsupported element type"); + } rewriter.replaceOpWithNewOp(op, value.value(), 32); return success(); } @@ -42,9 +43,10 @@ struct EncodingTypeOpConversion matchAndRewrite(IREE::HAL::EncodingTypeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto value = IREE::HAL::EncodingTypeOp::getTypeValue(op.getEncodingAttr()); - if (!value.has_value()) + if (!value.has_value()) { return rewriter.notifyMatchFailure(op.getLoc(), "unsupported encoding type"); + } rewriter.replaceOpWithNewOp(op, value.value(), 32); return success(); } diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/test/BUILD.bazel index a872428de0ba..0b0b3144856e 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "buffer_ops.mlir", "buffer_view_ops.mlir", diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp index fe756c89031f..af7cd6099fe4 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp @@ -664,10 +664,12 @@ struct GlobalTimepointConversionPattern matchAndRewrite(IREE::Util::GlobalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto initialValue = op.getInitialValue(); - if (!initialValue.has_value()) + if (!initialValue.has_value()) { return failure(); - if (!isa(*initialValue)) + } + if (!isa(*initialValue)) { return failure(); + } rewriter.modifyOpInPlace( op, [&]() { op.setInitialValueAttr(rewriter.getI64IntegerAttr(0)); }); return success(); diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/test/BUILD.bazel index 4c424a92daff..744d2d40971e 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "cmd_ops.mlir", "debug_ops.mlir", diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/IR/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Inline/IR/BUILD.bazel index 6543c578a5dc..52bdf367df85 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/IR/BUILD.bazel @@ -18,6 +18,7 @@ exports_files(["HALInlineOps.td"]) iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "HALInlineBase.td", "HALInlineOps.td", diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/IR/HALInlineOps.cpp b/compiler/src/iree/compiler/Modules/HAL/Inline/IR/HALInlineOps.cpp index aae94b66615c..90a4b1b5cf9a 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/IR/HALInlineOps.cpp +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/IR/HALInlineOps.cpp @@ -103,8 +103,9 @@ void BufferStorageOp::getAsmResultNames( OpFoldResult BufferStorageOp::fold(FoldAdaptor operands) { auto *definingOp = getBuffer().getDefiningOp(); - if (!definingOp) + if (!definingOp) { return {}; + } if (auto sourceOp = dyn_cast_if_present( definingOp)) { return sourceOp.getStorage(); @@ -168,8 +169,9 @@ struct FoldBufferViewCreateSubspan needsUpdate = true; } rewriter.restoreInsertionPoint(ip); - if (!needsUpdate) + if (!needsUpdate) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.getSourceBufferMutable().assign(newSourceBuffer); op.getSourceOffsetMutable().assign(newSourceOffset); diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Inline/IR/test/BUILD.bazel index 5cf787ee14bd..892b95fb1904 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "buffer_folding.mlir", ], diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/test/BUILD.bazel index 647c5145a792..637ea14df48d 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "inline_executables.mlir", ], diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/HALLoaderToVM/Patterns.cpp b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/HALLoaderToVM/Patterns.cpp index 31f4008b6a7d..e420f5a2f27b 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/HALLoaderToVM/Patterns.cpp +++ b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/HALLoaderToVM/Patterns.cpp @@ -26,8 +26,9 @@ namespace { // Casts |value| to i32 if it is not already. static Value castToI32(Value value, OpBuilder &builder) { - if (value.getType().isInteger(32)) + if (value.getType().isInteger(32)) { return value; + } return builder.createOrFold( value.getLoc(), builder.getI32Type(), value); } diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/HALLoaderToVM/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/HALLoaderToVM/test/BUILD.bazel index e2feb4e13e13..5a047972de3b 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/HALLoaderToVM/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/HALLoaderToVM/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "executable_ops.mlir", ], diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/test/BUILD.bazel index 04aeda54bf77..3a5ddbbb438f 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "cmd_ops.mlir", ], diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/IR/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Loader/IR/BUILD.bazel index 74739f603623..0f575dec174e 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Loader/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/HAL/Loader/IR/BUILD.bazel @@ -18,6 +18,7 @@ exports_files(["HALLoaderOps.td"]) iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "HALLoaderBase.td", "HALLoaderOps.td", diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/IR/HALLoaderOps.cpp b/compiler/src/iree/compiler/Modules/HAL/Loader/IR/HALLoaderOps.cpp index 9bb81dee2e91..5c5b0c43642e 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Loader/IR/HALLoaderOps.cpp +++ b/compiler/src/iree/compiler/Modules/HAL/Loader/IR/HALLoaderOps.cpp @@ -187,8 +187,9 @@ struct FoldBindingSubspansIntoDispatchOp bindingBuffers.push_back(subspanOp.getSource()); bindingOffsets.push_back(newOffset); } - if (!didChangeAny) + if (!didChangeAny) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.getBindingBuffersMutable().assign(bindingBuffers); op.getBindingOffsetsMutable().assign(bindingOffsets); diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Loader/IR/test/BUILD.bazel index ad58a2475fbd..8e3ed66d6a33 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Loader/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/HAL/Loader/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "dispatch_folding.mlir", ], diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/test/BUILD.bazel index a445578a0001..9e19da435caa 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "materialize_executables.mlir", ], diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/test/BUILD.bazel index a58f8e507078..6530aa2037ee 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "parameter_ops.mlir", ], diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/BUILD.bazel index a58f8e507078..6530aa2037ee 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "parameter_ops.mlir", ], diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/BUILD.bazel b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/BUILD.bazel index affb48d05d05..6c6b395f2d55 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/BUILD.bazel @@ -18,6 +18,7 @@ exports_files(["IOParametersOps.td"]) iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "IOParametersBase.td", "IOParametersOps.td", diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/test/BUILD.bazel index a58f8e507078..6530aa2037ee 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "parameter_ops.mlir", ], diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ArchiveUtils.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ArchiveUtils.cpp index 7922adde3439..53417ae46986 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ArchiveUtils.cpp +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ArchiveUtils.cpp @@ -14,12 +14,14 @@ namespace mlir::iree_compiler::IREE::IO::Parameters { LogicalResult handleRuntimeError(Operation *op, iree_status_t status, StringRef failureMessage) { - if (iree_status_is_ok(status)) + if (iree_status_is_ok(status)) { return success(); + } iree_host_size_t buffer_length = 0; if (!iree_status_format(status, /*buffer_capacity=*/0, - /*buffer=*/nullptr, &buffer_length)) + /*buffer=*/nullptr, &buffer_length)) { return op->emitError() << failureMessage; + } std::string message; message.reserve(buffer_length); message.resize(buffer_length - 1); diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ArchiveUtils.h b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ArchiveUtils.h index 254bc58aa57a..a199759263f7 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ArchiveUtils.h +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ArchiveUtils.h @@ -52,10 +52,11 @@ using ScopePath = std::pair; // If no `scope=` was specified the resulting scope string will be empty. static inline ScopePath splitScopePath(StringRef scopePath) { size_t i = scopePath.find_first_of('='); - if (i == StringRef::npos) + if (i == StringRef::npos) { return ScopePath("", scopePath); - else + } else { return ScopePath(scopePath.substr(0, i), scopePath.substr(i + 1)); + } } // Helper to interpret iree status messages and print the error message. diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp index 539a850ec8af..58680214a2ac 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp @@ -91,15 +91,17 @@ struct ExportParametersPass MLIRContext *context = &getContext(); // Nothing to do if no path specified. - if (scopePath.empty()) + if (scopePath.empty()) { return; + } auto [scope, path] = splitScopePath(scopePath); // Create a builder used to accumulate the parameters. ModuleOp moduleOp = getOperation(); auto builder = createArchiveBuilder(moduleOp); - if (failed(builder)) + if (failed(builder)) { return signalPassFailure(); + } // Accumulate globals that match the pass options and add them to the index. SmallVector constantGlobalOps; @@ -109,31 +111,36 @@ struct ExportParametersPass auto serializableAttr = dyn_cast_if_present( globalOp.getGlobalInitialValue()); - if (!serializableAttr) + if (!serializableAttr) { continue; + } // Check that the serialized size of the attribute is at least as big as // the pass configured minimum storage size. int64_t storageSize = serializableAttr.getStorageSize(); - if (storageSize < minimumSize) + if (storageSize < minimumSize) { continue; + } // Add the entry with a type based on its contents. - if (failed(addEntry(globalOp, serializableAttr, builder->get()))) + if (failed(addEntry(globalOp, serializableAttr, builder->get()))) { return signalPassFailure(); + } constantGlobalOps.push_back(globalOp); } // Early exit if no parameterizable globals are present. - if (constantGlobalOps.empty()) + if (constantGlobalOps.empty()) { return; + } // Create the parameter archive file opened for writing. auto fileStreamIndexOr = createParameterIndex(moduleOp, std::move(builder.value()), path); - if (failed(fileStreamIndexOr)) + if (failed(fileStreamIndexOr)) { return signalPassFailure(); + } auto [file, stream, index] = *std::move(fileStreamIndexOr); // Serialize parameters to the file. diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/GenerateSplatParameterArchive.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/GenerateSplatParameterArchive.cpp index a319589404b2..e7e2b16de58c 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/GenerateSplatParameterArchive.cpp +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/GenerateSplatParameterArchive.cpp @@ -70,14 +70,16 @@ struct GenerateSplatParameterArchivePass void runOnOperation() override { // Nothing to do if no path specified. - if (filePath.empty()) + if (filePath.empty()) { return; + } // Create a builder used to accumulate the parameters. ModuleOp moduleOp = getOperation(); auto builder = createArchiveBuilder(moduleOp); - if (failed(builder)) + if (failed(builder)) { return signalPassFailure(); + } // Find all parameters in the module and add them to the builder. // NOTE: there may be no parameters but we still will create the archive @@ -86,8 +88,9 @@ struct GenerateSplatParameterArchivePass for (auto [loc, parameterAttr] : parameterAttrs) { // Only support types we can meaningfully generate splats for. auto shapedType = dyn_cast(parameterAttr.getType()); - if (!shapedType) + if (!shapedType) { continue; + } // TODO: support other patterns/generators. auto elementAttr = getDefaultSplatAttr(shapedType.getElementType()); @@ -122,8 +125,9 @@ struct GenerateSplatParameterArchivePass // Create the parameter archive file. auto fileStreamIndexOr = createParameterIndex(moduleOp, std::move(builder.value()), filePath); - if (failed(fileStreamIndexOr)) + if (failed(fileStreamIndexOr)) { return signalPassFailure(); + } auto [file, stream, index] = *std::move(fileStreamIndexOr); // Commit the written file. diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ImportParameters.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ImportParameters.cpp index 91b6e690819b..c05ba32a614f 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ImportParameters.cpp +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ImportParameters.cpp @@ -76,8 +76,9 @@ loadParameterIndex(ModuleOp moduleOp, StringRef path, iree_io_parameter_index_t *parameterIndex) { // Open the archive file (hopefully mapping it). auto fileHandle = openArchiveFile(moduleOp, path); - if (failed(fileHandle)) + if (failed(fileHandle)) { return failure(); + } // Parse the archive as a particular format. iree_allocator_t hostAllocator = iree_allocator_system(); @@ -103,8 +104,9 @@ class ParameterIndices { iree_io_parameter_index_t *lookupOrCreate(ModuleOp moduleOp, StringRef scope) { iree_allocator_t hostAllocator = iree_allocator_system(); - if (iree_io_parameter_index_t *existing = lookup(scope)) + if (iree_io_parameter_index_t *existing = lookup(scope)) { return existing; + } iree_io_parameter_index_t *parameterIndexPtr = nullptr; if (failed(handleRuntimeError( moduleOp, @@ -133,8 +135,9 @@ loadParameterArchives(ModuleOp moduleOp, ArrayRef scopePaths) { for (auto &scopePath : scopePaths) { auto [scope, path] = splitScopePath(scopePath); auto *parameterIndex = parameterIndices.lookupOrCreate(moduleOp, scope); - if (failed(loadParameterIndex(moduleOp, path, parameterIndex))) + if (failed(loadParameterIndex(moduleOp, path, parameterIndex))) { return failure(); + } } return parameterIndices; } @@ -143,12 +146,14 @@ loadParameterArchives(ModuleOp moduleOp, ArrayRef scopePaths) { // data as stored in the file. static bool isTypeSupported(Type type) { auto shapedType = dyn_cast(type); - if (!shapedType) + if (!shapedType) { return false; + } auto elementType = shapedType.getElementType(); // NOTE: packed types not yet supported. - if (!elementType.isIntOrFloat()) + if (!elementType.isIntOrFloat()) { return false; + } const unsigned logicalBitWidth = elementType.getIntOrFloatBitWidth(); switch (logicalBitWidth) { case 8: @@ -280,29 +285,34 @@ struct ImportParametersPass void runOnOperation() override { // Nothing to do if no path specified. - if (scopePaths.empty()) + if (scopePaths.empty()) { return; + } // Open the archive file (hopefully mapping it) and parse the index. ModuleOp moduleOp = getOperation(); auto parameterIndices = loadParameterArchives(moduleOp, scopePaths); - if (failed(parameterIndices)) + if (failed(parameterIndices)) { return signalPassFailure(); + } // Decide whether to import a particular parameter. DenseSet importKeys; - for (auto &key : keys) + for (auto &key : keys) { importKeys.insert(key); + } auto shouldImportParameter = [&](IREE::Flow::NamedParameterAttr parameterAttr) -> bool { // Always try to import explicitly named parameters. - if (importKeys.contains(parameterAttr.getKey().getValue())) + if (importKeys.contains(parameterAttr.getKey().getValue())) { return true; // key match + } // If a maximum size is specified use that to limit what we import // (users may want to bring in small parameters but leave the big ones // out). - if (maximumSize && parameterAttr.getStorageSize() <= maximumSize) + if (maximumSize && parameterAttr.getStorageSize() <= maximumSize) { return true; // <= max size + } // Default to not importing. return false; }; @@ -312,14 +322,16 @@ struct ImportParametersPass // Only inspect parameter globals. auto parameterAttr = dyn_cast_if_present( globalOp.getGlobalInitialValue()); - if (!parameterAttr) + if (!parameterAttr) { continue; + } // Lookup the parameter index for the scope. auto scope = parameterAttr.getScope().getValue(); auto *parameterIndex = parameterIndices->lookup(scope); - if (!parameterIndex) + if (!parameterIndex) { continue; + } // See if the parameter is present in the scope (we may have only been // provided as partial index). @@ -351,8 +363,9 @@ struct ImportParametersPass auto valueOr = importParameter( fullName, cast(globalOp.getGlobalType()), parameterAttr, entry); - if (failed(valueOr)) + if (failed(valueOr)) { return signalPassFailure(); + } // Replace the initial value with the constant. globalOp.setGlobalInitialValue(*valueOr); diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/BUILD.bazel index eb9ea3fe8998..241c1800eb40 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "export_parameters.mlir", "generate_splat_parameter_archive.mlir", diff --git a/compiler/src/iree/compiler/Pipelines/Options.cpp b/compiler/src/iree/compiler/Pipelines/Options.cpp index 89d72888b6f3..eb88f483bdd2 100644 --- a/compiler/src/iree/compiler/Pipelines/Options.cpp +++ b/compiler/src/iree/compiler/Pipelines/Options.cpp @@ -11,6 +11,7 @@ IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::BindingOptions); IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::InputDialectOptions); IREE_DEFINE_COMPILER_OPTION_FLAGS( mlir::iree_compiler::GlobalOptimizationOptions); +IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::ParameterOptions); IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::SchedulingOptions); IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::PreprocessingOptions); IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::GlobalPipelineOptions); @@ -161,6 +162,84 @@ void PreprocessingOptions::bindOptions(OptionsBinder &binder) { llvm::cl::cat(category)); } +void ParameterOptions::bindOptions(OptionsBinder &binder) { + static llvm::cl::OptionCategory category("IREE Parameter Options"); + + // Parameter import/export options. + binder.list( + "iree-parameter-import", importPaths, + llvm::cl::desc("File paths to archives to import parameters from with an " + "optional `scope=` prefix."), + llvm::cl::cat(category)); + binder.list( + "iree-parameter-import-keys", importKeys, + llvm::cl::desc("List of parameter keys to import. Any matching keys from " + "any scope will be imported."), + llvm::cl::cat(category)); + binder.opt( + "iree-parameter-import-maximum-size", importMaximumSize, + llvm::cl::desc("Maximum size of parameters to import or 0 to disable " + "automatic import."), + llvm::cl::cat(category)); + + binder.opt( + "iree-parameter-export", exportPath, + llvm::cl::desc("File path to an archive to export parameters to with an " + "optional `scope=` prefix."), + llvm::cl::cat(category)); + binder.opt( + "iree-parameter-export-minimum-size", exportMinimumSize, + llvm::cl::desc("Minimum size of constants to export to the parameter " + "archive."), + llvm::cl::cat(category)); + + binder.opt( + "iree-parameter-splat", splatPath, + llvm::cl::desc("File path to create a parameter archive of splat values " + "from all parameter backed globals."), + llvm::cl::cat(category)); + + // Parameter encoder options. + binder.opt( + "iree-parameter-encoder-mode", encoderMode, + llvm::cl::desc("Controls how the encoder manages parameters."), + llvm::cl::values( + clEnumValN(ParameterEncoderMode::Consolidate, "consolidate", + "Merge all encoded and original parameters into a single " + "consolidated scope."), + clEnumValN(ParameterEncoderMode::Overlay, "overlay", + "Only produce encoded parameters and leave original " + "parameters untouched.")), + llvm::cl::cat(category)); + + binder.opt( + "iree-parameter-encoder-output-file", encoderOutputFile, + llvm::cl::desc(".mlir/.mlirbc file path to write the split parameter " + "encoder module to (empty = disabled)."), + llvm::cl::cat(category)); + + binder.opt( + "iree-parameter-encoder-output-scope", encoderOutputScope, + llvm::cl::desc("Parameter scope for the encoder output parameters."), + llvm::cl::cat(category)); + + // Deprecated flags aliasing the new ones above. + binder.opt( + "iree-opt-export-parameters", exportPath, + deprecated("use --iree-parameter-export= instead"), + llvm::cl::Hidden, + llvm::cl::desc("File path to an archive to export parameters to with an " + "optional `scope=` prefix."), + llvm::cl::cat(category)); + binder.opt( + "iree-opt-splat-parameters", splatPath, + deprecated("use --iree-parameter-splat= instead"), llvm::cl::Hidden, + llvm::cl::desc( + "File path to create a parameter archive of splat values out of all " + "parameter backed globals."), + llvm::cl::cat(category)); +} + void GlobalOptimizationOptions::bindOptions(OptionsBinder &binder) { static llvm::cl::OptionCategory category( "IREE options for controlling global optimizations."); @@ -183,6 +262,11 @@ void GlobalOptimizationOptions::bindOptions(OptionsBinder &binder) { llvm::cl::desc( "Enables propagation of transpose ops through convolutions."), llvm::cl::cat(category)); + binder.opt( + "iree-global-opt-enable-sink-transpose-through-pad", + sinkTransposeThroughPad, + llvm::cl::desc("Enables sinking transpose through pad operations."), + llvm::cl::cat(category)); binder.opt("iree-opt-outer-dim-concat", outerDimConcat, {init_at_opt(llvm::OptimizationLevel::O0, false), init_at_opt(llvm::OptimizationLevel::O1, true)}, @@ -211,39 +295,6 @@ void GlobalOptimizationOptions::bindOptions(OptionsBinder &binder) { "information has been extracted."), llvm::cl::cat(category)); - binder.list( - "iree-opt-import-parameters", parameterImportPaths, - llvm::cl::desc("File paths to archives to import parameters from with an " - "optional `scope=` prefix."), - llvm::cl::cat(category)); - binder.list("iree-opt-import-parameter-keys", - parameterImportKeys, - llvm::cl::desc("List of parameter keys to import."), - llvm::cl::cat(category)); - binder.opt("iree-opt-import-parameter-maximum-size", - parameterImportMaximumSize, - llvm::cl::desc("Maximum size of parameters to import."), - llvm::cl::cat(category)); - - binder.opt( - "iree-opt-export-parameters", parameterExportPath, - llvm::cl::desc("File path to an archive to export parameters to with an " - "optional `scope=` prefix."), - llvm::cl::cat(category)); - binder.opt( - "iree-opt-export-parameter-minimum-size", parameterExportMinimumSize, - llvm::cl::desc( - "Minimum size of constants to export to the archive created in " - "`iree-opt-export-parameter-archive-export-file`."), - llvm::cl::cat(category)); - - binder.opt( - "iree-opt-splat-parameters", parameterSplatExportFile, - llvm::cl::desc( - "File path to create a parameter archive of splat values out of all " - "parameter backed globals."), - llvm::cl::cat(category)); - binder.opt( "iree-opt-generalize-matmul", generalizeMatmul, {init_at_opt(llvm::OptimizationLevel::O0, false), diff --git a/compiler/src/iree/compiler/Pipelines/Options.h b/compiler/src/iree/compiler/Pipelines/Options.h index eee593f24feb..811bd5b17f1a 100644 --- a/compiler/src/iree/compiler/Pipelines/Options.h +++ b/compiler/src/iree/compiler/Pipelines/Options.h @@ -105,28 +105,58 @@ struct PreprocessingOptions { using FromFlags = OptionsFromFlags; }; -// Options controlling high level optimizations. -struct GlobalOptimizationOptions { - llvm::OptimizationLevel optLevel = llvm::OptimizationLevel::O0; +// Defines the mode for parameter encoding. +enum class ParameterEncoderMode { + // Merge all encoded and original parameters into a single consolidated scope. + Consolidate = 0, + // Only produce encoded parameters and leave original parameters untouched. + Overlay = 1, +}; + +// Options controlling parameter management (import/export and encoding). +struct ParameterOptions { + //===--------------------------------------------------------------------===// + // Parameter Import/Export + //===--------------------------------------------------------------------===// // File paths to archives to import parameters from with an optional // `scope=` prefix. - std::vector parameterImportPaths; + std::vector importPaths; // List of parameter keys to import. Any matching keys from any scope will be // imported. - std::vector parameterImportKeys; + std::vector importKeys; // Maximum size of parameters to import or 0 to disable automatic import. - int64_t parameterImportMaximumSize = 0; + int64_t importMaximumSize = 0; // File path to an archive to export parameters to with an optional // `scope=` prefix. - std::string parameterExportPath; + std::string exportPath; // Minimum size of constants to export as parameters. - int64_t parameterExportMinimumSize = 0; + int64_t exportMinimumSize = 0; // File path to create a splat parameter archive out of all parameters in the // module. - std::string parameterSplatExportFile = ""; + std::string splatPath = ""; + + //===--------------------------------------------------------------------===// + // Parameter Encoder + //===--------------------------------------------------------------------===// + + // Controls how the encoder manages parameters. + ParameterEncoderMode encoderMode = ParameterEncoderMode::Consolidate; + // .mlir/.mlirbc file path to write the split parameter encoder module to + // (empty = disabled). + std::string encoderOutputFile; + // Parameter scope for the encoder output parameters. + std::string encoderOutputScope = "encoded"; + + void bindOptions(OptionsBinder &binder); + using FromFlags = OptionsFromFlags; +}; + +// Options controlling high level optimizations. +struct GlobalOptimizationOptions { + llvm::OptimizationLevel optLevel = llvm::OptimizationLevel::O0; // Enables aggressive propagation of transposes to the inputs of named ops, // rewriting named ops as fused generics. @@ -135,6 +165,9 @@ struct GlobalOptimizationOptions { // Enables propagation of transpose ops through convolutions. bool propagateTransposesThroughConv = false; + // Enables sinking transpose through pad operations. + bool sinkTransposeThroughPad = false; + // Enables transposing all concatenations to the outer most dimension. bool outerDimConcat = false; diff --git a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp index 36a6dadd8eab..c8b1cace4003 100644 --- a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp +++ b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp @@ -58,15 +58,17 @@ IREEVMPipelineHooks::operator IREE::HAL::PipelineHooks() const { auto beforePhase = this->beforePhase; halHooks.beforePhase = [beforePhase](IREE::HAL::PipelinePhase phase, OpPassManager &passManager) { - if (beforePhase) + if (beforePhase) { beforePhase(getIREEVMPipelinePhase(phase), passManager); + } }; auto afterPhase = this->afterPhase; halHooks.afterPhase = [afterPhase](IREE::HAL::PipelinePhase phase, OpPassManager &passManager) { - if (afterPhase) + if (afterPhase) { afterPhase(getIREEVMPipelinePhase(phase), passManager); + } }; return halHooks; @@ -76,6 +78,7 @@ void buildIREEPrecompileTransformPassPipeline( const IREE::HAL::TargetRegistry &targetRegistry, GlobalPipelineOptions pipelineOptions, BindingOptions bindingOptions, InputDialectOptions inputOptions, PreprocessingOptions preprocessingOptions, + ParameterOptions parameterOptions, GlobalOptimizationOptions globalOptimizationOptions, DispatchCreationOptions dispatchCreationOptions, SchedulingOptions schedulingOptions, @@ -89,8 +92,9 @@ void buildIREEPrecompileTransformPassPipeline( if (compileFrom < IREEVMPipelinePhase::Input) { // late-entry auto inputType = inputOptions.parseInputTypeMnemonic(); IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "Input"); - if (hooks.beforePhase) + if (hooks.beforePhase) { hooks.beforePhase(IREEVMPipelinePhase::Input, passManager); + } if (hooks.pipelineExtensions) { hooks.pipelineExtensions->extendInputConversionPreprocessingPassPipeline( passManager, inputType); @@ -131,18 +135,21 @@ void buildIREEPrecompileTransformPassPipeline( InputConversion::buildCommonInputConversionPassPipeline( passManager, inputTransformOptions); - if (hooks.afterPhase) + if (hooks.afterPhase) { hooks.afterPhase(IREEVMPipelinePhase::Input, passManager); + } IREE_TRACE_ADD_END_FRAME_PASS(passManager, "Input"); } - if (compileTo == IREEVMPipelinePhase::Input) + if (compileTo == IREEVMPipelinePhase::Input) { return; // early-exit + } // Now that inputs are legalized, generate wrapper for entry functions. if (compileFrom < IREEVMPipelinePhase::ABI) { // late-entry IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "ABI"); - if (hooks.beforePhase) + if (hooks.beforePhase) { hooks.beforePhase(IREEVMPipelinePhase::ABI, passManager); + } IREE::ABI::InvocationOptions invocationOptions; invocationOptions.invocationModel = schedulingOptions.executionModel == @@ -155,12 +162,14 @@ void buildIREEPrecompileTransformPassPipeline( if (bindingOptions.tflite) { IREE::TFLite::buildTransformPassPipeline(passManager); } - if (hooks.afterPhase) + if (hooks.afterPhase) { hooks.afterPhase(IREEVMPipelinePhase::ABI, passManager); + } IREE_TRACE_ADD_END_FRAME_PASS(passManager, "ABI"); } - if (compileTo == IREEVMPipelinePhase::ABI) + if (compileTo == IREEVMPipelinePhase::ABI) { return; // early-exit + } // If the user specified a set of target devices we attach them to the module // IR so that they are available for all passes that may want to use this @@ -175,22 +184,20 @@ void buildIREEPrecompileTransformPassPipeline( halAssignmentOptions); GlobalOptimization::TransformOptions globalTransformOptions; - globalTransformOptions.parameterImportPaths = - globalOptimizationOptions.parameterImportPaths; - globalTransformOptions.parameterImportKeys = - globalOptimizationOptions.parameterImportKeys; + globalTransformOptions.parameterImportPaths = parameterOptions.importPaths; + globalTransformOptions.parameterImportKeys = parameterOptions.importKeys; globalTransformOptions.parameterImportMaximumSize = - globalOptimizationOptions.parameterImportMaximumSize; - globalTransformOptions.parameterExportPath = - globalOptimizationOptions.parameterExportPath; + parameterOptions.importMaximumSize; + globalTransformOptions.parameterExportPath = parameterOptions.exportPath; globalTransformOptions.parameterExportMinimumSize = - globalOptimizationOptions.parameterExportMinimumSize; - globalTransformOptions.parameterSplatExportFile = - globalOptimizationOptions.parameterSplatExportFile; + parameterOptions.exportMinimumSize; + globalTransformOptions.parameterSplatPath = parameterOptions.splatPath; globalTransformOptions.aggressiveTransposePropagation = globalOptimizationOptions.aggressiveTransposePropagation; globalTransformOptions.propagateTransposesThroughConv = globalOptimizationOptions.propagateTransposesThroughConv; + globalTransformOptions.sinkTransposeThroughPad = + globalOptimizationOptions.sinkTransposeThroughPad; globalTransformOptions.outerDimConcat = globalOptimizationOptions.outerDimConcat; // The pipeline option has higher priority. @@ -229,16 +236,19 @@ void buildIREEPrecompileTransformPassPipeline( default: if (compileFrom < IREEVMPipelinePhase::Preprocessing) { // late-entry. IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "Preprocessing"); - if (hooks.beforePhase) + if (hooks.beforePhase) { hooks.beforePhase(IREEVMPipelinePhase::Preprocessing, passManager); + } Preprocessing::buildPreprocessingPassPipeline( passManager, preprocessingOptions, hooks.pipelineExtensions); - if (hooks.afterPhase) + if (hooks.afterPhase) { hooks.afterPhase(IREEVMPipelinePhase::Preprocessing, passManager); + } IREE_TRACE_ADD_END_FRAME_PASS(passManager, "Preprocessing"); } - if (compileTo == IREEVMPipelinePhase::Preprocessing) + if (compileTo == IREEVMPipelinePhase::Preprocessing) { return; // early-exit + } if (compileFrom < IREEVMPipelinePhase::GlobalOptimization) { // late-entry // This pass pipeline recursively invokes the compiler if constEval is @@ -256,20 +266,23 @@ void buildIREEPrecompileTransformPassPipeline( } else { IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "GlobalOptimization"); } - if (hooks.beforePhase) + if (hooks.beforePhase) { hooks.beforePhase(IREEVMPipelinePhase::GlobalOptimization, passManager); + } GlobalOptimization::buildGlobalOptimizationPassPipeline( passManager, globalTransformOptions); - if (hooks.afterPhase) + if (hooks.afterPhase) { hooks.afterPhase(IREEVMPipelinePhase::GlobalOptimization, passManager); + } if (globalOptimizationOptions.constEval) { IREE_TRACE_ADD_END_FRAME_PASS(passManager, "GlobalOptimizationConst"); } else { IREE_TRACE_ADD_END_FRAME_PASS(passManager, "GlobalOptimization"); } } - if (compileTo == IREEVMPipelinePhase::GlobalOptimization) + if (compileTo == IREEVMPipelinePhase::GlobalOptimization) { return; // early-exit + } break; } @@ -279,6 +292,7 @@ void buildIREEVMTransformPassPipeline( const IREE::HAL::TargetRegistry &targetRegistry, GlobalPipelineOptions pipelineOptions, BindingOptions bindingOptions, InputDialectOptions inputOptions, PreprocessingOptions preprocessingOptions, + ParameterOptions parameterOptions, GlobalOptimizationOptions globalOptimizationOptions, DispatchCreationOptions dispatchCreationOptions, SchedulingOptions schedulingOptions, @@ -288,12 +302,13 @@ void buildIREEVMTransformPassPipeline( IREEVMPipelinePhase compileTo) { buildIREEPrecompileTransformPassPipeline( targetRegistry, pipelineOptions, bindingOptions, inputOptions, - preprocessingOptions, globalOptimizationOptions, dispatchCreationOptions, - schedulingOptions, halTargetOptions, hooks, passManager, compileFrom, - compileTo); + preprocessingOptions, parameterOptions, globalOptimizationOptions, + dispatchCreationOptions, schedulingOptions, halTargetOptions, hooks, + passManager, compileFrom, compileTo); - if (compileTo <= IREEVMPipelinePhase::GlobalOptimization) + if (compileTo <= IREEVMPipelinePhase::GlobalOptimization) { return; // early-exit + } IREE::Stream::TransformOptions streamOptions; // TODO(benvanik): find a way to share the enums w/o circular deps. @@ -304,6 +319,18 @@ void buildIREEVMTransformPassPipeline( (IREE::Stream::DumpOutputFormat)schedulingOptions.dumpStatisticsFormat; streamOptions.dumpStatisticsFile = schedulingOptions.dumpStatisticsFile; + // Set parameter encoder options. These are mapped to + // SplitParameterEncoderPassOptions when the pass is created in + // Stream/Transforms/Passes.cpp. + if (!parameterOptions.encoderOutputFile.empty()) { + streamOptions.parameterEncoderMode = + (IREE::Stream::ParameterEncoderMode)parameterOptions.encoderMode; + streamOptions.parameterEncoderOutputFile = + parameterOptions.encoderOutputFile; + streamOptions.parameterEncoderOutputScope = + parameterOptions.encoderOutputScope; + } + switch (schedulingOptions.executionModel) { case SchedulingOptions::ExecutionModel::HostOnly: // No flow/stream processing (implies no tensors). @@ -342,42 +369,51 @@ void buildIREEVMTransformPassPipeline( pipelineOptions.constExprHoisting; if (compileFrom < IREEVMPipelinePhase::DispatchCreation) { // late-entry IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "DispatchCreation"); - if (hooks.beforePhase) + if (hooks.beforePhase) { hooks.beforePhase(IREEVMPipelinePhase::DispatchCreation, passManager); + } DispatchCreation::buildDispatchCreationPassPipeline( passManager, dispatchTransformOptions); - if (hooks.afterPhase) + if (hooks.afterPhase) { hooks.afterPhase(IREEVMPipelinePhase::DispatchCreation, passManager); + } IREE_TRACE_ADD_END_FRAME_PASS(passManager, "DispatchCreation"); } - if (compileTo == IREEVMPipelinePhase::DispatchCreation) + if (compileTo == IREEVMPipelinePhase::DispatchCreation) { return; // early-exit + } IREE::Flow::TransformOptions flowOptions; if (compileFrom < IREEVMPipelinePhase::Flow) { // late-entry IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "Flow"); - if (hooks.beforePhase) + if (hooks.beforePhase) { hooks.beforePhase(IREEVMPipelinePhase::Flow, passManager); + } IREE::Flow::buildFlowTransformPassPipeline(passManager, flowOptions); - if (hooks.afterPhase) + if (hooks.afterPhase) { hooks.afterPhase(IREEVMPipelinePhase::Flow, passManager); + } IREE_TRACE_ADD_END_FRAME_PASS(passManager, "Flow"); } - if (compileTo == IREEVMPipelinePhase::Flow) + if (compileTo == IREEVMPipelinePhase::Flow) { return; // early-exit + } if (compileFrom < IREEVMPipelinePhase::Stream) { // late-entry IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "Stream"); - if (hooks.beforePhase) + if (hooks.beforePhase) { hooks.beforePhase(IREEVMPipelinePhase::Stream, passManager); + } IREE::Stream::buildStreamTransformPassPipeline(passManager, streamOptions); - if (hooks.afterPhase) + if (hooks.afterPhase) { hooks.afterPhase(IREEVMPipelinePhase::Stream, passManager); + } IREE_TRACE_ADD_END_FRAME_PASS(passManager, "Stream"); } - if (compileTo == IREEVMPipelinePhase::Stream) + if (compileTo == IREEVMPipelinePhase::Stream) { return; // early-exit + } break; } @@ -388,8 +424,9 @@ void buildIREEVMTransformPassPipeline( if (compileFrom < IREEVMPipelinePhase::HAL) { // late-entry IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "HAL"); - if (hooks.beforePhase) + if (hooks.beforePhase) { hooks.beforePhase(IREEVMPipelinePhase::HAL, passManager); + } switch (schedulingOptions.executionModel) { case SchedulingOptions::ExecutionModel::HostOnly: // No HAL required. @@ -410,8 +447,9 @@ void buildIREEVMTransformPassPipeline( passManager, targetRegistry, halTargetOptions); break; } - if (hooks.afterPhase) + if (hooks.afterPhase) { hooks.afterPhase(IREEVMPipelinePhase::HAL, passManager); + } IREE_TRACE_ADD_END_FRAME_PASS(passManager, "HAL"); } if (compileTo == IREEVMPipelinePhase::HAL || @@ -421,15 +459,18 @@ void buildIREEVMTransformPassPipeline( if (compileFrom < IREEVMPipelinePhase::VM) { // late-entry IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "VM"); - if (hooks.beforePhase) + if (hooks.beforePhase) { hooks.beforePhase(IREEVMPipelinePhase::VM, passManager); + } IREE::VM::buildVMTransformPassPipeline(passManager, vmTargetOptions); - if (hooks.afterPhase) + if (hooks.afterPhase) { hooks.afterPhase(IREEVMPipelinePhase::VM, passManager); + } IREE_TRACE_ADD_END_FRAME_PASS(passManager, "VM"); } - if (compileTo == IREEVMPipelinePhase::VM) + if (compileTo == IREEVMPipelinePhase::VM) { return; // early-exit + } } void buildDefaultIREEVMTransformPassPipeline(OpPassManager &passManager) { @@ -449,7 +490,8 @@ void buildDefaultIREEVMTransformPassPipeline(OpPassManager &passManager) { IREE::HAL::TargetRegistry::getGlobal(), GlobalPipelineOptions::FromFlags::get(), BindingOptions::FromFlags::get(), InputDialectOptions::FromFlags::get(), - PreprocessingOptions::FromFlags::get(), highLevelOptimizations, + PreprocessingOptions::FromFlags::get(), + ParameterOptions::FromFlags::get(), highLevelOptimizations, DispatchCreationOptions::FromFlags::get(), SchedulingOptions::FromFlags::get(), IREE::HAL::TargetOptions::FromFlags::get(), diff --git a/compiler/src/iree/compiler/Pipelines/Pipelines.h b/compiler/src/iree/compiler/Pipelines/Pipelines.h index 104cc42875b5..c17f9bd893f0 100644 --- a/compiler/src/iree/compiler/Pipelines/Pipelines.h +++ b/compiler/src/iree/compiler/Pipelines/Pipelines.h @@ -103,6 +103,7 @@ void buildIREEPrecompileTransformPassPipeline( const IREE::HAL::TargetRegistry &targetRegistry, GlobalPipelineOptions pipelineOptions, BindingOptions bindingOptions, InputDialectOptions inputOptions, PreprocessingOptions preprocessingOptions, + ParameterOptions parameterOptions, GlobalOptimizationOptions highLevelOptimizationOptions, DispatchCreationOptions dispatchCreationOptions, SchedulingOptions schedulingOptions, @@ -120,6 +121,7 @@ void buildIREEVMTransformPassPipeline( const IREE::HAL::TargetRegistry &targetRegistry, GlobalPipelineOptions pipelineOptions, BindingOptions bindingOptions, InputDialectOptions inputOptions, PreprocessingOptions preprocessingOptions, + ParameterOptions parameterOptions, GlobalOptimizationOptions highLevelOptimizationOptions, DispatchCreationOptions dispatchCreationOptions, SchedulingOptions schedulingOptions, diff --git a/compiler/src/iree/compiler/PluginAPI/PluginManager.cpp b/compiler/src/iree/compiler/PluginAPI/PluginManager.cpp index 3f37b9faeb72..3ac9025a2029 100644 --- a/compiler/src/iree/compiler/PluginAPI/PluginManager.cpp +++ b/compiler/src/iree/compiler/PluginAPI/PluginManager.cpp @@ -136,8 +136,9 @@ LogicalResult PluginManagerSession::initializePlugins() { } // Skip if already initialized. - if (!initializedIds.insert(it.first()).second) + if (!initializedIds.insert(it.first()).second) { continue; + } if (options.printPluginInfo) { llvm::errs() << "[IREE plugins]: Initializing default '" << it.first() @@ -156,8 +157,9 @@ LogicalResult PluginManagerSession::initializePlugins() { } // Skip if already initialized. - if (!initializedIds.insert(pluginId).second) + if (!initializedIds.insert(pluginId).second) { continue; + } if (options.printPluginInfo) { llvm::errs() << "[IREE plugins]: Initializing plugin '" << pluginId @@ -187,8 +189,9 @@ void PluginManagerSession::registerDialects(DialectRegistry ®istry) { LogicalResult PluginManagerSession::activatePlugins(MLIRContext *context) { for (auto *s : initializedSessions) { - if (failed(s->activate(context))) + if (failed(s->activate(context))) { return failure(); + } } return success(); } diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ApplyPDLPatterns.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ApplyPDLPatterns.cpp index 627e1a1af948..f5db486146b3 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/ApplyPDLPatterns.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/ApplyPDLPatterns.cpp @@ -320,12 +320,14 @@ createFlowDispatchOp(PatternRewriter &rewriter, SymbolRefAttr exportOp, // Get the dynamic dims for the operands. for (auto operand : operands) { auto tensorType = dyn_cast(operand.getType()); - if (!tensorType) + if (!tensorType) { continue; + } for (auto [index, shape] : llvm::enumerate(tensorType.getShape())) { - if (ShapedType::isStatic(shape)) + if (ShapedType::isStatic(shape)) { continue; + } Value dim = tensor::DimOp::create(rewriter, loc, operand, index); operandDynamicDims.push_back(dim); @@ -352,8 +354,9 @@ getDynamicResultDims(PatternRewriter &rewriter, ValueRange givenResultDims) { SmallVector mixedValues = getAsOpFoldResult(givenResultDims); for (auto ofr : mixedValues) { auto value = dyn_cast(ofr); - if (!value) + if (!value) { continue; + } dynamicResultDims.push_back(value); } return dynamicResultDims; diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel index 99c945c87c23..4662d4b2b7f9 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel @@ -47,7 +47,6 @@ iree_compiler_cc_library( "PadLinalgOps.cpp", "PadToIntrinsics.cpp", "Passes.cpp", - "SinkTransposeThroughPad.cpp", "TransposeMatmul.cpp", ], hdrs = [ diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt index 1d8165618518..3471f1db47a1 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt @@ -38,7 +38,6 @@ iree_cc_library( "PadLinalgOps.cpp" "PadToIntrinsics.cpp" "Passes.cpp" - "SinkTransposeThroughPad.cpp" "TransposeMatmul.cpp" DEPS ::PassesIncGen diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp index 267e19e605d3..86f407760831 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp @@ -30,15 +30,17 @@ static bool hasAllOneValues(DenseIntElementsAttr attr) { static Value createAdd(Location loc, Value x, Value y, bool isInt, OpBuilder &builder) { - if (isInt) + if (isInt) { return arith::AddIOp::create(builder, loc, x, y); + } return arith::AddFOp::create(builder, loc, x, y); } static Value createMul(Location loc, Value x, Value y, bool isInt, OpBuilder &builder) { - if (isInt) + if (isInt) { return arith::MulIOp::create(builder, loc, x, y); + } return arith::MulFOp::create(builder, loc, x, y); } @@ -255,11 +257,12 @@ class ConvertDepthwiseConv2DNhwcHwc final } // TODO: Support dilation. - if (!hasAllOneValues(convOp.getDilations())) + if (!hasAllOneValues(convOp.getDilations())) { return rewriter.notifyMatchFailure(convOp, [](Diagnostic &diag) { diag << "[unimplemented] " << "expected no dilations (expected dilations to all be one)."; }); + } auto loc = convOp.getLoc(); @@ -415,11 +418,12 @@ class ConvertConv2DNchwFchw final } // TODO: Support dilation. - if (!hasAllOneValues(convOp.getDilations())) + if (!hasAllOneValues(convOp.getDilations())) { return rewriter.notifyMatchFailure(convOp, [](Diagnostic &diag) { diag << "[unimplemented] " << "expected no dilations (expected dilations to all be one)."; }); + } Value input = convOp.getInputs()[0]; Value filter = convOp.getInputs()[1]; diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvFilterToChannelsLast.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvFilterToChannelsLast.cpp index 8b03a4bee45e..d1eda2597e1c 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvFilterToChannelsLast.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvFilterToChannelsLast.cpp @@ -122,6 +122,13 @@ struct ConvertGenericFilterToFhwc : public OpRewritePattern { return failure(); } + // Require non-empty filter, input and output channel dimensions. + if (convolutionDims->outputChannel.empty() || + convolutionDims->inputChannel.empty() || + convolutionDims->filterLoop.empty()) { + return failure(); + } + OpOperand *input = linalgOp.getDpsInputOperand(0); OpOperand *filter = linalgOp.getDpsInputOperand(1); OpOperand *output = linalgOp.getDpsInitOperand(0); @@ -161,11 +168,10 @@ struct ConvertGenericFilterToFhwc : public OpRewritePattern { return positions; }; - // Don't transpose when the input is in batch-last layout (e.g., CHWN). + // Don't transpose when the input is in not batch-first layout (e.g., CHWN). SmallVector batchInputPos = getDimPositions(convolutionDims->batch, inputMap); - if (!batchInputPos.empty() && - batchInputPos.back() == inputShape.size() - 1) { + if (!batchInputPos.empty() && batchInputPos.front() != 0) { return failure(); } @@ -262,8 +268,9 @@ struct ConvertGenericFilterToFhwc : public OpRewritePattern { FailureOr reorderOp = linalg::interchangeGenericOp(rewriter, genericOp, interchange); - if (failed(reorderOp)) + if (failed(reorderOp)) { return failure(); + } rewriter.replaceOp(linalgOp, reorderOp->getResults()); return success(); diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp index ee49524e2942..de71cfd01c34 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp @@ -522,10 +522,12 @@ class GeneralizeOuterUnitDimsPackOp final LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override { - if (!packOp.getOuterDimsPerm().empty()) + if (!packOp.getOuterDimsPerm().empty()) { return failure(); - if (packOp.getPaddingValue()) + } + if (packOp.getPaddingValue()) { return failure(); + } RankedTensorType destType = cast(packOp.getDest().getType()); @@ -572,8 +574,9 @@ class GeneralizeOuterUnitDimsPackOp final int64_t nTiled = 0; for (int64_t srcIdx = 0; srcIdx < srcRank; srcIdx++) { reassocationIndices.push_back({srcIdx + nTiled}); - while (innerDims.contains(srcIdx + nTiled)) + while (innerDims.contains(srcIdx + nTiled)) { reassocationIndices.back().push_back(srcIdx + ++nTiled); + } } rewriter.replaceOpWithNewOp( @@ -603,8 +606,9 @@ class GeneralizeOuterUnitDimsUnPackOp final LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const override { - if (!unpackOp.getOuterDimsPerm().empty()) + if (!unpackOp.getOuterDimsPerm().empty()) { return failure(); + } RankedTensorType srcType = cast(unpackOp.getSource().getType()); diff --git a/compiler/src/iree/compiler/Preprocessing/Common/InterpreterPass.cpp b/compiler/src/iree/compiler/Preprocessing/Common/InterpreterPass.cpp index eaf16845b335..cd2de6158f9e 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/InterpreterPass.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/InterpreterPass.cpp @@ -30,8 +30,9 @@ class InterpreterPass // pass finishes. OwningOpRef transformModule; if (failed(transform::detail::assembleTransformLibraryFromPaths( - context, transformSpecPath, transformModule))) + context, transformSpecPath, transformModule))) { return signalPassFailure(); + } Operation *payloadRoot = getOperation(); Operation *transformEntryPoint = transform::detail::findTransformEntryPoint( getOperation(), *transformModule, "__transform_main"); diff --git a/compiler/src/iree/compiler/Preprocessing/Common/PadLinalgOps.cpp b/compiler/src/iree/compiler/Preprocessing/Common/PadLinalgOps.cpp index 93fcd02cedfd..95ac26cbde7f 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/PadLinalgOps.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/PadLinalgOps.cpp @@ -30,8 +30,9 @@ class PadMatmulOp : public OpInterfaceRewritePattern { Operation *op = linalgOp.getOperation(); const bool isBatchMatmul = isa(op); const bool isMatmul = isa(op); - if (!isBatchMatmul && !isMatmul) + if (!isBatchMatmul && !isMatmul) { return failure(); + } Location loc = linalgOp.getLoc(); Value lhs = linalgOp.getDpsInputOperand(0)->get(); @@ -42,11 +43,13 @@ class PadMatmulOp : public OpInterfaceRewritePattern { auto rhsType = dyn_cast(rhs.getType()); auto resultType = dyn_cast(result.getType()); - if (!lhsType || !rhsType) + if (!lhsType || !rhsType) { return failure(); + } - if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) + if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) { return failure(); + } auto lhsShape = lhsType.getShape(); auto rhsShape = rhsType.getShape(); @@ -63,13 +66,15 @@ class PadMatmulOp : public OpInterfaceRewritePattern { int paddingForN = newNSize - N; int paddingForK = newKSize - K; - if (paddingForM == 0 && paddingForN == 0 && paddingForK == 0) + if (paddingForM == 0 && paddingForN == 0 && paddingForK == 0) { return failure(); + } auto getFullShape = [&](ArrayRef dims) { SmallVector shape; - if (isBatchMatmul) + if (isBatchMatmul) { shape.push_back(B); + } llvm::append_range(shape, dims); return shape; }; diff --git a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp index ef293c9dd97b..68dd88fb8f38 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp @@ -138,8 +138,9 @@ expandMapsAndIterators(SmallVector &expandedMaps, map = map.shiftDims(1, expandDstDim); std::optional maybeDim = map.getResultPosition( getAffineDimExpr(expandSrcDim, map.getContext())); - if (!maybeDim) + if (!maybeDim) { continue; + } map = map.insertResult(getAffineDimExpr(expandDstDim, map.getContext()), maybeDim.value() + 1); } @@ -158,8 +159,9 @@ getIntrinsics(linalg::LinalgOp linalgOp, // For LIT testing, also directly search TargetAttr around the op. target = getGPUTargetAttr(linalgOp); } - if (!target) + if (!target) { return {}; + } IREE::GPU::MMAOpsArrayAttr mmaKinds = target.getWgp().getMma(); @@ -176,8 +178,9 @@ padConvOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp, // Early exit if cannot find intrinsics or if multiple executable targets. SmallVector intrinsics = getIntrinsics(linalgOp, executableTargets); - if (intrinsics.empty()) + if (intrinsics.empty()) { return; + } // Check that conv has met conditions to go down mfma. SmallVector bounds = linalgOp.getStaticLoopRanges(); @@ -348,8 +351,9 @@ static void padContractionLikeOp( // Early exit if cannot find intrinsics or if multiple executable targets. SmallVector intrinsics = getIntrinsics(linalgOp, executableTargets); - if (intrinsics.empty()) + if (intrinsics.empty()) { return; + } Location loc = linalgOp.getLoc(); @@ -377,8 +381,9 @@ static void padContractionLikeOp( auto operandMap = linalgOp.getMatchingIndexingMap(operand); std::optional maybeDim = operandMap.getResultPosition( getAffineDimExpr(targetDim, operandMap.getContext())); - if (maybeDim) + if (maybeDim) { return std::pair{operand->get(), maybeDim.value()}; + } } return std::nullopt; }; @@ -405,8 +410,9 @@ static void padContractionLikeOp( OpFoldResult mSizeExpr = rewriter.getIndexAttr(mSize); if (ShapedType::isDynamic(mSize)) { auto mOperandDimPair = getSrcOperandAndDim(mDim); - if (!mOperandDimPair) + if (!mOperandDimPair) { return; + } auto [mOperand, mOperandDim] = mOperandDimPair.value(); mSizeExpr = tensor::DimOp::create(rewriter, loc, mOperand, mOperandDim) .getResult(); @@ -419,8 +425,9 @@ static void padContractionLikeOp( OpFoldResult nSizeExpr = rewriter.getIndexAttr(nSize); if (ShapedType::isDynamic(nSize)) { auto nOperandDimPair = getSrcOperandAndDim(nDim); - if (!nOperandDimPair) + if (!nOperandDimPair) { return; + } auto [nOperand, nOperandDim] = nOperandDimPair.value(); nSizeExpr = tensor::DimOp::create(rewriter, loc, nOperand, nOperandDim) .getResult(); @@ -433,8 +440,9 @@ static void padContractionLikeOp( OpFoldResult kSizeExpr = rewriter.getIndexAttr(kSize); if (ShapedType::isDynamic(kSize)) { auto kOperandDimPair = getSrcOperandAndDim(kDim); - if (!kOperandDimPair) + if (!kOperandDimPair) { return; + } auto [kOperand, kOperandDim] = kOperandDimPair.value(); kSizeExpr = tensor::DimOp::create(rewriter, loc, kOperand, kOperandDim) .getResult(); @@ -474,14 +482,16 @@ static void padContractionLikeOp( auto getOperandPadding = [&](AffineMap operandMap) -> SmallVector { auto operandRank = operandMap.getNumResults(); - if (operandRank == 0) + if (operandRank == 0) { return {}; + } SmallVector operandPadding(operandRank, zero); for (auto [targetDim, targetPad] : llvm::zip(mnkDim, mnkPadding)) { std::optional maybeDim = operandMap.getResultPosition( getAffineDimExpr(targetDim, operandMap.getContext())); - if (!maybeDim) + if (!maybeDim) { continue; + } operandPadding[maybeDim.value()] = targetPad; } return operandPadding; @@ -541,13 +551,14 @@ static void padContractionLikeOp( SmallVector offsets(resultRank, zero), strides(resultRank, one), sizes; for (auto [dimIdx, dimSize] : llvm::enumerate(resultShape)) { - if (ShapedType::isDynamic(dimSize)) + if (ShapedType::isDynamic(dimSize)) { sizes.push_back( tensor::DimOp::create(rewriter, loc, linalgOp.getDpsInitOperand(0)->get(), dimIdx) .getResult()); - else + } else { sizes.push_back(rewriter.getIndexAttr(dimSize)); + } } rewriter.replaceOpWithNewOp(linalgOp, paddedCompute, offsets, sizes, strides); diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td index 68be70b0d00d..defec3e3d13a 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td @@ -172,13 +172,4 @@ def GeneralizeLinalgMatMulPass : ]; } -def SinkTransposeThroughPadPass : - InterfacePass<"iree-preprocessing-sink-transpose-through-pad", "mlir::FunctionOpInterface"> { - let summary = "Sink linalg transpose ops through tensor pad ops"; - let dependentDialects = [ - "mlir::linalg::LinalgDialect", - "mlir::tensor::TensorDialect", - ]; -} - #endif // IREE_PREPROCESSING_COMMON_PASSES diff --git a/compiler/src/iree/compiler/Preprocessing/Common/SinkTransposeThroughPad.cpp b/compiler/src/iree/compiler/Preprocessing/Common/SinkTransposeThroughPad.cpp deleted file mode 100644 index 4d1937a66f8f..000000000000 --- a/compiler/src/iree/compiler/Preprocessing/Common/SinkTransposeThroughPad.cpp +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2025 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" -#include "iree/compiler/Preprocessing/Common/Passes.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Utils/IndexingUtils.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir::iree_compiler::Preprocessing { - -#define GEN_PASS_DEF_SINKTRANSPOSETHROUGHPADPASS -#include "iree/compiler/Preprocessing/Common/Passes.h.inc" - -static Value createTransposeInit(OpBuilder &builder, Value source, - ArrayRef perm) { - SmallVector mixedSizes = - tensor::getMixedSizes(builder, source.getLoc(), source); - applyPermutationToVector(mixedSizes, perm); - Type elemType = cast(source.getType()).getElementType(); - Value empty = - tensor::EmptyOp::create(builder, source.getLoc(), mixedSizes, elemType) - .getResult(); - return empty; -} - -static Value createTranspose(OpBuilder &builder, Value source, - ArrayRef perm) { - Value empty = createTransposeInit(builder, source, perm); - return linalg::TransposeOp::create(builder, source.getLoc(), source, empty, - perm) - ->getResult(0); -} - -// Sinks a transpose through a tensor.pad -class SinkTransposeThroughPadOp : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::PadOp padOp, - PatternRewriter &rewriter) const override { - if (!IREE::Flow::isNonNullAndOutsideDispatch(padOp)) { - return failure(); - } - Value source = padOp.getSource(); - auto transposeOp = source.getDefiningOp(); - if (!transposeOp) { - return failure(); - } - - Block &block = padOp.getRegion().front(); - if (llvm::any_of(block.getArguments(), [](BlockArgument blockArg) { - return blockArg.getNumUses(); - })) { - return failure(); - } - - auto invPerm = invertPermutationVector(transposeOp.getPermutation()); - SmallVector lowSizes = padOp.getMixedLowPad(); - SmallVector highSizes = padOp.getMixedHighPad(); - applyPermutationToVector(lowSizes, invPerm); - applyPermutationToVector(highSizes, invPerm); - - RankedTensorType oldPaddedType = cast(padOp.getType()); - RankedTensorType newPaddedType = oldPaddedType.clone( - applyPermutation(oldPaddedType.getShape(), invPerm)); - auto newPadOp = tensor::PadOp::create( - rewriter, padOp.getLoc(), newPaddedType, transposeOp.getInput(), - lowSizes, highSizes, padOp.getNofold()); - rewriter.cloneRegionBefore(padOp.getRegion(), newPadOp.getRegion(), - newPadOp.getRegion().begin()); - Value newTransposeOp = - createTranspose(rewriter, newPadOp, transposeOp.getPermutation()); - rewriter.replaceOp(padOp, newTransposeOp); - return success(); - } -}; - -namespace { -struct SinkTransposeThroughPadPass - : public impl::SinkTransposeThroughPadPassBase< - SinkTransposeThroughPadPass> { - void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - patterns.insert(&getContext()); - if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { - getOperation().emitError(getPassName()) << " failed to converge."; - return signalPassFailure(); - } - } -}; -} // namespace - -} // namespace mlir::iree_compiler::Preprocessing diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel index d5ac29ee01cc..60f73989ce8a 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "attr_based_pipeline.mlir", "conv2d_to_img2col.mlir", @@ -29,7 +30,6 @@ iree_lit_test_suite( "pdl_example.mlir", "preprocessing_match_ops.mlir", "transform_symbol_importing.mlir", - "sink_transpose_through_pad.mlir", "transpose_matmul.mlir", ], include = ["*.mlir"], diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt index 598ddbcb7692..f0e2bad5da0a 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt @@ -26,7 +26,6 @@ iree_lit_test_suite( "pad_to_intrinsics_wmma.mlir" "pdl_example.mlir" "preprocessing_match_ops.mlir" - "sink_transpose_through_pad.mlir" "transform_symbol_importing.mlir" "transpose_matmul.mlir" TOOLS diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/conv_filter_to_channels_last.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/conv_filter_to_channels_last.mlir index 16fc3177e99f..36e118dc314c 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/conv_filter_to_channels_last.mlir +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/conv_filter_to_channels_last.mlir @@ -163,6 +163,24 @@ util.func public @conv_2d_chwn_chwf_no_transpose(%arg0: tensor<16x26x18x288xf32> // ----- +#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d2 + d5, d3 + d6)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d0, d5, d6)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +util.func public @conv_2d_cnhw_cfhw_no_transpose(%arg0: tensor<16x288x26x18xf32>, %arg1: tensor<16x288x24x16xf32>, %arg2: tensor<288x288x3x3xf32>) -> tensor<288x288x3x3xf32> { + %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<16x288x26x18xf32>, tensor<16x288x24x16xf32>) outs(%arg2 : tensor<288x288x3x3xf32>) { + ^bb0(%in: f32, %in_3: f32, %out: f32): + %12 = arith.mulf %in, %in_3 : f32 + %13 = arith.addf %out, %12 : f32 + linalg.yield %13 : f32 + } -> tensor<288x288x3x3xf32> + util.return %0 : tensor<288x288x3x3xf32> +} + +// CHECK-FHWC-LABEL: @conv_2d_cnhw_cfhw_no_transpose +// CHECK-FHWC-NOT: linalg.transpose + +// ----- + #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d4)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/sink_transpose_through_pad.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/sink_transpose_through_pad.mlir deleted file mode 100644 index 223100802681..000000000000 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/sink_transpose_through_pad.mlir +++ /dev/null @@ -1,17 +0,0 @@ -// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-preprocessing-sink-transpose-through-pad))" --split-input-file %s | FileCheck %s - -util.func public @sink_pad_through_transpose(%arg0 : tensor<16x64x64x128xf16>) -> (tensor<16x128x66x66xf16>) { - %2 = tensor.empty() : tensor<16x128x64x64xf16> - %cst = arith.constant 0.000000e+00 : f16 - %transposed = linalg.transpose ins(%arg0 : tensor<16x64x64x128xf16>) outs(%2 : tensor<16x128x64x64xf16>) permutation = [0, 3, 1, 2] - %padded = tensor.pad %transposed low[0, 0, 1, 1] high[0, 0, 1, 1] { - ^bb0(%arg5: index, %arg6: index, %arg7: index, %arg8: index): - tensor.yield %cst : f16 - } : tensor<16x128x64x64xf16> to tensor<16x128x66x66xf16> - util.return %padded : tensor<16x128x66x66xf16> -} -// CHECK-LABEL: util.func public @sink_pad_through_transpose -// CHECK: %[[PAD:.+]] = tensor.pad -// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose -// CHECK-SAME: ins(%[[PAD]] -// CHECK: util.return %[[TRANSPOSE]] diff --git a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel index 9a8958a13f22..a49cab6a6e63 100644 --- a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "PreprocessingExtensionsOps.td", ], diff --git a/compiler/src/iree/compiler/Reducer/Framework/Delta.cpp b/compiler/src/iree/compiler/Reducer/Framework/Delta.cpp index 7bab516d8e80..18b1ff3f5581 100644 --- a/compiler/src/iree/compiler/Reducer/Framework/Delta.cpp +++ b/compiler/src/iree/compiler/Reducer/Framework/Delta.cpp @@ -102,8 +102,9 @@ void Delta::runDeltaPass(DeltaFunc deltaFunc, StringRef message) { for (Chunk chunk : maybeInteresting) { FailureOr result = checkChunk(chunk, deltaFunc, maybeInteresting, uninterestingChunks); - if (failed(result)) + if (failed(result)) { continue; + } // Removing this chunk is still interesting. Mark this chunk as // uninteresting. diff --git a/compiler/src/iree/compiler/Reducer/Framework/WorkItem.h b/compiler/src/iree/compiler/Reducer/Framework/WorkItem.h index 01e7e07757a2..39476e66be55 100644 --- a/compiler/src/iree/compiler/Reducer/Framework/WorkItem.h +++ b/compiler/src/iree/compiler/Reducer/Framework/WorkItem.h @@ -22,8 +22,9 @@ class WorkItem { /// TODO(Groverkss): Ownership of module should be conveyed here via /// mlir::OwningOpReference. void replaceModule(ModuleOp newModule) { - if (root) + if (root) { root->erase(); + } root = newModule; } diff --git a/compiler/src/iree/compiler/Reducer/Strategies/ReduceLinalgOnTensorsDelta.cpp b/compiler/src/iree/compiler/Reducer/Strategies/ReduceLinalgOnTensorsDelta.cpp index c2eef672c2cc..4d66462d37c5 100644 --- a/compiler/src/iree/compiler/Reducer/Strategies/ReduceLinalgOnTensorsDelta.cpp +++ b/compiler/src/iree/compiler/Reducer/Strategies/ReduceLinalgOnTensorsDelta.cpp @@ -28,8 +28,9 @@ void mlir::iree_compiler::Reducer::reduceLinalgOnTensorsDelta( SmallVector linalgOps; SmallVector keepOps; module.walk([&](linalg::LinalgOp op) { - if (!op.hasPureTensorSemantics()) + if (!op.hasPureTensorSemantics()) { return; + } // Op should have at least one tensor input, otherwise the operation is // already a fill-like operation. // TODO(Groverkss): Explore if we can remove in this case too. @@ -41,14 +42,17 @@ void mlir::iree_compiler::Reducer::reduceLinalgOnTensorsDelta( } } - if (!hasAtleastOneTensorInput) + if (!hasAtleastOneTensorInput) { return; + } // There should be only 1 tensor output. - if (op.getNumDpsInits() != 1) + if (op.getNumDpsInits() != 1) { return; - if (!isa(op.getDpsInitOperand(0)->get().getType())) + } + if (!isa(op.getDpsInitOperand(0)->get().getType())) { return; + } if (!chunker.shouldFeatureBeKept()) { linalgOps.push_back(op); @@ -84,8 +88,9 @@ void mlir::iree_compiler::Reducer::reduceLinalgOnTensorsDelta( if (outType.hasStaticShape()) { for (auto *input : linalgOp.getDpsInputOperands()) { auto inType = dyn_cast(input->get().getType()); - if (!inType) + if (!inType) { continue; + } // Check if we can replace an input directly with the output. if (inType == outType) { @@ -124,6 +129,7 @@ void mlir::iree_compiler::Reducer::reduceLinalgOnTensorsDelta( pm.addPass(createCanonicalizerPass()); // Remove dead globals. pm.addPass(createSymbolDCEPass()); - if (failed(pm.run(module))) + if (failed(pm.run(module))) { return; + } } diff --git a/compiler/src/iree/compiler/Tools/iree_compile_lib.cc b/compiler/src/iree/compiler/Tools/iree_compile_lib.cc index 378fd373d26c..1ebfc8d7747e 100644 --- a/compiler/src/iree/compiler/Tools/iree_compile_lib.cc +++ b/compiler/src/iree/compiler/Tools/iree_compile_lib.cc @@ -52,9 +52,10 @@ struct BytecodeVersionParser : public llvm::cl::parser> { bool parse(llvm::cl::Option &O, StringRef /*argName*/, StringRef arg, std::optional &v) { long long w; - if (llvm::getAsSignedInteger(arg, 10, w)) + if (llvm::getAsSignedInteger(arg, 10, w)) { return O.error("Invalid argument '" + arg + "', only integer is supported."); + } v = w; return false; } @@ -264,30 +265,37 @@ int mlir::iree_compiler::runIreecMain(int argc, char **argv) { remarksOutputFile.c_str()); } - if (!ireeCompilerInvocationParseSource(r.inv, source)) + if (!ireeCompilerInvocationParseSource(r.inv, source)) { return false; + } // Switch on compileMode to choose a pipeline to run. switch (compileMode) { case CompileMode::std: - if (!ireeCompilerInvocationPipeline(r.inv, IREE_COMPILER_PIPELINE_STD)) + if (!ireeCompilerInvocationPipeline(r.inv, IREE_COMPILER_PIPELINE_STD)) { return false; + } break; case CompileMode::vm: + if (!ireeCompilerInvocationPipeline(r.inv, IREE_COMPILER_PIPELINE_VM)) { + return false; + } break; case CompileMode::hal_executable: { // Compiling a HAL executable, it is only valid to output in that form. outputFormat = OutputFormat::hal_executable; if (!ireeCompilerInvocationPipeline( - r.inv, IREE_COMPILER_PIPELINE_HAL_EXECUTABLE)) + r.inv, IREE_COMPILER_PIPELINE_HAL_EXECUTABLE)) { return false; + } break; } case CompileMode::precompile: { outputFormat = OutputFormat::precompile; if (!ireeCompilerInvocationPipeline(r.inv, - IREE_COMPILER_PIPELINE_PRECOMPILE)) + IREE_COMPILER_PIPELINE_PRECOMPILE)) { return false; + } break; } default: @@ -371,8 +379,9 @@ int mlir::iree_compiler::runIreecMain(int argc, char **argv) { return 1; } } else { - if (!processBuffer(s.source)) + if (!processBuffer(s.source)) { return 1; + } } ireeCompilerOutputKeep(s.output); diff --git a/compiler/src/iree/compiler/Utils/BUILD.bazel b/compiler/src/iree/compiler/Utils/BUILD.bazel index 36f0d197021a..050bed8656ff 100644 --- a/compiler/src/iree/compiler/Utils/BUILD.bazel +++ b/compiler/src/iree/compiler/Utils/BUILD.bazel @@ -18,6 +18,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "CommonTypeConstraints.td", "DocMetadata.td", @@ -34,6 +35,7 @@ iree_compiler_cc_library( name = "Utils", srcs = [ "ConversionUtils.cpp", + "EncodingUtils.cpp", "EquivalenceUtils.cpp", "FlatbufferUtils.cpp", "Indexing.cpp", @@ -49,9 +51,9 @@ iree_compiler_cc_library( hdrs = [ "ConversionUtils.h", "EmbeddedDataDirectory.h", + "EncodingUtils.h", "EquivalenceUtils.h", "FlatbufferUtils.h", - "Folding.h", "Indexing.h", "IntegerSet.h", "ModuleUtils.h", @@ -77,6 +79,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:BytecodeWriter", "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FunctionInterfaces", diff --git a/compiler/src/iree/compiler/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Utils/CMakeLists.txt index 6dd027fc39a4..7842841b068e 100644 --- a/compiler/src/iree/compiler/Utils/CMakeLists.txt +++ b/compiler/src/iree/compiler/Utils/CMakeLists.txt @@ -16,9 +16,9 @@ iree_cc_library( HDRS "ConversionUtils.h" "EmbeddedDataDirectory.h" + "EncodingUtils.h" "EquivalenceUtils.h" "FlatbufferUtils.h" - "Folding.h" "Indexing.h" "IntegerSet.h" "ModuleUtils.h" @@ -35,6 +35,7 @@ iree_cc_library( "TracingUtils.h" SRCS "ConversionUtils.cpp" + "EncodingUtils.cpp" "EquivalenceUtils.cpp" "FlatbufferUtils.cpp" "Indexing.cpp" @@ -52,6 +53,7 @@ iree_cc_library( MLIRAffineDialect MLIRAnalysis MLIRArithDialect + MLIRBytecodeWriter MLIRFuncDialect MLIRFunctionInterfaces MLIRIR diff --git a/compiler/src/iree/compiler/Utils/ConversionUtils.cpp b/compiler/src/iree/compiler/Utils/ConversionUtils.cpp index 78fd1a93960d..3966e160830d 100644 --- a/compiler/src/iree/compiler/Utils/ConversionUtils.cpp +++ b/compiler/src/iree/compiler/Utils/ConversionUtils.cpp @@ -45,8 +45,9 @@ LogicalResult verifyAllOperationsAreLegal(Operation *op, illegalOps.insert(op); } }); - if (illegalOps.empty()) + if (illegalOps.empty()) { return success(); + } emitLegalizationErrors(op->getLoc(), illegalOps); return failure(); } @@ -60,14 +61,16 @@ Attribute convertAttribute(Location loc, Attribute oldAttr, // Return the same attribute if it doesn't have a type. auto typedOldAttr = dyn_cast(oldAttr); - if (!typedOldAttr) + if (!typedOldAttr) { return oldAttr; + } // Convert the attribute type - if it's the same then it's already legal. auto oldType = typedOldAttr.getType(); auto newType = typeConverter.convertType(oldType); - if (oldType == newType) + if (oldType == newType) { return typedOldAttr; + } if (auto intAttr = dyn_cast(typedOldAttr)) { APInt value = intAttr.getValue(); diff --git a/compiler/src/iree/compiler/Utils/EncodingUtils.cpp b/compiler/src/iree/compiler/Utils/EncodingUtils.cpp new file mode 100644 index 000000000000..d6a5c51d6291 --- /dev/null +++ b/compiler/src/iree/compiler/Utils/EncodingUtils.cpp @@ -0,0 +1,86 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Utils/EncodingUtils.h" + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpImplementation.h" + +namespace mlir::iree_compiler { + +/// Parse a list of integer values and/or dynamic values ('?') +FailureOr> parseDynamicI64IntegerList(AsmParser &parser) { + SmallVector integerVals; + if (failed(parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&] { + int64_t value = ShapedType::kDynamic; + if (failed(parser.parseOptionalQuestion()) && + failed(parser.parseInteger(value))) { + return failure(); + } + integerVals.push_back(value); + return success(); + }))) { + return failure(); + } + return integerVals; +} + +/// Print a list of integer values and/or dynamic values ('?') +void printDynamicI64IntegerList(AsmPrinter &printer, ArrayRef vals) { + printer << "["; + llvm::interleaveComma(vals, printer, [&](int64_t val) { + if (ShapedType::isDynamic(val)) { + printer << "?"; + } else { + printer << val; + } + }); + printer << "]"; +} + +/// Parse a list of integer values and/or dynamic values ('?') into an ArrayAttr +ParseResult parseDynamicI64ArrayAttr(AsmParser &parser, ArrayAttr &attr) { + FailureOr> integerVals = + parseDynamicI64IntegerList(parser); + if (failed(integerVals)) { + return failure(); + } + auto integerValsAttr = + llvm::map_to_vector(integerVals.value(), [&](int64_t val) -> Attribute { + return IntegerAttr::get(IntegerType::get(parser.getContext(), 64), val); + }); + attr = ArrayAttr::get(parser.getContext(), integerValsAttr); + return success(); +} + +/// Print an ArrayAttr of integer values and/or dynamic values ('?') +void printDynamicI64ArrayAttr(AsmPrinter &printer, ArrayAttr attrs) { + SmallVector intVals = llvm::map_to_vector( + attrs, [&](Attribute attr) { return cast(attr).getInt(); }); + return printDynamicI64IntegerList(printer, intVals); +} + +/// Parse a list of integer values and/or dynamic values ('?') into a +/// DenseI64ArrayAttr +ParseResult parseDynamicI64DenseArrayAttr(AsmParser &parser, + DenseI64ArrayAttr &attr) { + FailureOr> integerVals = + parseDynamicI64IntegerList(parser); + if (failed(integerVals)) { + return failure(); + } + attr = DenseI64ArrayAttr::get(parser.getContext(), *integerVals); + return success(); +} + +/// Print a DenseI64ArrayAttr as a list of integer values and/or dynamic values +/// ('?') +void printDynamicI64DenseArrayAttr(AsmPrinter &printer, + DenseI64ArrayAttr attr) { + printDynamicI64IntegerList(printer, attr.asArrayRef()); +} + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Utils/EncodingUtils.h b/compiler/src/iree/compiler/Utils/EncodingUtils.h new file mode 100644 index 000000000000..a6e52f07ab8c --- /dev/null +++ b/compiler/src/iree/compiler/Utils/EncodingUtils.h @@ -0,0 +1,39 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_UTILS_ENCODINGUTILS_H_ +#define IREE_COMPILER_UTILS_ENCODINGUTILS_H_ + +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/BuiltinAttributes.h" + +namespace mlir::iree_compiler { + +/// Parse a list of integer values and/or dynamic values ('?') +FailureOr> parseDynamicI64IntegerList(AsmParser &parser); + +/// Print a list of integer values and/or dynamic values ('?') +void printDynamicI64IntegerList(AsmPrinter &printer, ArrayRef vals); + +/// Parse a list of integer values and/or dynamic values ('?') into an ArrayAttr +ParseResult parseDynamicI64ArrayAttr(AsmParser &parser, ArrayAttr &attr); + +/// Print an ArrayAttr of integer values and/or dynamic values ('?') +void printDynamicI64ArrayAttr(AsmPrinter &printer, ArrayAttr attrs); + +/// Parse a list of integer values and/or dynamic values ('?') into a +/// DenseI64ArrayAttr +ParseResult parseDynamicI64DenseArrayAttr(AsmParser &parser, + DenseI64ArrayAttr &attr); + +/// Print a DenseI64ArrayAttr as a list of integer values and/or dynamic values +/// ('?') +void printDynamicI64DenseArrayAttr(AsmPrinter &printer, DenseI64ArrayAttr attr); + +} // namespace mlir::iree_compiler + +#endif // IREE_COMPILER_UTILS_ENCODINGUTILS_H_ diff --git a/compiler/src/iree/compiler/Utils/EquivalenceUtils.cpp b/compiler/src/iree/compiler/Utils/EquivalenceUtils.cpp index debfe478d4f6..ea0c9bbc8681 100644 --- a/compiler/src/iree/compiler/Utils/EquivalenceUtils.cpp +++ b/compiler/src/iree/compiler/Utils/EquivalenceUtils.cpp @@ -21,14 +21,18 @@ OperationEquivalenceCache::OperationEquivalenceCache(MLIRContext *context) StringAttr::get(context, SymbolTable::getSymbolAttrName())) {} OperationEquivalenceCache::~OperationEquivalenceCache() { - for (auto *mapping : mappingFreeList) + for (auto *mapping : mappingFreeList) { delete mapping; - for (auto region : regions) + } + for (auto region : regions) { delete region.second; - for (auto block : blocks) + } + for (auto block : blocks) { delete block.second; - for (auto op : ops) + } + for (auto op : ops) { delete op.second; + } } bool OperationEquivalenceCache::isSymbolAttrName(StringAttr name) const { @@ -52,8 +56,9 @@ OperationEquivalenceCache::acquireMapping() { OperationEquivalenceCache::RegionEntry & OperationEquivalenceCache::getRegion(Region *region) { auto it = regions.find(region); - if (it != regions.end()) + if (it != regions.end()) { return *it->second; + } RegionEntry *entry = new RegionEntry(); for (Block &block : region->getBlocks()) { llvm::ReversePostOrderTraversal traversal(&block); @@ -66,8 +71,9 @@ OperationEquivalenceCache::getRegion(Region *region) { OperationEquivalenceCache::BlockEntry & OperationEquivalenceCache::getBlock(Block *block) { auto it = blocks.find(block); - if (it != blocks.end()) + if (it != blocks.end()) { return *it->second; + } BlockEntry *entry = new BlockEntry(); entry->count = block->getOperations().size(); blocks[block] = entry; @@ -77,8 +83,9 @@ OperationEquivalenceCache::getBlock(Block *block) { OperationEquivalenceCache::OperationEntry & OperationEquivalenceCache::getOp(Operation *op) { auto it = ops.find(op); - if (it != ops.end()) + if (it != ops.end()) { return *it->second; + } OperationEntry *entry = new OperationEntry(); entry->attrs.append(op->getRawDictionaryAttrs().getValue()); if (op->getPropertiesStorageSize()) { @@ -95,8 +102,9 @@ bool compare_ranges(Range &&lhs, Range &&rhs, Pred pred) { auto lhsEnd = lhs.end(); auto rhsEnd = rhs.end(); while (lhsIt != lhsEnd && rhsIt != rhsEnd) { - if (!pred(*lhsIt++, *rhsIt++)) + if (!pred(*lhsIt++, *rhsIt++)) { return false; + } } if ((lhsIt == lhsEnd) != (rhsIt == rhsEnd)) { // Block count mismatch. We do this here so that we avoid the O(n) scan @@ -157,18 +165,21 @@ bool isStructurallyEquivalentTo(OperationEquivalenceCache &cache, Region &lhs, Region &rhs, IRMapping &mapping) { auto &lhsRegionEntry = cache.getRegion(&lhs); auto &rhsRegionEntry = cache.getRegion(&rhs); - if (lhsRegionEntry.blocks.size() != rhsRegionEntry.blocks.size()) + if (lhsRegionEntry.blocks.size() != rhsRegionEntry.blocks.size()) { return false; + } // Map blocks and their arguments so that we can compare their use by ops. for (auto [lhsBlock, rhsBlock] : llvm::zip_equal(lhsRegionEntry.blocks, rhsRegionEntry.blocks)) { - if (lhsBlock->getNumArguments() != rhsBlock->getNumArguments()) + if (lhsBlock->getNumArguments() != rhsBlock->getNumArguments()) { return false; + } for (auto [lhsArg, rhsArg] : llvm::zip_equal(lhsBlock->getArguments(), rhsBlock->getArguments())) { - if (lhsArg.getType() != rhsArg.getType()) + if (lhsArg.getType() != rhsArg.getType()) { return false; + } mapping.map(lhsArg, rhsArg); } mapping.map(lhsBlock, rhsBlock); @@ -180,13 +191,15 @@ bool isStructurallyEquivalentTo(OperationEquivalenceCache &cache, Region &lhs, llvm::zip_equal(lhsRegionEntry.blocks, rhsRegionEntry.blocks)) { const auto &lhsBlockEntry = cache.getBlock(lhsBlock); const auto &rhsBlockEntry = cache.getBlock(rhsBlock); - if (lhsBlockEntry.count != rhsBlockEntry.count) + if (lhsBlockEntry.count != rhsBlockEntry.count) { return false; + } for (auto [lhsOp, rhsOp] : llvm::zip_equal(lhsBlock->getOperations(), rhsBlock->getOperations())) { - if (!isStructurallyEquivalentTo(cache, lhsOp, rhsOp, mapping)) + if (!isStructurallyEquivalentTo(cache, lhsOp, rhsOp, mapping)) { return false; + } } } @@ -210,13 +223,15 @@ static bool isStructurallyEquivalentTo(OperationEquivalenceCache &cache, auto &rhsEntry = cache.getOp(&rhs); // TODO(#3996): symbol mapping; for now allow them to differ unconditionally. - if (lhsEntry.attrs.getAttrs().size() != rhsEntry.attrs.getAttrs().size()) + if (lhsEntry.attrs.getAttrs().size() != rhsEntry.attrs.getAttrs().size()) { return false; + } for (auto [lhsAttr, rhsAttr] : llvm::zip_equal(lhsEntry.attrs, rhsEntry.attrs)) { if (!cache.isSymbolAttrName(lhsAttr.getName())) { - if (lhsAttr != rhsAttr) + if (lhsAttr != rhsAttr) { return false; + } } } @@ -224,8 +239,9 @@ static bool isStructurallyEquivalentTo(OperationEquivalenceCache &cache, // in the mapping already from the parent region to do the lhs->rhs mapping. for (auto [lhsSuccessor, rhsSuccessor] : llvm::zip_equal(lhs.getSuccessors(), rhs.getSuccessors())) { - if (rhsSuccessor != parentMapping.lookup(lhsSuccessor)) + if (rhsSuccessor != parentMapping.lookup(lhsSuccessor)) { return false; + } } // Ensure result types match first and add to the block and value mapping. @@ -234,8 +250,9 @@ static bool isStructurallyEquivalentTo(OperationEquivalenceCache &cache, // exit prior to the full traversal. for (auto [lhsValue, rhsValue] : llvm::zip_equal(lhs.getResults(), rhs.getResults())) { - if (lhsValue.getType() != rhsValue.getType()) + if (lhsValue.getType() != rhsValue.getType()) { return false; + } parentMapping.map(lhsValue, rhsValue); } @@ -243,10 +260,12 @@ static bool isStructurallyEquivalentTo(OperationEquivalenceCache &cache, // these values they should already be defined in the mapping. for (auto [lhsValue, rhsValue] : llvm::zip_equal(lhs.getOperands(), rhs.getOperands())) { - if (lhsValue.getType() != rhsValue.getType()) + if (lhsValue.getType() != rhsValue.getType()) { return false; - if (rhsValue != parentMapping.lookup(lhsValue)) + } + if (rhsValue != parentMapping.lookup(lhsValue)) { return false; + } } // Recurse into regions. diff --git a/compiler/src/iree/compiler/Utils/FlatbufferUtils.h b/compiler/src/iree/compiler/Utils/FlatbufferUtils.h index 697556e16df0..eddd13b92274 100644 --- a/compiler/src/iree/compiler/Utils/FlatbufferUtils.h +++ b/compiler/src/iree/compiler/Utils/FlatbufferUtils.h @@ -61,8 +61,9 @@ class FlatbufferBuilder { auto stringRefs = llvm::map_to_vector<8>(Range, [&](StringRef value) { return flatbuffers_string_create(*this, value.data(), value.size()); }); - if (stringRefs.empty()) + if (stringRefs.empty()) { return 0; + } return flatbuffers_string_vec_create(*this, stringRefs.data(), stringRefs.size()); } @@ -70,8 +71,9 @@ class FlatbufferBuilder { // Creates an offset vector with the given values. The source values will not // be modified. flatbuffers_vec_ref_t createOffsetVec(ArrayRef values) { - if (values.empty()) + if (values.empty()) { return 0; + } return flatcc_builder_create_offset_vector(*this, values.data(), values.size()); } @@ -81,8 +83,9 @@ class FlatbufferBuilder { // serialization but be much faster. flatbuffers_vec_ref_t createOffsetVecDestructive(SmallVectorImpl &values) { - if (values.empty()) + if (values.empty()) { return 0; + } return flatcc_builder_create_offset_vector_direct(*this, values.data(), values.size()); } @@ -90,8 +93,9 @@ class FlatbufferBuilder { // Creates an [int32] vec with the contents of the given range. template flatbuffers_int32_vec_ref_t createInt32Vec(RangeTy &&Range) { - if (Range.empty()) + if (Range.empty()) { return 0; + } flatbuffers_int32_vec_start(*this); for (int32_t v : Range) { flatbuffers_int32_vec_push_create(*this, v); diff --git a/compiler/src/iree/compiler/Utils/Folding.h b/compiler/src/iree/compiler/Utils/Folding.h deleted file mode 100644 index 02184cf9f084..000000000000 --- a/compiler/src/iree/compiler/Utils/Folding.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2024 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_COMPILER_UTILS_FOLDING_H_ -#define IREE_COMPILER_UTILS_FOLDING_H_ - -#include -#include - -#include "llvm/ADT/STLExtras.h" -#include "mlir/IR/OpDefinition.h" - -namespace mlir::iree_compiler { - -// Convert a `Value` or an `Attribute` range to a range of `OpFoldResult`. -template -void toOpFoldResults(Range &&range, OutIt outIt) { - llvm::transform(std::forward(range), outIt, - [](auto v) { return OpFoldResult(v); }); -} - -template -SmallVector toOpFoldResults(Range &&range) { - SmallVector res; - toOpFoldResults(std::forward(range), std::back_inserter(res)); - return res; -} - -} // namespace mlir::iree_compiler - -#endif // IREE_COMPILER_UTILS_FOLDING_H_ diff --git a/compiler/src/iree/compiler/Utils/Indexing.cpp b/compiler/src/iree/compiler/Utils/Indexing.cpp index c321c20cc3a2..b27ff4cfd55d 100644 --- a/compiler/src/iree/compiler/Utils/Indexing.cpp +++ b/compiler/src/iree/compiler/Utils/Indexing.cpp @@ -33,8 +33,9 @@ LogicalResult basisFromSizesStrides(ArrayRef sizes, stride = 1; size = 1; } - if (stride % previousSizes != 0) + if (stride % previousSizes != 0) { return failure(); + } // Handle casis like threads = {4, 8}, strides = {1, 16}, which need an // extra basis element. @@ -56,8 +57,9 @@ LogicalResult basisFromSizesStrides(ArrayRef sizes, size_t basisLength = basis.size(); dimToResult.assign(numDims, ~0); for (auto [reverseBasisPos, dimPos] : llvm::enumerate(basisEntryToDim)) { - if (!dimPos) + if (!dimPos) { continue; + } // There's an extra overflow term at the front of the delineraize results, // so this subtraction lands in the [1, basisLength] range we need it // to be in. diff --git a/compiler/src/iree/compiler/Utils/ModuleUtils.cpp b/compiler/src/iree/compiler/Utils/ModuleUtils.cpp index f4972a5aeb21..c57dc3ba5420 100644 --- a/compiler/src/iree/compiler/Utils/ModuleUtils.cpp +++ b/compiler/src/iree/compiler/Utils/ModuleUtils.cpp @@ -7,13 +7,17 @@ #include "iree/compiler/Utils/ModuleUtils.h" #include "iree/compiler/Utils/StringUtils.h" +#include "llvm/Support/FileSystem.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/Path.h" +#include "llvm/Support/ToolOutputFile.h" +#include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Parser/Parser.h" +#include "mlir/Support/FileUtilities.h" #include "mlir/Support/LLVM.h" namespace mlir::iree_compiler { @@ -27,22 +31,26 @@ std::optional findFirstFileLoc(Location baseLoc) { // Recurse through fused locations. for (auto &childLoc : loc.getLocations()) { auto childResult = findFirstFileLoc(childLoc); - if (childResult) + if (childResult) { return childResult; + } } } else if (auto loc = dyn_cast(baseLoc)) { // First check caller... auto callerResult = findFirstFileLoc(loc.getCaller()); - if (callerResult) + if (callerResult) { return callerResult; + } // Then check callee... auto calleeResult = findFirstFileLoc(loc.getCallee()); - if (calleeResult) + if (calleeResult) { return calleeResult; + } } else if (auto loc = dyn_cast(baseLoc)) { auto childResult = findFirstFileLoc(loc.getChildLoc()); - if (childResult) + if (childResult) { return childResult; + } } else if (auto loc = dyn_cast(baseLoc)) { // TODO(scotttodd): Use loc.fallbackLocation()? } else if (auto loc = dyn_cast(baseLoc)) { @@ -54,8 +62,9 @@ std::optional findFirstFileLoc(Location baseLoc) { std::string guessModuleName(mlir::ModuleOp moduleOp, StringRef defaultName) { std::string moduleName = moduleOp.getName().value_or("").str(); - if (!moduleName.empty()) + if (!moduleName.empty()) { return moduleName; + } auto loc = findFirstFileLoc(moduleOp.getLoc()); if (loc.has_value()) { return sanitizeSymbolName( @@ -148,8 +157,9 @@ LogicalResult mergeModuleInto(Operation *sourceModuleOp, // Resolve conflicts and move the op. for (auto &sourceOp : sourceOps) { - if (sourceOp->hasTrait()) + if (sourceOp->hasTrait()) { continue; + } if (auto symbolOp = dyn_cast(sourceOp)) { auto symbolName = symbolOp.getName(); @@ -221,4 +231,34 @@ LogicalResult mergeSourceModuleInto(Location loc, StringRef source, return mergeModuleInto(*sourceModuleRef, targetOp, targetBuilder); } +LogicalResult writeModule(mlir::ModuleOp moduleOp, StringRef path) { + // Ensure the parent paths exist. + llvm::sys::fs::create_directories(llvm::sys::path::parent_path(path)); + + // Attempt to open file - should succeed as long as permissions are ok. + std::string error; + auto file = mlir::openOutputFile(path, &error); + if (!file) { + return mlir::emitError(moduleOp.getLoc()) + << "while dumping to '" << path << "': " << error << "\n"; + } + + // If going to binary serialize out and otherwise print as text. + if (llvm::sys::path::extension(path) == ".mlirbc") { + BytecodeWriterConfig config; + if (failed(mlir::writeBytecodeToFile(moduleOp, file->os(), config))) { + return mlir::emitError(moduleOp.getLoc()) + << "failed to serialize module to '" << path << "'\n"; + } + } else { + OpPrintingFlags flags; + moduleOp.print(file->os(), flags); + } + + // Keep the temporary file after the write succeeds. + file->keep(); + + return success(); +} + } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Utils/ModuleUtils.h b/compiler/src/iree/compiler/Utils/ModuleUtils.h index 882266bfaa35..18b0765b2f6b 100644 --- a/compiler/src/iree/compiler/Utils/ModuleUtils.h +++ b/compiler/src/iree/compiler/Utils/ModuleUtils.h @@ -38,6 +38,10 @@ LogicalResult mergeSourceModuleInto(Location loc, StringRef source, Operation *targetOp, OpBuilder &targetBuilder); +// Writes |moduleOp| to the file at |path|. +// The module will be written as MLIR text unless it has the .mlirbc extension. +LogicalResult writeModule(mlir::ModuleOp moduleOp, StringRef path); + } // namespace mlir::iree_compiler #endif // IREE_COMPILER_UTILS_MODULEUTILS_H_ diff --git a/compiler/src/iree/compiler/Utils/OptionUtils.cpp b/compiler/src/iree/compiler/Utils/OptionUtils.cpp index 9d0678f81b67..3ed11228ef47 100644 --- a/compiler/src/iree/compiler/Utils/OptionUtils.cpp +++ b/compiler/src/iree/compiler/Utils/OptionUtils.cpp @@ -78,10 +78,12 @@ llvm::SmallVector OptionsBinder::printArguments(bool nonDefaultOnly) { llvm::SmallVector values; for (auto &[flag, info] : getOptionsStorage()) { - if (!info.print) + if (!info.print) { continue; - if (nonDefaultOnly && !info.isDefault()) + } + if (nonDefaultOnly && !info.isDefault()) { continue; + } std::string s; llvm::raw_string_ostream os(s); diff --git a/compiler/src/iree/compiler/Utils/OptionUtils.h b/compiler/src/iree/compiler/Utils/OptionUtils.h index c6fa25cff27b..bb707fca7c6b 100644 --- a/compiler/src/iree/compiler/Utils/OptionUtils.h +++ b/compiler/src/iree/compiler/Utils/OptionUtils.h @@ -36,8 +36,9 @@ struct opt_initializer { : parentName(parentName), init(val), optLevel(opt) {} void apply(const llvm::OptimizationLevel inLevel, Ty &val) const { assert(inLevel.getSizeLevel() == 0 && "size level not implemented"); - if (inLevel.getSpeedupLevel() >= optLevel.getSpeedupLevel()) + if (inLevel.getSpeedupLevel() >= optLevel.getSpeedupLevel()) { val = init; + } } /// Append to the description string of the flag. @@ -69,6 +70,18 @@ struct opt_scope { } }; +// Modifier to mark an option as deprecated with a warning message. +// When the option is parsed, a deprecation warning will be printed to stderr. +// The apply method is a no-op since OptionsBinder handles the deprecation +// warning through the callback mechanism, but it's required because LLVM's +// applicator will try to call it when the modifier is forwarded. +struct deprecated { + llvm::StringRef message; + explicit deprecated(llvm::StringRef msg) : message(msg) {} + template + void apply(Opt &) const {} +}; + // Base class that can bind named options to fields of structs. // // Typically use by adding the following to your struct: @@ -98,7 +111,8 @@ class OptionsBinder { template void opt(llvm::StringRef name, V &value, Mods... Ms) { - auto [changedCallback, clCallback] = makeChangedCallback(); + const deprecated *dep = filterDeprecated(Ms...); + auto [changedCallback, clCallback] = makeChangedCallback(name, dep); OptionInfo &info = getOptionsStorage()[name]; if (!scope) { // Bind global options. @@ -188,8 +202,9 @@ class OptionsBinder { void restoreOptimizationDefaults() { for (auto &[_, info] : getOptionsStorage()) { - if (info.optOverrides) + if (info.optOverrides) { info.optOverrides->restoreBackup(); + } } } @@ -402,14 +417,25 @@ class OptionsBinder { // Returns a pair of callbacks, the first returns if the option has been // parsed and the second is passed to llvm::cl to track if the option has been - // parsed. + // parsed. If a deprecation message is provided, it will be printed to stderr + // when the option is parsed. template static std::pair> - makeChangedCallback() { + makeChangedCallback(llvm::StringRef name = "", + const deprecated *dep = nullptr) { std::shared_ptr changed = std::make_shared(false); + // Capture name and message by value for lambda lifetime. + std::string optName = name.str(); + std::string depMsg = dep ? dep->message.str() : ""; return std::pair{ [changed]() -> bool { return *changed; }, - llvm::cl::cb([changed](const V &) { *changed = true; })}; + llvm::cl::cb([changed, optName, depMsg](const V &) { + *changed = true; + if (!depMsg.empty()) { + llvm::errs() << "warning: --" << optName << " is deprecated; " + << depMsg << "\n"; + } + })}; } // Scalar default specialization. @@ -439,14 +465,15 @@ class OptionsBinder { return [optionName, values](llvm::raw_ostream &os) { os << "--" << optionName << "="; for (auto it : llvm::enumerate(*values)) { - if (it.index() > 0) + if (it.index() > 0) { os << ","; + } os << it.value(); } }; } - // Finds the description in args + // Finds the description in args. template static llvm::cl::desc &filterDescription(Args &...args) { llvm::cl::desc *result = nullptr; @@ -454,8 +481,9 @@ class OptionsBinder { [&] { if constexpr (std::is_same_v, llvm::cl::desc>) { assert(!result && "Multiple llvm::cl::desc in args"); - if (!result) + if (!result) { result = &args; + } } }(), ...); @@ -463,6 +491,20 @@ class OptionsBinder { return *result; } + // Extracts deprecated modifier from args (returns nullptr if not found). + template + static const deprecated *filterDeprecated(const Args &...args) { + const deprecated *result = nullptr; + ( + [&] { + if constexpr (std::is_same_v, deprecated>) { + result = &args; + } + }(), + ...); + return result; + } + std::unique_ptr scope; OptionsStorage localOptions; diff --git a/compiler/src/iree/compiler/Utils/ToolUtils.cpp b/compiler/src/iree/compiler/Utils/ToolUtils.cpp index 6a8524074359..67860cf3d7c6 100644 --- a/compiler/src/iree/compiler/Utils/ToolUtils.cpp +++ b/compiler/src/iree/compiler/Utils/ToolUtils.cpp @@ -109,8 +109,9 @@ std::string findToolFromExecutableDir(SmallVector toolNames) { static std::string getCurrentDylibPath() { #if __linux__ || __APPLE__ Dl_info dlInfo; - if (dladdr((void *)getCurrentDylibPath, &dlInfo) == 0) + if (dladdr((void *)getCurrentDylibPath, &dlInfo) == 0) { return {}; + } return (dlInfo.dli_fname); #elif defined(WIN32) HMODULE hm = NULL; @@ -145,8 +146,9 @@ static std::string getCurrentDylibPath() { std::string findToolFromDylibDir(SmallVector toolNames) { const auto &normalizedToolNames = normalizeToolNames(toolNames); std::string dylibPath = getCurrentDylibPath(); - if (dylibPath.empty()) + if (dylibPath.empty()) { return {}; + } SmallString<256> dylibDir(dylibPath); llvm::sys::path::remove_filename(dylibDir); @@ -240,18 +242,21 @@ std::string findTool(SmallVector toolNames) { // TODO(benvanik): add a test for IREE_[toolName]_PATH. std::string dylibDirPath = findToolFromDylibDir(toolNames); - if (!dylibDirPath.empty()) + if (!dylibDirPath.empty()) { return dylibDirPath; + } // Search the install or build dir. std::string executableDirPath = findToolFromExecutableDir(toolNames); - if (!executableDirPath.empty()) + if (!executableDirPath.empty()) { return executableDirPath; + } // Currently fall back on searching the environment. std::string environmentPath = findToolInEnvironment(toolNames); - if (!environmentPath.empty()) + if (!environmentPath.empty()) { return environmentPath; + } return ""; } @@ -263,14 +268,16 @@ std::string findTool(std::string toolName) { std::string findPlatformLibDirectory(StringRef platformName) { std::string dylibPath = getCurrentDylibPath(); - if (dylibPath.empty()) + if (dylibPath.empty()) { return {}; + } SmallString<256> path(dylibPath); llvm::sys::path::remove_filename(path); llvm::sys::path::append(path, "iree_platform_libs", platformName); - if (!llvm::sys::fs::is_directory(path)) + if (!llvm::sys::fs::is_directory(path)) { return {}; + } llvm::sys::fs::make_absolute(path); (void)llvm::sys::path::remove_dots(path, /*remove_dot_dot=*/true); diff --git a/docs/website/docs/developers/debugging/fuzzing.md b/docs/website/docs/developers/debugging/fuzzing.md new file mode 100644 index 000000000000..4d533aecc914 --- /dev/null +++ b/docs/website/docs/developers/debugging/fuzzing.md @@ -0,0 +1,182 @@ +--- +icon: material/bug +--- + +# Fuzzing with libFuzzer + +[libFuzzer](https://llvm.org/docs/LibFuzzer.html) is a coverage-guided fuzzing +engine provided by LLVM. It generates random inputs and mutates them based on +code coverage feedback to find crashes, hangs, and memory errors. + +IREE provides build infrastructure for creating libFuzzer-based fuzz targets +that integrate with the existing build system. + +## When to use fuzzing + +Fuzzing is most effective for: + +- Parsers and decoders (UTF-8, binary formats, etc.) +- Serialization/deserialization code +- Input validation logic +- Any code that processes untrusted or external data + +## Enabling fuzzing builds + +### Bazel + +```shell +bazel build --config=fuzzer //runtime/src/iree/base/internal:unicode_fuzz +``` + +The `--config=fuzzer` flag enables coverage instrumentation and ASan. + +### CMake + +```shell +cmake -B build -DIREE_ENABLE_FUZZING=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo +cmake --build build --target unicode_fuzz +``` + +Fuzzing automatically enables ASan. Fuzz targets are excluded from the default +`all` target and must be built explicitly. + +## Running fuzz targets + +Fuzz targets are standalone executables that accept libFuzzer arguments: + +```shell +# Run indefinitely (Ctrl+C to stop) +./build/runtime/src/iree/base/internal/unicode_fuzz + +# Run for 60 seconds +./build/runtime/src/iree/base/internal/unicode_fuzz -max_total_time=60 + +# Use a corpus directory (recommended) +mkdir -p corpus/unicode +./build/runtime/src/iree/base/internal/unicode_fuzz corpus/unicode/ +``` + +### Common options + +Option | Description +------ | ----------- +`-max_total_time=N` | Stop after N seconds +`-max_len=N` | Maximum input size in bytes +`-timeout=N` | Per-input timeout in seconds (0 to disable) +`-jobs=N` | Run N parallel fuzzing jobs +`-workers=N` | Number of worker processes for parallel fuzzing +`-dict=file` | Use a dictionary file for structured inputs +`-seed=N` | Use specific random seed for reproducibility + +See [libFuzzer documentation](https://llvm.org/docs/LibFuzzer.html) for all +options. + +## Writing fuzz targets + +Fuzz targets implement the `LLVMFuzzerTestOneInput` function: + +```cpp +// my_fuzz.cc +#include +#include + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + // Process the fuzzer-generated input + my_function_under_test(data, size); + return 0; // Always return 0 +} +``` + +### Adding to the build system + +In `BUILD.bazel`: + +```python +load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_fuzz") + +iree_runtime_cc_fuzz( + name = "my_fuzz", + srcs = ["my_fuzz.cc"], + deps = [ + ":my_library", + ], +) +``` + +Then run `python build_tools/bazel_to_cmake/bazel_to_cmake.py` to generate the +CMake equivalent. + +## Best practices + +### Maintain a corpus + +Store interesting inputs in a corpus directory. The fuzzer uses existing corpus +entries as seeds for mutation: + +```shell +mkdir -p corpus/my_fuzz +./my_fuzz corpus/my_fuzz/ -max_total_time=3600 +``` + +After finding bugs, minimize the corpus to remove redundant entries: + +```shell +mkdir corpus/my_fuzz_minimized +./my_fuzz -merge=1 corpus/my_fuzz_minimized/ corpus/my_fuzz/ +``` + +### Add unit tests for found bugs + +When fuzzing discovers a crash: + +1. Minimize the reproducer: `./my_fuzz -minimize_crash=1 crash-xxx` +2. Add the minimized input as a unit test case +3. Fix the bug +4. Verify the fix with the original crash input + +This prevents regressions and documents the bug. + +### Use dictionaries for structured formats + +For inputs with specific syntax (protocols, file formats), provide a dictionary: + +```text +# my_dict.txt +"keyword1" +"keyword2" +"\x00\x01\x02" +``` + +```shell +./my_fuzz -dict=my_dict.txt corpus/ +``` + +## Troubleshooting + +### Fuzzer runs slowly + +- Ensure `CMAKE_BUILD_TYPE=RelWithDebInfo` or `Release` (Debug is very slow) +- Check that the target doesn't do excessive I/O or allocations per iteration +- Use `-jobs=N` for parallel fuzzing on multi-core machines + +### Out of memory + +- Limit input size with `-max_len=N` +- Add early returns for oversized inputs in your fuzz target +- Use `-rss_limit_mb=N` to set memory limits + +### No new coverage + +- Verify the target actually processes the input +- Check that coverage instrumentation is enabled (`-fsanitize=fuzzer-no-link`) +- Try seeding with representative inputs in the corpus + +### Timeout errors + +libFuzzer kills inputs that take too long (default 1200 seconds). If you see +`ALARM: working on the last Unit for N seconds` followed by a timeout: + +- Use `-timeout=N` to adjust the per-input timeout (in seconds) +- Use `-timeout=0` to disable timeouts entirely (useful for debugging) +- Check if certain inputs cause algorithmic complexity issues (e.g., pathological + regex patterns, deeply nested structures) diff --git a/docs/website/docs/developers/general/contributing.md b/docs/website/docs/developers/general/contributing.md index 74b82047832c..069b5b76a5c3 100644 --- a/docs/website/docs/developers/general/contributing.md +++ b/docs/website/docs/developers/general/contributing.md @@ -278,7 +278,7 @@ Access to repositories is divided into tiers following the | Tier | Description | Team links | | ---- | ----------- | --------- | Triage | **New project members should typically start here**
:material-check: Can be [assigned issues](https://docs.github.com/en/issues/tracking-your-work-with-issues/assigning-issues-and-pull-requests-to-other-github-users)
:material-check: Can apply labels to issues / PRs
:material-check: Can run workflows [without approval](https://docs.github.com/en/actions/managing-workflow-runs/approving-workflow-runs-from-public-forks) |
  • [iree-triage](https://github.com/orgs/iree-org/teams/iree-triage)
    (access to most repositories)
-Write | **Established contributors can request this access**
:material-check: Can [merge approved pull requests](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/incorporating-changes-from-a-pull-request/merging-a-pull-request)
:material-check: Can create branches
:material-check: Can [re-run workflows](https://docs.github.com/en/actions/managing-workflow-runs-and-deployments/managing-workflow-runs/re-running-workflows-and-jobs) |
  • [iree-write](https://github.com/orgs/iree-org/teams/iree-write)
    (access to most repositories)
  • [iree-turbine-write](https://github.com/orgs/iree-org/teams/iree-turbine-write)
    (access to iree-turbine)
+Write | **Established contributors can request this access**
:material-check: Can [merge approved pull requests](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/incorporating-changes-from-a-pull-request/merging-a-pull-request)
:material-check: Can create branches
:material-check: Can [re-run workflows](https://docs.github.com/en/actions/managing-workflow-runs-and-deployments/managing-workflow-runs/re-running-workflows-and-jobs) |
  • [iree-write](https://github.com/orgs/iree-org/teams/iree-write)
    (access to most repositories)
  • [iree-turbine-write](https://github.com/orgs/iree-org/teams/iree-turbine-write)
    (access to iree-turbine)
  • [iree-fusilli-write](https://github.com/orgs/iree-org/teams/iree-fusilli-write)
    (access to fusilli)
Maintain/Admin | :material-check: Can [edit repository settings](https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features)
:material-check: Can push to [protected branches](https://docs.github.com/en/repositories/configuring-branches-and-merges-in-your-repository/managing-protected-branches/about-protected-branches) | Added case-by-case All access tiers first require joining the diff --git a/docs/website/mkdocs.yml b/docs/website/mkdocs.yml index 115bbd6c8e3d..fb3507fdee7b 100644 --- a/docs/website/mkdocs.yml +++ b/docs/website/mkdocs.yml @@ -245,6 +245,7 @@ nav: - "developers/debugging/model-development.md" - "developers/debugging/releases.md" - "developers/debugging/sanitizers.md" + - "developers/debugging/fuzzing.md" - "Performance": - "developers/performance/benchmarking.md" - "developers/performance/profiling.md" diff --git a/runtime/src/iree/base/BUILD.bazel b/runtime/src/iree/base/BUILD.bazel index 5f83948d351d..1691f95047cc 100644 --- a/runtime/src/iree/base/BUILD.bazel +++ b/runtime/src/iree/base/BUILD.bazel @@ -6,7 +6,7 @@ # Common types and utilities used in the IREE codebase. -load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_library", "iree_runtime_cc_test") +load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_fuzz", "iree_runtime_cc_library", "iree_runtime_cc_test") package( default_visibility = ["//visibility:public"], @@ -112,6 +112,14 @@ iree_runtime_cc_test( ], ) +iree_runtime_cc_fuzz( + name = "string_view_fuzz", + srcs = ["string_view_fuzz.cc"], + deps = [ + ":base", + ], +) + iree_runtime_cc_test( name = "string_view_test", srcs = ["string_view_test.cc"], diff --git a/runtime/src/iree/base/CMakeLists.txt b/runtime/src/iree/base/CMakeLists.txt index 8d0b1e8123fa..56144d624c5c 100644 --- a/runtime/src/iree/base/CMakeLists.txt +++ b/runtime/src/iree/base/CMakeLists.txt @@ -148,6 +148,15 @@ iree_cc_test( iree::testing::gtest_main ) +iree_cc_fuzz( + NAME + string_view_fuzz + SRCS + "string_view_fuzz.cc" + DEPS + ::base +) + iree_cc_test( NAME string_view_test diff --git a/runtime/src/iree/base/internal/BUILD.bazel b/runtime/src/iree/base/internal/BUILD.bazel index 5d0b2a47c068..fde8859615a7 100644 --- a/runtime/src/iree/base/internal/BUILD.bazel +++ b/runtime/src/iree/base/internal/BUILD.bazel @@ -8,7 +8,7 @@ # These are not part of the IREE API. Though they may be used by external # projects their API may change at any time. -load("//build_tools/bazel:build_defs.oss.bzl", "iree_cmake_extra_content", "iree_runtime_cc_binary", "iree_runtime_cc_library", "iree_runtime_cc_test") +load("//build_tools/bazel:build_defs.oss.bzl", "iree_cmake_extra_content", "iree_runtime_cc_binary", "iree_runtime_cc_fuzz", "iree_runtime_cc_library", "iree_runtime_cc_test") load("//build_tools/bazel:cc_binary_benchmark.bzl", "cc_binary_benchmark") load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite") @@ -367,6 +367,14 @@ iree_runtime_cc_library( ], ) +iree_runtime_cc_fuzz( + name = "unicode_fuzz", + srcs = ["unicode_fuzz.cc"], + deps = [ + ":unicode", + ], +) + iree_runtime_cc_test( name = "unicode_test", srcs = ["unicode_test.cc"], diff --git a/runtime/src/iree/base/internal/CMakeLists.txt b/runtime/src/iree/base/internal/CMakeLists.txt index 1d2e5ae15790..41daa557021b 100644 --- a/runtime/src/iree/base/internal/CMakeLists.txt +++ b/runtime/src/iree/base/internal/CMakeLists.txt @@ -395,6 +395,15 @@ iree_cc_library( PUBLIC ) +iree_cc_fuzz( + NAME + unicode_fuzz + SRCS + "unicode_fuzz.cc" + DEPS + ::unicode +) + iree_cc_test( NAME unicode_test diff --git a/runtime/src/iree/base/internal/math.h b/runtime/src/iree/base/internal/math.h index eb8c9b17ecfb..bcbbcc729557 100644 --- a/runtime/src/iree/base/internal/math.h +++ b/runtime/src/iree/base/internal/math.h @@ -334,6 +334,15 @@ static inline float iree_math_make_f32_from_bits(uint32_t src, int exp_bits, (src_exp >> src_exp_shift) - src_exp_bias - src_mantissa_bits); } +// Helper for rounding to nearest-even. Does not right-shift. Returns the +// biased value suitable for right-shifting. +static inline uint32_t bias_to_nearest_even(uint32_t input, int shift_amount) { + uint32_t even_bit = 1u << shift_amount; + uint32_t odd_bit = even_bit >> 1; + uint32_t bias = (input & even_bit) ? (odd_bit) : (odd_bit - 1); + return input + bias; +} + // Generic conversion from f32 to any less-than-32-bit floating-point format, // rounding to nearest-even. The return value is typed as a uint32_t for // genericity but occupies only the bottom (1 + exp_bits + mantissa_bits) bits. @@ -370,8 +379,8 @@ static inline uint32_t iree_math_truncate_f32_to_bits_rounding_to_nearest_even( // can remain nonzero. This happens only with the bf16 type. // Just divide the mantissa (rounding shift). int shift_amount = f32_mantissa_bits - dst_mantissa_bits; - uint32_t rounding_term = 1 << (shift_amount - 1); - dst_mantissa = (f32_mantissa + rounding_term) >> shift_amount; + dst_mantissa = + bias_to_nearest_even(f32_mantissa, shift_amount) >> shift_amount; } // The destination type has fewer exponent bits, so f32 subnormal values // become exactly zero. Leave the mantissa zero. @@ -398,21 +407,18 @@ static inline uint32_t iree_math_truncate_f32_to_bits_rounding_to_nearest_even( dst_mantissa = 0; } else { // Source f32 value is normal so has an implied 1... leading bit. - int effective_f32_mantissa = (1 << f32_mantissa_bits) + f32_mantissa; - // Add this term to achieve rounding to nearest instead of truncation - // towards zero. - int rounding_term = 1 << (shift_amount - 1); - // Finally compute the destination mantissa as a rounded right shift. - dst_mantissa = (effective_f32_mantissa + rounding_term) >> shift_amount; + uint32_t effective_f32_mantissa = + (1u << f32_mantissa_bits) + f32_mantissa; + dst_mantissa = + bias_to_nearest_even(effective_f32_mantissa, shift_amount) >> + shift_amount; } } else { // Normal case. // Implement round-to-nearest-even, by adding a bias before truncating. - int even_bit = 1u << (f32_mantissa_bits - dst_mantissa_bits); - int odd_bit = even_bit >> 1; + int shift_amount = f32_mantissa_bits - dst_mantissa_bits; uint32_t biased_f32_mantissa = - f32_mantissa + - ((f32_mantissa & even_bit) ? (odd_bit) : (odd_bit - 1)); + bias_to_nearest_even(f32_mantissa, shift_amount); // Adding the bias may cause an exponent increment. if (biased_f32_mantissa > f32_mantissa_mask) { // Note: software implementations that try to be fast tend to get this diff --git a/runtime/src/iree/base/internal/math_test.cc b/runtime/src/iree/base/internal/math_test.cc index 0c2f05008ca2..222555da7f47 100644 --- a/runtime/src/iree/base/internal/math_test.cc +++ b/runtime/src/iree/base/internal/math_test.cc @@ -313,6 +313,22 @@ TEST(BF16ConversionTest, F32ToBF16) { EXPECT_EQ(0xff80, iree_math_f32_to_bf16(-FLT_MAX)); EXPECT_EQ(0x0080, iree_math_f32_to_bf16(FLT_MIN)); EXPECT_EQ(0x8080, iree_math_f32_to_bf16(-FLT_MIN)); + // Test some round-to-nearest-even. F32->BF16 is interesting because F32 + // denormals can round to nonzero BF16 denormals. + EXPECT_EQ(0x0000, iree_math_f32_to_bf16(FLT_MIN * 1.0f / 256.f)); + EXPECT_EQ(0x0001, iree_math_f32_to_bf16(FLT_MIN * 2.0f / 256.f)); + EXPECT_EQ(0x0002, iree_math_f32_to_bf16(FLT_MIN * 3.0f / 256.f)); + EXPECT_EQ(0x0002, iree_math_f32_to_bf16(FLT_MIN * 4.0f / 256.f)); + EXPECT_EQ(0x0002, iree_math_f32_to_bf16(FLT_MIN * 5.0f / 256.f)); + EXPECT_EQ(0x0003, iree_math_f32_to_bf16(FLT_MIN * 6.0f / 256.f)); + EXPECT_EQ(0x0004, iree_math_f32_to_bf16(FLT_MIN * 7.0f / 256.f)); + EXPECT_EQ(0x8000, iree_math_f32_to_bf16(FLT_MIN * -1.0f / 256.f)); + EXPECT_EQ(0x8001, iree_math_f32_to_bf16(FLT_MIN * -2.0f / 256.f)); + EXPECT_EQ(0x8002, iree_math_f32_to_bf16(FLT_MIN * -3.0f / 256.f)); + EXPECT_EQ(0x8002, iree_math_f32_to_bf16(FLT_MIN * -4.0f / 256.f)); + EXPECT_EQ(0x8002, iree_math_f32_to_bf16(FLT_MIN * -5.0f / 256.f)); + EXPECT_EQ(0x8003, iree_math_f32_to_bf16(FLT_MIN * -6.0f / 256.f)); + EXPECT_EQ(0x8004, iree_math_f32_to_bf16(FLT_MIN * -7.0f / 256.f)); } TEST(BF16ConversionTest, Denormals) { @@ -503,6 +519,22 @@ TEST(F8E4M3FNConversionTest, F32ToF8E4M3FN) { EXPECT_EQ(0x7A, iree_math_f32_to_f8e4m3fn(304.0f)); EXPECT_EQ(0x7A, iree_math_f32_to_f8e4m3fn(336.0f)); EXPECT_EQ(0x7C, iree_math_f32_to_f8e4m3fn(368.0f)); + // Test round-to-nearest-even for denormals. + EXPECT_EQ(0x00, iree_math_f32_to_f8e4m3fn(0.5f / 512.f)); + EXPECT_EQ(0x01, iree_math_f32_to_f8e4m3fn(1.f / 512.f)); + EXPECT_EQ(0x02, iree_math_f32_to_f8e4m3fn(1.5f / 512.f)); + EXPECT_EQ(0x02, iree_math_f32_to_f8e4m3fn(2.f / 512.f)); + EXPECT_EQ(0x02, iree_math_f32_to_f8e4m3fn(2.5f / 512.f)); + EXPECT_EQ(0x03, iree_math_f32_to_f8e4m3fn(3.f / 512.f)); + EXPECT_EQ(0x04, iree_math_f32_to_f8e4m3fn(3.5f / 512.f)); + EXPECT_EQ(0x80, iree_math_f32_to_f8e4m3fn(-0.5f / 512.f)); + EXPECT_EQ(0x81, iree_math_f32_to_f8e4m3fn(-1.f / 512.f)); + EXPECT_EQ(0x82, iree_math_f32_to_f8e4m3fn(-1.5f / 512.f)); + EXPECT_EQ(0x82, iree_math_f32_to_f8e4m3fn(-2.f / 512.f)); + EXPECT_EQ(0x82, iree_math_f32_to_f8e4m3fn(-2.5f / 512.f)); + EXPECT_EQ(0x83, iree_math_f32_to_f8e4m3fn(-3.f / 512.f)); + EXPECT_EQ(0x84, iree_math_f32_to_f8e4m3fn(-3.5f / 512.f)); + // Important case to test: overflow due to rounding to nearest-even of 465 // to 512, while 464 gets rounded to nearest-even 448, not overflowing. EXPECT_EQ(0x7E, iree_math_f32_to_f8e4m3fn(464.f)); diff --git a/runtime/src/iree/base/internal/unicode_fuzz.cc b/runtime/src/iree/base/internal/unicode_fuzz.cc new file mode 100644 index 000000000000..69e02147fd04 --- /dev/null +++ b/runtime/src/iree/base/internal/unicode_fuzz.cc @@ -0,0 +1,217 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// Fuzz target for Unicode utilities: UTF-8 decoding/validation, category +// classification, case folding, and composition. Includes invariant assertions +// that crash on consistency violations. +// +// See https://iree.dev/developers/debugging/fuzzing/ for build and run info. + +#include +#include + +#include "iree/base/internal/unicode.h" + +// Invariant assertion that crashes on failure. +// We use __builtin_trap() to get a clean crash for the fuzzer to detect. +#define FUZZ_ASSERT(condition) \ + do { \ + if (!(condition)) { \ + __builtin_trap(); \ + } \ + } while (0) + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + iree_string_view_t input = + iree_make_string_view(reinterpret_cast(data), size); + + // Test UTF-8 validation and counting. + (void)iree_unicode_utf8_validate(input); + (void)iree_unicode_utf8_codepoint_count(input); + + // Test incomplete tail detection. + (void)iree_unicode_utf8_incomplete_tail_length(input.data, input.size); + + // Decode all codepoints and test classification/transformation functions. + iree_host_size_t position = 0; + while (position < input.size) { + uint32_t codepoint = iree_unicode_utf8_decode(input, &position); + + // Category classification. + (void)iree_unicode_category(codepoint); + (void)iree_unicode_is_letter(codepoint); + (void)iree_unicode_is_mark(codepoint); + (void)iree_unicode_is_number(codepoint); + (void)iree_unicode_is_punctuation(codepoint); + (void)iree_unicode_is_symbol(codepoint); + (void)iree_unicode_is_separator(codepoint); + (void)iree_unicode_is_other(codepoint); + (void)iree_unicode_is_whitespace(codepoint); + (void)iree_unicode_is_control(codepoint); + (void)iree_unicode_is_cjk(codepoint); + (void)iree_unicode_is_hiragana(codepoint); + (void)iree_unicode_is_katakana(codepoint); + (void)iree_unicode_is_hangul(codepoint); + + // Case folding. + (void)iree_unicode_to_lower(codepoint); + (void)iree_unicode_to_upper(codepoint); + + // NFD decomposition. + (void)iree_unicode_nfd_base(codepoint); + + // Canonical Combining Class. + (void)iree_unicode_ccc(codepoint); + + // UTF-8 encoding (roundtrip test). + char encode_buffer[4]; + (void)iree_unicode_utf8_encode(codepoint, encode_buffer); + (void)iree_unicode_utf8_encoded_length(codepoint); + } + + //===--------------------------------------------------------------------===// + // Direct codepoint testing (raw byte interpretation) + //===--------------------------------------------------------------------===// + // Interpret every 4 bytes as a raw uint32_t codepoint to test the full + // codepoint space including invalid ranges (>0x10FFFF, surrogates). + // This exercises table lookup binary search with boundary values. + for (size_t i = 0; i + 4 <= size; i += 4) { + uint32_t codepoint = (static_cast(data[i]) << 24) | + (static_cast(data[i + 1]) << 16) | + (static_cast(data[i + 2]) << 8) | + static_cast(data[i + 3]); + + // Test all classification functions on arbitrary codepoint values. + iree_unicode_category_t category = iree_unicode_category(codepoint); + bool is_letter = iree_unicode_is_letter(codepoint); + bool is_mark = iree_unicode_is_mark(codepoint); + bool is_number = iree_unicode_is_number(codepoint); + bool is_punctuation = iree_unicode_is_punctuation(codepoint); + bool is_symbol = iree_unicode_is_symbol(codepoint); + bool is_separator = iree_unicode_is_separator(codepoint); + bool is_other = iree_unicode_is_other(codepoint); + (void)iree_unicode_is_whitespace(codepoint); + (void)iree_unicode_is_control(codepoint); + (void)iree_unicode_is_cjk(codepoint); + (void)iree_unicode_is_hiragana(codepoint); + (void)iree_unicode_is_katakana(codepoint); + (void)iree_unicode_is_hangul(codepoint); + + // Invariant: category classification consistency. + // If is_X returns true, the corresponding category bit must be set. + if (is_letter) { + FUZZ_ASSERT((category & IREE_UNICODE_CATEGORY_LETTER) != 0); + } + if (is_mark) { + FUZZ_ASSERT((category & IREE_UNICODE_CATEGORY_MARK) != 0); + } + if (is_number) { + FUZZ_ASSERT((category & IREE_UNICODE_CATEGORY_NUMBER) != 0); + } + if (is_punctuation) { + FUZZ_ASSERT((category & IREE_UNICODE_CATEGORY_PUNCTUATION) != 0); + } + if (is_symbol) { + FUZZ_ASSERT((category & IREE_UNICODE_CATEGORY_SYMBOL) != 0); + } + if (is_separator) { + FUZZ_ASSERT((category & IREE_UNICODE_CATEGORY_SEPARATOR) != 0); + } + if (is_other) { + FUZZ_ASSERT((category & IREE_UNICODE_CATEGORY_OTHER) != 0); + } + + // Test case folding and NFD. + uint32_t lower = iree_unicode_to_lower(codepoint); + uint32_t upper = iree_unicode_to_upper(codepoint); + uint32_t nfd = iree_unicode_nfd_base(codepoint); + (void)iree_unicode_ccc(codepoint); + + // Invariant: case folding idempotency. + // Applying the same case operation twice should yield the same result. + FUZZ_ASSERT(iree_unicode_to_lower(lower) == lower); + FUZZ_ASSERT(iree_unicode_to_upper(upper) == upper); + + // Note: NFD decomposition may be multi-level (e.g., ẳ → ạ → a), so + // nfd_base is NOT necessarily idempotent. Instead, verify it converges + // to a fixed point within a reasonable number of steps. + uint32_t nfd_current = nfd; + for (int depth = 0; depth < 10; ++depth) { + uint32_t nfd_next = iree_unicode_nfd_base(nfd_current); + if (nfd_next == nfd_current) break; // Reached fixed point. + nfd_current = nfd_next; + } + // After at most 10 iterations, we must have reached a fixed point. + FUZZ_ASSERT(iree_unicode_nfd_base(nfd_current) == nfd_current); + + // Invariant: encode/decode roundtrip for valid codepoints. + int encoded_length = iree_unicode_utf8_encoded_length(codepoint); + if (encoded_length > 0) { + char encode_buffer[4]; + int actual_length = iree_unicode_utf8_encode(codepoint, encode_buffer); + + // Invariant: encoded_length and encode must agree. + FUZZ_ASSERT(encoded_length == actual_length); + + // Decode what we just encoded and verify roundtrip. + iree_string_view_t encoded = iree_make_string_view( + encode_buffer, static_cast(actual_length)); + iree_host_size_t decode_position = 0; + uint32_t decoded = iree_unicode_utf8_decode(encoded, &decode_position); + + // Invariant: roundtrip must recover the original codepoint. + FUZZ_ASSERT(decoded == codepoint); + FUZZ_ASSERT(decode_position == + static_cast(actual_length)); + } + } + + //===--------------------------------------------------------------------===// + // Composition testing with status verification + //===--------------------------------------------------------------------===// + // Test composition on valid UTF-8 sequences, verifying status codes. + if (iree_unicode_utf8_validate(input)) { + // Allocate output buffer (composition can only shrink). + char* compose_buffer = new char[size + 1]; + iree_host_size_t out_length = 0; + iree_status_t status = + iree_unicode_compose(input, compose_buffer, size + 1, &out_length); + + // Status must be OK or RESOURCE_EXHAUSTED (for very long combining seqs). + // Any other status indicates a bug in the compose function. + FUZZ_ASSERT(iree_status_is_ok(status) || + iree_status_code(status) == IREE_STATUS_RESOURCE_EXHAUSTED); + + if (iree_status_is_ok(status)) { + // Invariant: output length must not exceed input length. + // Composition can only shrink (combining base + mark -> precomposed). + FUZZ_ASSERT(out_length <= input.size); + + // The output should also be valid UTF-8. + iree_string_view_t output = + iree_make_string_view(compose_buffer, out_length); + FUZZ_ASSERT(iree_unicode_utf8_validate(output)); + } else { + iree_status_ignore(status); + } + delete[] compose_buffer; + } + + // Test pairwise composition with interpreted codepoints. + if (size >= 8) { + uint32_t base = (static_cast(data[0]) << 24) | + (static_cast(data[1]) << 16) | + (static_cast(data[2]) << 8) | + static_cast(data[3]); + uint32_t combining = (static_cast(data[4]) << 24) | + (static_cast(data[5]) << 16) | + (static_cast(data[6]) << 8) | + static_cast(data[7]); + (void)iree_unicode_compose_pair(base, combining); + } + + return 0; +} diff --git a/runtime/src/iree/base/string_view.c b/runtime/src/iree/base/string_view.c index 583b58b0b53b..349196c125bf 100644 --- a/runtime/src/iree/base/string_view.c +++ b/runtime/src/iree/base/string_view.c @@ -96,7 +96,7 @@ IREE_API_EXPORT iree_host_size_t iree_string_view_find_last_of( for (iree_host_size_t i = 0; i < s.size; ++i) { lookup_table[(uint8_t)s.data[i]] = true; } - pos = iree_min(pos, value.size) + 1; + pos = iree_min(pos, value.size - 1) + 1; iree_host_size_t i = pos; while (i != 0) { --i; @@ -261,28 +261,81 @@ static bool iree_string_view_match_pattern_impl(iree_string_view_t value, return true; } char pattern_char = pattern.data[0]; - if (pattern_char == '*' && pattern.size > 1 && - iree_string_view_is_empty(value)) { + + // Normalize wildcard sequences to avoid exponential backtracking. + // A sequence like *?*?* is equivalent to "match 2+ chars then match rest". + // We coalesce all * and ? into a single * with a minimum char requirement. + if (pattern_char == '*' || pattern_char == '?') { + iree_host_size_t min_chars = 0; + iree_host_size_t skip = 0; + bool has_star = false; + while (skip < pattern.size) { + char c = pattern.data[skip]; + if (c == '*') { + has_star = true; + ++skip; + } else if (c == '?') { + ++min_chars; + ++skip; + } else { + break; + } + } + + // Remaining pattern after wildcards. + iree_string_view_t rest = + iree_string_view_substr(pattern, skip, IREE_STRING_VIEW_NPOS); + + if (!has_star) { + // Only ? wildcards - must match exactly min_chars characters. + if (value.size < min_chars) return false; + return iree_string_view_match_pattern_impl( + iree_string_view_substr(value, min_chars, IREE_STRING_VIEW_NPOS), + rest); + } + + // Has * - must match at least min_chars, possibly more. + if (value.size < min_chars) return false; + + // Empty rest means * matches everything remaining. + if (iree_string_view_is_empty(rest)) return true; + + // Try matching rest at each position from min_chars to end. + for (iree_host_size_t i = min_chars; i <= value.size; ++i) { + if (iree_string_view_match_pattern_impl( + iree_string_view_substr(value, i, IREE_STRING_VIEW_NPOS), rest)) { + return true; + } + } return false; - } else if (pattern_char == '*' && pattern.size == 1) { - return true; - } else if (pattern_char == '?' || value.data[0] == pattern_char) { - return iree_string_view_match_pattern_impl( - iree_string_view_substr(value, 1, IREE_STRING_VIEW_NPOS), - iree_string_view_substr(pattern, 1, IREE_STRING_VIEW_NPOS)); - } else if (pattern_char == '*') { - return iree_string_view_match_pattern_impl( - value, - iree_string_view_substr(pattern, 1, IREE_STRING_VIEW_NPOS)) || - iree_string_view_match_pattern_impl( - iree_string_view_substr(value, 1, IREE_STRING_VIEW_NPOS), - pattern); } - return false; + + // Literal character - must match exactly. + if (iree_string_view_is_empty(value) || value.data[0] != pattern_char) { + return false; + } + return iree_string_view_match_pattern_impl( + iree_string_view_substr(value, 1, IREE_STRING_VIEW_NPOS), + iree_string_view_substr(pattern, 1, IREE_STRING_VIEW_NPOS)); } +// Maximum wildcards allowed in a pattern to prevent pathological matching. +// 16 is enough for any reasonable glob (e.g., "*foo*bar*baz*") while avoiding +// O(n^2) blowup on patterns like "?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*". +#define IREE_STRING_VIEW_MAX_PATTERN_WILDCARDS 16 + IREE_API_EXPORT bool iree_string_view_match_pattern( iree_string_view_t value, iree_string_view_t pattern) { + // Count wildcards and reject patterns with too many. + iree_host_size_t wildcard_count = 0; + for (iree_host_size_t i = 0; i < pattern.size; ++i) { + if (pattern.data[i] == '*' || pattern.data[i] == '?') { + ++wildcard_count; + } + } + if (wildcard_count > IREE_STRING_VIEW_MAX_PATTERN_WILDCARDS) { + return false; + } return iree_string_view_match_pattern_impl(value, pattern); } diff --git a/runtime/src/iree/base/string_view_fuzz.cc b/runtime/src/iree/base/string_view_fuzz.cc new file mode 100644 index 000000000000..43841b633d27 --- /dev/null +++ b/runtime/src/iree/base/string_view_fuzz.cc @@ -0,0 +1,185 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// Fuzz target for string parsing utilities: integer/float parsing, device size +// parsing with units, bitfield parsing, pattern matching, and hex byte parsing. +// +// See https://iree.dev/developers/debugging/fuzzing/ for build and run info. + +#include +#include + +#include "iree/base/api.h" + +// Sample bitfield mapping table for fuzzing iree_bitfield_parse. +// Uses realistic flag names similar to actual IREE usage. +static const iree_bitfield_string_mapping_t kTestBitfieldMappings[] = { + {0x7, IREE_SVL("ALL")}, // Combined flag (A|B|C). + {0x1, IREE_SVL("READ")}, // Bit 0. + {0x2, IREE_SVL("WRITE")}, // Bit 1. + {0x4, IREE_SVL("EXECUTE")}, // Bit 2. + {0x8, IREE_SVL("DISCARD")}, // Bit 3. + {0x10, IREE_SVL("MAPPABLE")}, // Bit 4. + {0x20, IREE_SVL("COHERENT")}, // Bit 5. +}; + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + iree_string_view_t input = + iree_make_string_view(reinterpret_cast(data), size); + + //===--------------------------------------------------------------------===// + // Integer parsing (signed and unsigned, various bases) + //===--------------------------------------------------------------------===// + + { + int32_t value_i32 = 0; + (void)iree_string_view_atoi_int32(input, &value_i32); + (void)iree_string_view_atoi_int32_base(input, 10, &value_i32); + (void)iree_string_view_atoi_int32_base(input, 16, &value_i32); + (void)iree_string_view_atoi_int32_base(input, 2, &value_i32); + } + + { + uint32_t value_u32 = 0; + (void)iree_string_view_atoi_uint32(input, &value_u32); + (void)iree_string_view_atoi_uint32_base(input, 10, &value_u32); + (void)iree_string_view_atoi_uint32_base(input, 16, &value_u32); + (void)iree_string_view_atoi_uint32_base(input, 2, &value_u32); + } + + { + int64_t value_i64 = 0; + (void)iree_string_view_atoi_int64(input, &value_i64); + (void)iree_string_view_atoi_int64_base(input, 10, &value_i64); + (void)iree_string_view_atoi_int64_base(input, 16, &value_i64); + (void)iree_string_view_atoi_int64_base(input, 2, &value_i64); + } + + { + uint64_t value_u64 = 0; + (void)iree_string_view_atoi_uint64(input, &value_u64); + (void)iree_string_view_atoi_uint64_base(input, 10, &value_u64); + (void)iree_string_view_atoi_uint64_base(input, 16, &value_u64); + (void)iree_string_view_atoi_uint64_base(input, 2, &value_u64); + } + + //===--------------------------------------------------------------------===// + // Floating point parsing + //===--------------------------------------------------------------------===// + + { + float value_f32 = 0.0f; + (void)iree_string_view_atof(input, &value_f32); + } + + { + double value_f64 = 0.0; + (void)iree_string_view_atod(input, &value_f64); + } + + //===--------------------------------------------------------------------===// + // Device size parsing with units (e.g., "1kb", "2mib", "3gb") + //===--------------------------------------------------------------------===// + + { + iree_device_size_t device_size = 0; + iree_status_t status = + iree_string_view_parse_device_size(input, &device_size); + iree_status_ignore(status); + } + + //===--------------------------------------------------------------------===// + // Bitfield parsing + //===--------------------------------------------------------------------===// + + { + uint32_t bitfield_value = 0; + iree_status_t status = + iree_bitfield_parse(input, IREE_ARRAYSIZE(kTestBitfieldMappings), + kTestBitfieldMappings, &bitfield_value); + iree_status_ignore(status); + } + + //===--------------------------------------------------------------------===// + // Pattern matching (wildcard patterns with * and ?) + //===--------------------------------------------------------------------===// + + // Use the first half as value and second half as pattern. + if (size >= 2) { + size_t mid = size / 2; + iree_string_view_t value = + iree_make_string_view(reinterpret_cast(data), mid); + iree_string_view_t pattern = iree_make_string_view( + reinterpret_cast(data + mid), size - mid); + (void)iree_string_view_match_pattern(value, pattern); + } + + // Also test pattern matching with specific patterns that stress recursion. + (void)iree_string_view_match_pattern(input, IREE_SV("*")); + (void)iree_string_view_match_pattern(input, IREE_SV("?*?")); + (void)iree_string_view_match_pattern(input, IREE_SV("***")); + + //===--------------------------------------------------------------------===// + // Hex byte parsing + //===--------------------------------------------------------------------===// + + // Parse up to 64 bytes of hex data. + { + uint8_t hex_buffer[64] = {0}; + (void)iree_string_view_parse_hex_bytes(input, sizeof(hex_buffer), + hex_buffer); + } + + // Try parsing various sizes to test boundary conditions. + for (size_t parse_size = 1; parse_size <= 8; ++parse_size) { + uint8_t small_buffer[8] = {0}; + (void)iree_string_view_parse_hex_bytes(input, parse_size, small_buffer); + } + + //===--------------------------------------------------------------------===// + // String view operations that process the data + //===--------------------------------------------------------------------===// + + (void)iree_string_view_trim(input); + + // Split operations with various split characters. + { + iree_string_view_t lhs, rhs; + (void)iree_string_view_split(input, '|', &lhs, &rhs); + (void)iree_string_view_split(input, '=', &lhs, &rhs); + (void)iree_string_view_split(input, ',', &lhs, &rhs); + (void)iree_string_view_split(input, ':', &lhs, &rhs); + } + + // Find operations. + if (size > 0) { + char search_char = static_cast(data[0]); + (void)iree_string_view_find_char(input, search_char, 0); + + if (size > 1) { + iree_string_view_t search_set = + iree_make_string_view(reinterpret_cast(data), size / 2); + (void)iree_string_view_find_first_of(input, search_set, 0); + (void)iree_string_view_find_last_of(input, search_set, SIZE_MAX); + } + } + + // Comparison operations. + if (size >= 2) { + size_t mid = size / 2; + iree_string_view_t left = + iree_make_string_view(reinterpret_cast(data), mid); + iree_string_view_t right = iree_make_string_view( + reinterpret_cast(data + mid), size - mid); + (void)iree_string_view_equal(left, right); + (void)iree_string_view_equal_case(left, right); + (void)iree_string_view_compare(left, right); + (void)iree_string_view_starts_with(left, right); + (void)iree_string_view_ends_with(left, right); + } + + return 0; +} diff --git a/runtime/src/iree/base/string_view_test.cc b/runtime/src/iree/base/string_view_test.cc index de623aa7495f..890a90101f9c 100644 --- a/runtime/src/iree/base/string_view_test.cc +++ b/runtime/src/iree/base/string_view_test.cc @@ -670,4 +670,80 @@ TEST(StringViewTest, ParseDeviceSizeInvalid) { EXPECT_THAT(ParseDeviceSize("abc"), StatusIs(StatusCode::kInvalidArgument)); } +TEST(StringViewTest, MatchPattern) { + auto match = [](const char* value, const char* pattern) -> bool { + return iree_string_view_match_pattern(iree_make_cstring_view(value), + iree_make_cstring_view(pattern)); + }; + + // Empty patterns and values. + EXPECT_TRUE(match("", "")); + EXPECT_FALSE(match("a", "")); + EXPECT_FALSE(match("", "a")); + + // Exact matches. + EXPECT_TRUE(match("abc", "abc")); + EXPECT_FALSE(match("abc", "abd")); + EXPECT_FALSE(match("abc", "ab")); + EXPECT_FALSE(match("ab", "abc")); + + // Single character wildcard (?). + EXPECT_TRUE(match("a", "?")); + EXPECT_TRUE(match("abc", "a?c")); + EXPECT_TRUE(match("abc", "???")); + EXPECT_FALSE(match("ab", "???")); + EXPECT_FALSE(match("abcd", "???")); + + // Multi-character wildcard (*). + EXPECT_TRUE(match("", "*")); + EXPECT_TRUE(match("a", "*")); + EXPECT_TRUE(match("abc", "*")); + EXPECT_TRUE(match("abc", "a*")); + EXPECT_TRUE(match("abc", "*c")); + EXPECT_TRUE(match("abc", "a*c")); + EXPECT_TRUE(match("abxyzc", "a*c")); + EXPECT_FALSE(match("abc", "a*d")); + + // Combined wildcards. + EXPECT_TRUE(match("abc", "?*")); + EXPECT_TRUE(match("abc", "*?")); + EXPECT_TRUE(match("abc", "?*?")); + EXPECT_TRUE(match("abcdef", "a?c*f")); + + // Consecutive wildcards (tests coalescing to avoid exponential backtracking). + EXPECT_TRUE(match("abc", "**")); + EXPECT_TRUE(match("abc", "***")); + EXPECT_TRUE(match("abc", "a**c")); + EXPECT_TRUE(match("abc", "**c")); + EXPECT_TRUE(match("abc", "a**")); + + // Pathological pattern that would cause exponential backtracking without + // coalescing: many wildcards followed by a non-matching suffix. + // This must complete in reasonable time (milliseconds, not seconds). + EXPECT_FALSE(match("aaaaaaaaaaaaaaaaaaaab", "**************c")); + EXPECT_TRUE(match("aaaaaaaaaaaaaaaaaaaab", "**************b")); + + // Alternating ?* patterns - also pathological without normalization. + // ?* means "1 or more chars", ?*?* means "2 or more chars", etc. + EXPECT_TRUE(match("ab", "?*")); + EXPECT_TRUE(match("abc", "?*?")); + EXPECT_TRUE(match("abc", "?*?*")); + EXPECT_TRUE(match("abcd", "?*?*")); + EXPECT_FALSE(match("a", "?*?*")); // Need at least 2 chars. + + // Pathological alternating patterns - must complete quickly. + EXPECT_FALSE(match("aaaaaaaaaaaaaaaaaaaab", "?*?*?*?*?*?*?*c")); + EXPECT_TRUE(match("aaaaaaaaaaaaaaaaaaaab", "?*?*?*?*?*?*?*b")); + EXPECT_FALSE(match("aaaaaaaaaaaaaaaaaaaab", "*?*?*?*?*?*?*?c")); + EXPECT_TRUE(match("aaaaaaaaaaaaaaaaaaaab", "*?*?*?*?*?*?*?b")); + + // Patterns with too many wildcards are rejected (returns false). + // Limit is 16 wildcards to prevent O(n^2) blowup. + EXPECT_TRUE(match("abcdefghijklmnop", "????????????????")); // 16 - ok + EXPECT_FALSE( + match("abcdefghijklmnopq", "?????????????????")); // 17 - rejected + EXPECT_FALSE( + match("anything", "?*?*?*?*?*?*?*?*?*")); // 18 wildcards - rejected +} + } // namespace diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_bf16.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_bf16.c index 97eae890a3de..55f30291ca2e 100644 --- a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_bf16.c +++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_bf16.c @@ -7,7 +7,8 @@ #include "iree/builtins/ukernel/arch/x86_64/common_x86_64.h" #include "iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_internal.h" -#if defined(IREE_UK_COMPILER_CLANG) && !defined(IREE_UK_COMPILER_MSVC) +#if defined(IREE_UK_COMPILER_CLANG) && !defined(IREE_UK_COMPILER_MSVC) && \ + !IREE_UK_COMPILER_CLANG_VERSION_AT_LEAST(20, 0) // This inline-asm function is a work-around for: // 1. https://github.com/llvm/llvm-project/issues/68117 // Summary: LLVM crash affecting Clang 16-17. Fixed in Clang 18. diff --git a/runtime/src/iree/hal/utils/file_transfer.c b/runtime/src/iree/hal/utils/file_transfer.c index 6089af66c9b2..6be18daa8e44 100644 --- a/runtime/src/iree/hal/utils/file_transfer.c +++ b/runtime/src/iree/hal/utils/file_transfer.c @@ -593,7 +593,8 @@ static iree_status_t iree_hal_transfer_operation_launch_read( for (iree_host_size_t i = 0; i < operation->worker_count; ++i) { iree_hal_transfer_worker_t* worker = &operation->workers[i]; alloca_semaphore_list.semaphores[i] = worker->semaphore; - alloca_semaphore_list.payload_values[i] = ++worker->pending_timepoint; + uint64_t signal_timepoint = ++worker->pending_timepoint; + alloca_semaphore_list.payload_values[i] = signal_timepoint; } IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_device_queue_alloca( @@ -666,16 +667,17 @@ static iree_status_t iree_hal_transfer_worker_copy_buffer_to_staging( IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)transfer_length); // Timeline increments by one. + uint64_t wait_timepoint = worker->pending_timepoint; iree_hal_semaphore_list_t wait_semaphore_list = { .count = 1, .semaphores = &worker->semaphore, - .payload_values = &worker->pending_timepoint, + .payload_values = &wait_timepoint, }; - ++worker->pending_timepoint; + uint64_t signal_timepoint = ++worker->pending_timepoint; iree_hal_semaphore_list_t signal_semaphore_list = { .count = 1, .semaphores = &worker->semaphore, - .payload_values = &worker->pending_timepoint, + .payload_values = &signal_timepoint, }; // Track the pending copy operation so we know where to place it in the file. @@ -692,8 +694,7 @@ static iree_status_t iree_hal_transfer_worker_copy_buffer_to_staging( // Wait for the copy to complete so we can write it to the file. if (iree_status_is_ok(status)) { status = iree_loop_wait_one( - loop, - iree_hal_semaphore_await(worker->semaphore, worker->pending_timepoint), + loop, iree_hal_semaphore_await(worker->semaphore, signal_timepoint), iree_infinite_timeout(), iree_hal_transfer_worker_copy_staging_to_file, worker); } @@ -785,7 +786,8 @@ static iree_status_t iree_hal_transfer_operation_launch_write( for (iree_host_size_t i = 0; i < operation->worker_count; ++i) { iree_hal_transfer_worker_t* worker = &operation->workers[i]; alloca_semaphore_list.semaphores[i] = worker->semaphore; - alloca_semaphore_list.payload_values[i] = ++worker->pending_timepoint; + uint64_t signal_timepoint = ++worker->pending_timepoint; + alloca_semaphore_list.payload_values[i] = signal_timepoint; } IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_device_queue_alloca( diff --git a/runtime/src/iree/io/formats/irpa/irpa_builder.c b/runtime/src/iree/io/formats/irpa/irpa_builder.c index e1459e898874..b8ca7001c598 100644 --- a/runtime/src/iree/io/formats/irpa/irpa_builder.c +++ b/runtime/src/iree/io/formats/irpa/irpa_builder.c @@ -48,6 +48,14 @@ iree_io_parameter_archive_builder_storage_offset( iree_io_parameter_archive_builder_storage_alignment(builder)); } +IREE_API_EXPORT iree_io_physical_size_t +iree_io_parameter_archive_builder_header_size( + const iree_io_parameter_archive_builder_t* builder) { + IREE_ASSERT_ARGUMENT(builder); + return (iree_io_physical_size_t) + iree_io_parameter_archive_builder_storage_offset(builder); +} + IREE_API_EXPORT iree_io_physical_size_t iree_io_parameter_archive_builder_total_size( const iree_io_parameter_archive_builder_t* builder) { diff --git a/runtime/src/iree/io/formats/irpa/irpa_builder.h b/runtime/src/iree/io/formats/irpa/irpa_builder.h index 46d9a7f11888..767c6d5c69e1 100644 --- a/runtime/src/iree/io/formats/irpa/irpa_builder.h +++ b/runtime/src/iree/io/formats/irpa/irpa_builder.h @@ -57,6 +57,13 @@ IREE_API_EXPORT void iree_io_parameter_archive_builder_deinitialize( IREE_API_EXPORT bool iree_io_parameter_archive_builder_is_empty( const iree_io_parameter_archive_builder_t* builder); +// Returns the size required to store the parameter archive header and +// associated metadata (excluding parameters). Adding new parameters will +// invalidate this value. +IREE_API_EXPORT iree_io_physical_size_t +iree_io_parameter_archive_builder_header_size( + const iree_io_parameter_archive_builder_t* builder); + // Returns the total file size required to store the parameter archive header // and contents of all added parameters. Adding new parameters will invalidate // this value. diff --git a/runtime/src/iree/tooling/context_util.c b/runtime/src/iree/tooling/context_util.c index aa3b9941d580..51a87d2b9347 100644 --- a/runtime/src/iree/tooling/context_util.c +++ b/runtime/src/iree/tooling/context_util.c @@ -468,7 +468,8 @@ static iree_status_t iree_tooling_resolve_module_dependency_callback( } else if (iree_string_view_equal(dependency->name, IREE_SV("io_parameters"))) { IREE_RETURN_IF_ERROR(iree_tooling_create_parameters_module_from_flags( - state->instance, state->host_allocator, &module)); + state->instance, /*additional_provider_count=*/0, + /*additional_providers=*/NULL, state->host_allocator, &module)); } else { // Defer to the generic module resolver registry. IREE_RETURN_IF_ERROR(iree_tooling_resolve_module_dependency( diff --git a/runtime/src/iree/tooling/parameter_util.c b/runtime/src/iree/tooling/parameter_util.c index 9499abe6e4d9..34dd964aa168 100644 --- a/runtime/src/iree/tooling/parameter_util.c +++ b/runtime/src/iree/tooling/parameter_util.c @@ -125,8 +125,9 @@ iree_status_t iree_tooling_build_parameter_indices_from_flags( } iree_status_t iree_tooling_create_parameters_module_from_flags( - iree_vm_instance_t* instance, iree_allocator_t host_allocator, - iree_vm_module_t** out_module) { + iree_vm_instance_t* instance, iree_host_size_t additional_provider_count, + iree_io_parameter_provider_t** additional_providers, + iree_allocator_t host_allocator, iree_vm_module_t** out_module) { IREE_TRACE_ZONE_BEGIN(z0); iree_io_scope_map_t scope_map; @@ -136,9 +137,9 @@ iree_status_t iree_tooling_create_parameters_module_from_flags( iree_status_t status = iree_tooling_build_parameter_indices_from_flags(&scope_map); - // Create one provider per scope. - iree_host_size_t provider_count = 0; - iree_io_parameter_provider_t** providers = + // Create one provider per scope from flags. + iree_host_size_t flag_provider_count = 0; + iree_io_parameter_provider_t** flag_providers = (iree_io_parameter_provider_t**)iree_alloca( scope_map.count * sizeof(iree_io_parameter_provider_t*)); if (iree_status_is_ok(status)) { @@ -146,21 +147,36 @@ iree_status_t iree_tooling_create_parameters_module_from_flags( status = iree_io_parameter_index_provider_create( scope_map.entries[i]->scope, scope_map.entries[i]->index, IREE_IO_PARAMETER_INDEX_PROVIDER_DEFAULT_MAX_CONCURRENT_OPERATIONS, - host_allocator, &providers[i]); + host_allocator, &flag_providers[i]); if (!iree_status_is_ok(status)) break; - ++provider_count; + ++flag_provider_count; } } - // Create the module with the list of providers. + // Merge flag-created providers with additional providers. + iree_host_size_t total_provider_count = + flag_provider_count + additional_provider_count; + iree_io_parameter_provider_t** all_providers = + (iree_io_parameter_provider_t**)iree_alloca( + total_provider_count * sizeof(iree_io_parameter_provider_t*)); + for (iree_host_size_t i = 0; i < flag_provider_count; ++i) { + all_providers[i] = flag_providers[i]; + } + for (iree_host_size_t i = 0; i < additional_provider_count; ++i) { + all_providers[flag_provider_count + i] = additional_providers[i]; + } + + // Create the module with the merged list of providers. if (iree_status_is_ok(status)) { - status = iree_io_parameters_module_create( - instance, provider_count, providers, host_allocator, out_module); + status = iree_io_parameters_module_create(instance, total_provider_count, + all_providers, host_allocator, + out_module); } // Cleanup (module owns providers which own indices/etc). - for (iree_host_size_t i = 0; i < provider_count; ++i) { - iree_io_parameter_provider_release(providers[i]); + // Only release flag providers - additional providers are owned by caller. + for (iree_host_size_t i = 0; i < flag_provider_count; ++i) { + iree_io_parameter_provider_release(flag_providers[i]); } iree_io_scope_map_deinitialize(&scope_map); diff --git a/runtime/src/iree/tooling/parameter_util.h b/runtime/src/iree/tooling/parameter_util.h index a0e46c12fd0e..f1633416d0e4 100644 --- a/runtime/src/iree/tooling/parameter_util.h +++ b/runtime/src/iree/tooling/parameter_util.h @@ -21,10 +21,17 @@ typedef struct iree_io_scope_map_t iree_io_scope_map_t; iree_status_t iree_tooling_build_parameter_indices_from_flags( iree_io_scope_map_t* scope_map); +typedef struct iree_io_parameter_provider_t iree_io_parameter_provider_t; + // Builds an I/O parameters module based on the runtime flags provided. +// If |additional_provider_count| is non-zero then |additional_providers| +// contains providers that will be added to the module in addition to those +// parsed from --parameters= flags. Additional providers are retained by the +// module and can be released by the caller after this call returns. iree_status_t iree_tooling_create_parameters_module_from_flags( - iree_vm_instance_t* instance, iree_allocator_t host_allocator, - iree_vm_module_t** out_module); + iree_vm_instance_t* instance, iree_host_size_t additional_provider_count, + iree_io_parameter_provider_t** additional_providers, + iree_allocator_t host_allocator, iree_vm_module_t** out_module); #ifdef __cplusplus } // extern "C" diff --git a/runtime/src/iree/vm/bytecode/BUILD.bazel b/runtime/src/iree/vm/bytecode/BUILD.bazel index 177ecd1fcf3a..7eeddb1ccb25 100644 --- a/runtime/src/iree/vm/bytecode/BUILD.bazel +++ b/runtime/src/iree/vm/bytecode/BUILD.bazel @@ -55,11 +55,7 @@ if(IREE_BUILD_COMPILER) iree_runtime_cc_test( name = "module_test", - srcs = [ - "dispatch_async_test.cc", - "dispatch_test.cc", - "module_test.cc", - ], + srcs = ["module_test.cc"], deps = [ ":module", ":module_test_module_c", @@ -67,9 +63,6 @@ iree_runtime_cc_test( "//runtime/src/iree/testing:gtest", "//runtime/src/iree/testing:gtest_main", "//runtime/src/iree/vm", - "//runtime/src/iree/vm/test:all_bytecode_modules_c", - "//runtime/src/iree/vm/test:async_bytecode_modules_c", - "//runtime/src/iree/vm/test:async_ops_test_module", ], ) diff --git a/runtime/src/iree/vm/bytecode/CMakeLists.txt b/runtime/src/iree/vm/bytecode/CMakeLists.txt index 8d0250024cbc..2293aadf4e58 100644 --- a/runtime/src/iree/vm/bytecode/CMakeLists.txt +++ b/runtime/src/iree/vm/bytecode/CMakeLists.txt @@ -41,8 +41,6 @@ iree_cc_test( NAME module_test SRCS - "dispatch_async_test.cc" - "dispatch_test.cc" "module_test.cc" DEPS ::module @@ -51,9 +49,6 @@ iree_cc_test( iree::testing::gtest iree::testing::gtest_main iree::vm - iree::vm::test::all_bytecode_modules_c - iree::vm::test::async_bytecode_modules_c - iree::vm::test::async_ops_test_module ) iree_bytecode_module( diff --git a/runtime/src/iree/vm/bytecode/dispatch_async_test.cc b/runtime/src/iree/vm/bytecode/dispatch_async_test.cc deleted file mode 100644 index c117791c737e..000000000000 --- a/runtime/src/iree/vm/bytecode/dispatch_async_test.cc +++ /dev/null @@ -1,834 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -// Tests covering the dispatch logic for individual ops. -// -// iree/vm/test/async_ops.mlir contains the functions used here for testing. We -// avoid defining the IR inline here so that we can run this test on platforms -// that we can't run the full MLIR compiler stack on. - -#include "iree/base/api.h" -#include "iree/testing/gtest.h" -#include "iree/testing/status_matchers.h" -#include "iree/vm/api.h" -#include "iree/vm/bytecode/module.h" - -// Compiled module embedded here to avoid file IO: -#include "iree/vm/test/async_bytecode_modules.h" - -// Native test module for yieldable imports. -#include "iree/vm/test/async_ops_test_module.h" - -namespace iree { -namespace { - -using iree::testing::status::StatusIs; - -class VMBytecodeDispatchAsyncTest : public ::testing::Test { - protected: - void SetUp() override { - IREE_TRACE_SCOPE(); - const iree_file_toc_t* file = async_bytecode_modules_c_create(); - - IREE_CHECK_OK(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, - iree_allocator_system(), &instance_)); - - // Create native yieldable_test module (required by async_ops imports). - IREE_CHECK_OK(yieldable_test_module_create( - instance_, iree_allocator_system(), &native_module_)); - - IREE_CHECK_OK(iree_vm_bytecode_module_create( - instance_, IREE_VM_BYTECODE_MODULE_FLAG_NONE, - iree_const_byte_span_t{reinterpret_cast(file->data), - static_cast(file->size)}, - iree_allocator_null(), iree_allocator_system(), &bytecode_module_)); - - // Native module first for import resolution. - std::vector modules = {native_module_, bytecode_module_}; - IREE_CHECK_OK(iree_vm_context_create_with_modules( - instance_, IREE_VM_CONTEXT_FLAG_NONE, modules.size(), modules.data(), - iree_allocator_system(), &context_)); - } - - void TearDown() override { - IREE_TRACE_SCOPE(); - iree_vm_module_release(bytecode_module_); - iree_vm_module_release(native_module_); - iree_vm_context_release(context_); - iree_vm_instance_release(instance_); - } - - iree_vm_instance_t* instance_ = nullptr; - iree_vm_context_t* context_ = nullptr; - iree_vm_module_t* native_module_ = nullptr; - iree_vm_module_t* bytecode_module_ = nullptr; -}; - -// Tests a simple straight-line yield sequence that requires 3 resumes. -// See iree/vm/test/async_ops.mlir > @yield_sequence -TEST_F(VMBytecodeDispatchAsyncTest, YieldSequence) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("yield_sequence"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - uint32_t arg_value = 97; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_value, sizeof(arg_value)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // 0/3 - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - - // 1/3 - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // 2/3 - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // 3/3 - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - - ASSERT_EQ(ret_value, arg_value + 3); - - iree_vm_stack_deinitialize(stack); -} - -// Tests a yield with data-dependent control, ensuring that we run the -// alternating branches and pass along branch args on resume. -// See iree/vm/test/async_ops.mlir > @yield_divergent -TEST_F(VMBytecodeDispatchAsyncTest, YieldDivergent) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("yield_divergent"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - // result = %arg0 ? %arg1 : %arg2 - struct { - uint32_t arg0; - uint32_t arg1; - uint32_t arg2; - } arg_values = { - 0, - 100, - 200, - }; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_values, sizeof(arg_values)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // arg0=0: result = %arg0 ? %arg1 : %arg2 => %arg2 - arg_values.arg0 = 0; - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - ASSERT_EQ(ret_value, arg_values.arg2); - - // arg0=1: result = %arg0 ? %arg1 : %arg2 => %arg1 - arg_values.arg0 = 1; - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - ASSERT_EQ(ret_value, arg_values.arg1); - - iree_vm_stack_deinitialize(stack); -} - -//===----------------------------------------------------------------------===// -// CallYieldable tests -//===----------------------------------------------------------------------===// - -class VMBytecodeDispatchCallYieldableTest : public ::testing::Test { - protected: - void SetUp() override { - IREE_TRACE_SCOPE(); - const iree_file_toc_t* file = async_bytecode_modules_c_create(); - - IREE_CHECK_OK(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, - iree_allocator_system(), &instance_)); - - // Create native yieldable_test module (required by async_ops imports). - IREE_CHECK_OK(yieldable_test_module_create( - instance_, iree_allocator_system(), &native_module_)); - - IREE_CHECK_OK(iree_vm_bytecode_module_create( - instance_, IREE_VM_BYTECODE_MODULE_FLAG_NONE, - iree_const_byte_span_t{reinterpret_cast(file->data), - static_cast(file->size)}, - iree_allocator_null(), iree_allocator_system(), &bytecode_module_)); - - // Native module first for import resolution. - std::vector modules = {native_module_, bytecode_module_}; - IREE_CHECK_OK(iree_vm_context_create_with_modules( - instance_, IREE_VM_CONTEXT_FLAG_NONE, modules.size(), modules.data(), - iree_allocator_system(), &context_)); - } - - void TearDown() override { - IREE_TRACE_SCOPE(); - iree_vm_module_release(bytecode_module_); - iree_vm_module_release(native_module_); - iree_vm_context_release(context_); - iree_vm_instance_release(instance_); - } - - iree_vm_instance_t* instance_ = nullptr; - iree_vm_context_t* context_ = nullptr; - iree_vm_module_t* native_module_ = nullptr; - iree_vm_module_t* bytecode_module_ = nullptr; -}; - -// Tests calling an internal function that yields 4 times via vm.call.yieldable. -// See iree/vm/test/call_yieldable_ops.mlir > @call_yieldable_internal -TEST_F(VMBytecodeDispatchCallYieldableTest, CallYieldableInternal) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_yieldable_internal"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(nullptr, 0); // No arguments - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // The callee yields 3 times, so we need 3 resumes. - // begin -> 1st yield -> DEFERRED - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - - // resume -> 2nd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> 3rd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> return -> OK - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - - // Result should be 4 (0 + 4 increments across 4 basic blocks) - ASSERT_EQ(ret_value, 4u); - - iree_vm_stack_deinitialize(stack); -} - -// Tests calling an internal yieldable function with an argument. -// See iree/vm/test/call_yieldable_ops.mlir > @call_yieldable_with_arg -TEST_F(VMBytecodeDispatchCallYieldableTest, CallYieldableWithArg) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_yieldable_with_arg"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - uint32_t arg_value = 42; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_value, sizeof(arg_value)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // The callee yields 1 time. - // 0/1 - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - - // 1/1 - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - - // Result should be arg_value + 1 - ASSERT_EQ(ret_value, arg_value + 1); - - iree_vm_stack_deinitialize(stack); -} - -//===----------------------------------------------------------------------===// -// CallYieldable to Imports tests -//===----------------------------------------------------------------------===// -// Tests vm.call.yieldable calling native module functions that yield. - -class VMBytecodeDispatchCallYieldableImportTest : public ::testing::Test { - protected: - void SetUp() override { - IREE_TRACE_SCOPE(); - const iree_file_toc_t* file = async_bytecode_modules_c_create(); - - IREE_CHECK_OK(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, - iree_allocator_system(), &instance_)); - - // Create native yieldable_test module. - IREE_CHECK_OK(yieldable_test_module_create( - instance_, iree_allocator_system(), &native_module_)); - - // Create bytecode module that imports from native module. - IREE_CHECK_OK(iree_vm_bytecode_module_create( - instance_, IREE_VM_BYTECODE_MODULE_FLAG_NONE, - iree_const_byte_span_t{reinterpret_cast(file->data), - static_cast(file->size)}, - iree_allocator_null(), iree_allocator_system(), &bytecode_module_)); - - // Create context with both modules (native first for import resolution). - std::vector modules = {native_module_, bytecode_module_}; - IREE_CHECK_OK(iree_vm_context_create_with_modules( - instance_, IREE_VM_CONTEXT_FLAG_NONE, modules.size(), modules.data(), - iree_allocator_system(), &context_)); - } - - void TearDown() override { - IREE_TRACE_SCOPE(); - iree_vm_module_release(bytecode_module_); - iree_vm_module_release(native_module_); - iree_vm_context_release(context_); - iree_vm_instance_release(instance_); - } - - iree_vm_instance_t* instance_ = nullptr; - iree_vm_context_t* context_ = nullptr; - iree_vm_module_t* native_module_ = nullptr; - iree_vm_module_t* bytecode_module_ = nullptr; -}; - -// Tests calling a yieldable import that yields 3 times. -// This exercises Bug 1 fix: PC must be saved at instruction start, not after -// decode. -TEST_F(VMBytecodeDispatchCallYieldableImportTest, YieldableImportYields3) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_yieldable_import_yields_3"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - uint32_t arg_value = 100; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_value, sizeof(arg_value)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // The import yields 3 times. - // begin -> 1st yield -> DEFERRED - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - - // resume -> 2nd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> 3rd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> return -> OK - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - - // Result should be arg + 3 - ASSERT_EQ(ret_value, arg_value + 3); - - iree_vm_stack_deinitialize(stack); -} - -// Tests calling a yieldable import that yields 0 times (synchronous). -TEST_F(VMBytecodeDispatchCallYieldableImportTest, YieldableImportYields0) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_yieldable_import_yields_0"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - uint32_t arg_value = 42; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_value, sizeof(arg_value)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // No yields, should complete immediately. - IREE_ASSERT_OK( - function.module->begin_call(function.module->self, stack, call)); - - // Result should be arg + 0 - ASSERT_EQ(ret_value, arg_value); - - iree_vm_stack_deinitialize(stack); -} - -// Tests calling a yieldable import after an internal function call. -// This exercises Bug 2 fix: return_registers must be cleared after internal -// call. -TEST_F(VMBytecodeDispatchCallYieldableImportTest, YieldableAfterInternal) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_yieldable_after_internal"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - uint32_t arg_value = 5; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_value, sizeof(arg_value)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // The function: - // 1. Calls internal_add_10(arg) -> arg + 10 - // 2. Calls yieldable import yield_n(arg+10, 2) which yields 2 times - // Expected: 2 yields, result = (arg + 10) + 2 - - // begin -> internal call completes, import yields 1st time -> DEFERRED - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - - // resume -> 2nd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> return -> OK - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - - // Result should be (arg + 10) + 2 = arg + 12 - ASSERT_EQ(ret_value, arg_value + 12); - - iree_vm_stack_deinitialize(stack); -} - -// Tests two sequential yieldable import calls in the same function. -// This catches bugs where the second call sees stale state from the first. -TEST_F(VMBytecodeDispatchCallYieldableImportTest, YieldableImportSequential) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_yieldable_import_sequential"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - uint32_t arg_value = 10; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_value, sizeof(arg_value)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // First import yields 2 times, second import yields 3 times = 5 total yields. - // begin -> 1st import, 1st yield -> DEFERRED - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - - // resume -> 1st import, 2nd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> 1st import done, 2nd import, 1st yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> 2nd import, 2nd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> 2nd import, 3rd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> 2nd import done, return -> OK - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - - // Result should be (arg + 2) + 3 = arg + 5 - ASSERT_EQ(ret_value, arg_value + 5); - - iree_vm_stack_deinitialize(stack); -} - -// Tests a yieldable import nested inside an internal yieldable function. -// This is the most complex frame stack scenario. -TEST_F(VMBytecodeDispatchCallYieldableImportTest, YieldableImportNested) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_nested_yieldable"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - uint32_t arg_value = 50; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_value, sizeof(arg_value)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // Sequence: 1 yield (internal) + 2 yields (import) + 1 yield (internal) = 4 - // begin -> internal 1st yield -> DEFERRED - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - - // resume -> import 1st yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> import 2nd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> import done, internal 2nd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> internal done, return -> OK - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - - // Result should be ((arg + 1) + 2) + 1 = arg + 4 - ASSERT_EQ(ret_value, arg_value + 4); - - iree_vm_stack_deinitialize(stack); -} - -// Tests a yieldable import with many yields to catch state accumulation bugs. -TEST_F(VMBytecodeDispatchCallYieldableImportTest, YieldableImportStress) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_yieldable_import_stress"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - uint32_t arg_value = 1000; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_value, sizeof(arg_value)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // 10 yields total. - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - - for (int i = 1; i < 10; ++i) { - ASSERT_THAT(function.module->resume_call(function.module->self, stack, - call.results), - StatusIs(StatusCode::kDeferred)) - << "Expected DEFERRED at resume " << i; - } - - // Final resume should complete. - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - - // Result should be arg + 10 - ASSERT_EQ(ret_value, arg_value + 10); - - iree_vm_stack_deinitialize(stack); -} - -//===----------------------------------------------------------------------===// -// CallVariadicYieldable to Imports tests -//===----------------------------------------------------------------------===// -// Tests vm.call.variadic.yieldable calling native module functions that yield. - -// Tests calling a variadic yieldable import with 2 args and 3 yields. -TEST_F(VMBytecodeDispatchCallYieldableImportTest, VariadicYieldable2Args) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_variadic_yieldable_2args"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - struct { - uint32_t arg0; - uint32_t arg1; - } arg_values = {10, 20}; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_values, sizeof(arg_values)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // The import sums the variadic args and yields 3 times. - // begin -> 1st yield -> DEFERRED - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - - // resume -> 2nd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> 3rd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> return -> OK - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - - // Result should be (arg0 + arg1) + 3 = 10 + 20 + 3 = 33 - ASSERT_EQ(ret_value, arg_values.arg0 + arg_values.arg1 + 3); - - iree_vm_stack_deinitialize(stack); -} - -// Tests calling a variadic yieldable import with 0 yields (synchronous). -TEST_F(VMBytecodeDispatchCallYieldableImportTest, VariadicYieldable0Yields) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_variadic_yieldable_0yields"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - struct { - uint32_t arg0; - uint32_t arg1; - uint32_t arg2; - } arg_values = {5, 10, 15}; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_values, sizeof(arg_values)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // No yields, should complete immediately. - IREE_ASSERT_OK( - function.module->begin_call(function.module->self, stack, call)); - - // Result should be arg0 + arg1 + arg2 = 5 + 10 + 15 = 30 - ASSERT_EQ(ret_value, arg_values.arg0 + arg_values.arg1 + arg_values.arg2); - - iree_vm_stack_deinitialize(stack); -} - -// Tests calling a variadic yieldable import with 1 arg. -TEST_F(VMBytecodeDispatchCallYieldableImportTest, VariadicYieldable1Arg) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_variadic_yieldable_1arg"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - uint32_t arg_value = 100; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_value, sizeof(arg_value)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // 2 yields. - // begin -> 1st yield -> DEFERRED - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - - // resume -> 2nd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> return -> OK - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - - // Result should be arg0 + 2 = 100 + 2 = 102 - ASSERT_EQ(ret_value, arg_value + 2); - - iree_vm_stack_deinitialize(stack); -} - -// Tests calling a variadic yieldable import with empty variadic list. -TEST_F(VMBytecodeDispatchCallYieldableImportTest, VariadicYieldableEmpty) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_variadic_yieldable_empty"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(nullptr, 0); // No arguments - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // 1 yield. - // begin -> 1st yield -> DEFERRED - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - - // resume -> return -> OK - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - - // Result should be 0 + 1 = 1 - ASSERT_EQ(ret_value, 1u); - - iree_vm_stack_deinitialize(stack); -} - -// Tests two sequential variadic yieldable calls. -TEST_F(VMBytecodeDispatchCallYieldableImportTest, VariadicYieldableSequential) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_variadic_yieldable_sequential"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - struct { - uint32_t arg0; - uint32_t arg1; - uint32_t arg2; - } arg_values = {10, 20, 5}; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_values, sizeof(arg_values)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // First variadic: 2 yields, second variadic: 1 yield = 3 yields total. - // begin -> 1st call, 1st yield -> DEFERRED - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - - // resume -> 1st call, 2nd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> 1st call done, 2nd call, 1st yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> 2nd call done, return -> OK - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - - // Result should be: - // First call: sum(arg0, arg1) + 2 yields = (10 + 20) + 2 = 32 - // Second call: sum(32, arg2) + 1 yield = (32 + 5) + 1 = 38 - ASSERT_EQ(ret_value, 38u); - - iree_vm_stack_deinitialize(stack); -} - -} // namespace -} // namespace iree diff --git a/runtime/src/iree/vm/bytecode/dispatch_test.cc b/runtime/src/iree/vm/bytecode/dispatch_test.cc deleted file mode 100644 index 21361bfe2a0a..000000000000 --- a/runtime/src/iree/vm/bytecode/dispatch_test.cc +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -// Tests covering the dispatch logic for individual ops. -// -// iree/vm/test/*.mlir contains the functions used here for testing. We -// avoid defining the IR inline here so that we can run this test on platforms -// that we can't run the full MLIR compiler stack on. - -#include "iree/base/api.h" -#include "iree/testing/gtest.h" -#include "iree/vm/api.h" -#include "iree/vm/bytecode/module.h" - -// Compiled module embedded here to avoid file IO: -#include "iree/vm/test/all_bytecode_modules.h" - -namespace { - -struct TestParams { - const struct iree_file_toc_t& module_file; - std::string function_name; -}; - -std::ostream& operator<<(std::ostream& os, const TestParams& params) { - std::string name{params.module_file.name}; - auto name_sv = iree_make_string_view(name.data(), name.size()); - iree_string_view_replace_char(name_sv, ':', '_'); - iree_string_view_replace_char(name_sv, '.', '_'); - return os << name << "_" << params.function_name; -} - -std::vector GetModuleTestParams() { - std::vector test_params; - - iree_vm_instance_t* instance = NULL; - IREE_CHECK_OK(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, - iree_allocator_system(), &instance)); - - const struct iree_file_toc_t* module_file_toc = - all_bytecode_modules_c_create(); - for (size_t i = 0; i < all_bytecode_modules_c_size(); ++i) { - const auto& module_file = module_file_toc[i]; - iree_vm_module_t* module = nullptr; - IREE_CHECK_OK(iree_vm_bytecode_module_create( - instance, IREE_VM_BYTECODE_MODULE_FLAG_NONE, - iree_const_byte_span_t{ - reinterpret_cast(module_file.data), - static_cast(module_file.size)}, - iree_allocator_null(), iree_allocator_system(), &module)); - iree_vm_module_signature_t signature = iree_vm_module_signature(module); - test_params.reserve(test_params.size() + signature.export_function_count); - for (int i = 0; i < signature.export_function_count; ++i) { - iree_vm_function_t function; - IREE_CHECK_OK(iree_vm_module_lookup_function_by_ordinal( - module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &function)); - iree_string_view_t function_name = iree_vm_function_name(&function); - test_params.push_back( - {module_file, std::string(function_name.data, function_name.size)}); - } - iree_vm_module_release(module); - } - - iree_vm_instance_release(instance); - - return test_params; -} - -class VMBytecodeDispatchTest - : public ::testing::Test, - public ::testing::WithParamInterface { - protected: - virtual void SetUp() { - const auto& test_params = GetParam(); - - IREE_CHECK_OK(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, - iree_allocator_system(), &instance_)); - - IREE_CHECK_OK(iree_vm_bytecode_module_create( - instance_, IREE_VM_BYTECODE_MODULE_FLAG_NONE, - iree_const_byte_span_t{ - reinterpret_cast(test_params.module_file.data), - static_cast(test_params.module_file.size)}, - iree_allocator_null(), iree_allocator_system(), &bytecode_module_)); - - std::vector modules = {bytecode_module_}; - IREE_CHECK_OK(iree_vm_context_create_with_modules( - instance_, IREE_VM_CONTEXT_FLAG_NONE, modules.size(), modules.data(), - iree_allocator_system(), &context_)); - } - - virtual void TearDown() { - iree_vm_module_release(bytecode_module_); - iree_vm_context_release(context_); - iree_vm_instance_release(instance_); - } - - iree_status_t RunFunction(const char* function_name) { - iree_vm_function_t function; - IREE_CHECK_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - iree_make_cstring_view(function_name), &function)); - - iree_vm_invocation_flags_t flags = IREE_VM_INVOCATION_FLAG_NONE; - // NOTE: adding this bit makes it easy to debug issues on stdout: - // flags |= IREE_VM_INVOCATION_FLAG_TRACE_EXECUTION; - return iree_vm_invoke(context_, function, flags, - /*policy=*/nullptr, /*inputs=*/nullptr, - /*outputs=*/nullptr, iree_allocator_system()); - } - - iree_vm_instance_t* instance_ = nullptr; - iree_vm_context_t* context_ = nullptr; - iree_vm_module_t* bytecode_module_ = nullptr; -}; - -TEST_P(VMBytecodeDispatchTest, Check) { - const auto& test_params = GetParam(); - bool expect_failure = test_params.function_name.find("fail_") == 0; - - iree_status_t status = RunFunction(test_params.function_name.c_str()); - if (iree_status_is_ok(status)) { - if (expect_failure) { - GTEST_FAIL() << "Function expected failure but succeeded"; - } else { - GTEST_SUCCEED(); - } - } else { - if (expect_failure) { - iree_status_ignore(status); - GTEST_SUCCEED(); - } else { - GTEST_FAIL() << "Function expected success but failed with error: " - << iree::Status(std::move(status)).ToString(); - } - } -} - -INSTANTIATE_TEST_SUITE_P(VMIRFunctions, VMBytecodeDispatchTest, - ::testing::ValuesIn(GetModuleTestParams()), - ::testing::PrintToStringParamName()); - -} // namespace diff --git a/runtime/src/iree/vm/list.c b/runtime/src/iree/vm/list.c index d78bc852fe80..89f0fca5ee7c 100644 --- a/runtime/src/iree/vm/list.c +++ b/runtime/src/iree/vm/list.c @@ -860,6 +860,9 @@ static iree_status_t iree_vm_list_get_ref(const iree_vm_list_t* list, return iree_make_status(IREE_STATUS_FAILED_PRECONDITION); } iree_vm_list_ref_op(ref_mode, &variant->ref, out_value); + if (ref_mode == IREE_VM_LIST_REF_MOVE) { + variant->type = iree_vm_make_undefined_type_def(); + } break; } default: @@ -978,7 +981,7 @@ static iree_status_t iree_vm_list_get_variant(const iree_vm_list_t* list, "index %" PRIhsz " out of bounds (%" PRIhsz ")", i, list->count); } - iree_vm_variant_reset(out_variant); + *out_variant = iree_vm_variant_empty(); uintptr_t element_ptr = (uintptr_t)list->storage + i * list->element_size; switch (list->storage_mode) { case IREE_VM_LIST_STORAGE_MODE_VALUE: { @@ -998,6 +1001,9 @@ static iree_status_t iree_vm_list_get_variant(const iree_vm_list_t* list, out_variant->type = variant->type; if (iree_vm_type_def_is_ref(variant->type)) { iree_vm_list_ref_op(ref_mode, &variant->ref, &out_variant->ref); + if (ref_mode == IREE_VM_LIST_REF_MOVE) { + variant->type = iree_vm_make_undefined_type_def(); + } } else { memcpy(out_variant->value_storage, variant->value_storage, sizeof(variant->value_storage)); diff --git a/runtime/src/iree/vm/list_test.cc b/runtime/src/iree/vm/list_test.cc index d75c02f18eb0..b95375f9a864 100644 --- a/runtime/src/iree/vm/list_test.cc +++ b/runtime/src/iree/vm/list_test.cc @@ -273,6 +273,61 @@ TEST_F(VMListTest, GetRefRetainOrMove) { iree_vm_list_release(list); } +// Tests that moving a ref from a variant list properly marks the slot as empty. +TEST_F(VMListTest, VariantListRefMoveMarksSlotEmpty) { + // Create a variant list (stores any type). + iree_vm_type_def_t element_type = iree_vm_make_undefined_type_def(); + iree_vm_list_t* list = nullptr; + IREE_ASSERT_OK(iree_vm_list_create(element_type, /*initial_capacity=*/1, + iree_allocator_system(), &list)); + IREE_ASSERT_OK(iree_vm_list_resize(list, 1)); + + // Set a ref into the variant slot. + iree_vm_ref_t ref_a = MakeRef(1.0f); + IREE_ASSERT_OK(iree_vm_list_set_ref_move(list, 0, &ref_a)); + + // Verify the slot contains a ref. + { + iree_vm_variant_t variant; + IREE_ASSERT_OK(iree_vm_list_get_variant_assign(list, 0, &variant)); + EXPECT_TRUE(iree_vm_variant_is_ref(variant)); + EXPECT_FALSE(iree_vm_variant_is_empty(variant)); + } + + // Move the ref out of the variant list. + iree_vm_ref_t moved{0}; + IREE_ASSERT_OK( + iree_vm_list_get_ref_retain_or_move(list, 0, /*is_move=*/true, &moved)); + EXPECT_TRUE(test_a_isa(moved)); + iree_vm_ref_release(&moved); + + // Verify the slot is now empty (type should be undefined/variant). + { + iree_vm_variant_t variant; + IREE_ASSERT_OK(iree_vm_list_get_variant_assign(list, 0, &variant)); + EXPECT_TRUE(iree_vm_variant_is_empty(variant)) + << "After move, variant slot should be empty"; + } + + // Also test get_variant_move marks the slot empty. + { + iree_vm_ref_t ref_b = MakeRef(2.0f); + IREE_ASSERT_OK(iree_vm_list_set_ref_move(list, 0, &ref_b)); + + iree_vm_variant_t moved_variant; + IREE_ASSERT_OK(iree_vm_list_get_variant_move(list, 0, &moved_variant)); + EXPECT_TRUE(iree_vm_variant_is_ref(moved_variant)); + iree_vm_ref_release(&moved_variant.ref); + + iree_vm_variant_t after_move; + IREE_ASSERT_OK(iree_vm_list_get_variant_assign(list, 0, &after_move)); + EXPECT_TRUE(iree_vm_variant_is_empty(after_move)) + << "After get_variant_move, slot should be empty"; + } + + iree_vm_list_release(list); +} + // Tests simple variant list usage, mainly just for demonstration. // Stores any heterogeneous element type, equivalent to `!vm.list`. TEST_F(VMListTest, UsageVariant) { diff --git a/runtime/src/iree/vm/native_module_packing.h b/runtime/src/iree/vm/native_module_packing.h index 49860f16d22e..794bb2668471 100644 --- a/runtime/src/iree/vm/native_module_packing.h +++ b/runtime/src/iree/vm/native_module_packing.h @@ -329,6 +329,41 @@ static inline params_ptr_t align_ptr(params_ptr_t ptr) { return ptr; } +// Computes the effective alignment for a parameter type. +// Only 8-byte types (i64, f64, ref) require special alignment; everything +// else uses the minimum 4-byte alignment. The primary template works for all +// types since alignof() is valid for any complete type. +template +struct ParamAlignment { + static constexpr iree_host_size_t value = alignof(T) > sizeof(int32_t) + ? alignof(T) + : sizeof(int32_t); +}; + +// Computes the maximum alignment across a parameter pack. +template +struct MaxParamAlignment; + +template <> +struct MaxParamAlignment<> { + static constexpr iree_host_size_t value = sizeof(int32_t); +}; + +template +struct MaxParamAlignment { + static constexpr iree_host_size_t value = + ParamAlignment::type>::value; +}; + +template +struct MaxParamAlignment { + static constexpr iree_host_size_t value = + (ParamAlignment::type>::value > + MaxParamAlignment::value) + ? ParamAlignment::type>::value + : MaxParamAlignment::value; +}; + template struct ParamUnpack; template <> @@ -364,31 +399,36 @@ struct Unpacker { typename impl::remove_cvref::type>::storage_type()...); Status status; params_ptr_t ptr = storage.data; - ApplyLoad(status, ptr, params, + params_ptr_t limit = storage.data + storage.data_length; + ApplyLoad(status, ptr, limit, params, std::make_index_sequence()); IREE_RETURN_IF_ERROR(std::move(status)); - // Note: we check > instead of != because alignment padding can leave - // unused bytes at the end of the buffer. - params_ptr_t limit = storage.data + storage.data_length; - if (IREE_UNLIKELY(ptr > limit)) { - return iree_make_status( - IREE_STATUS_INVALID_ARGUMENT, - "argument buffer unpacking failure; consumed %" PRIhsz - " bytes beyond %" PRIhsz " byte buffer", - (reinterpret_cast(ptr) - reinterpret_cast(limit)), - storage.data_length); + // Verify remaining bytes are valid trailing alignment padding. + // Buffer sizes are computed with trailing padding to max_alignment, so + // unconsumed bytes must be less than max_alignment. This catches cases + // where the caller provided more data than expected (trailing garbage). + constexpr iree_host_size_t max_alignment = + impl::MaxParamAlignment::value; + iree_host_size_t remaining = static_cast(limit - ptr); + if (IREE_UNLIKELY(ptr > limit || remaining >= max_alignment)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument buffer unpacking failure; %" PRIhsz + " bytes remaining in %" PRIhsz + " byte buffer (max valid padding: %" PRIhsz ")", + remaining, storage.data_length, + max_alignment - 1); } return std::move(params); } private: template - static void ApplyLoad(Status& status, params_ptr_t& ptr, T&& params, - std::index_sequence) { + static void ApplyLoad(Status& status, params_ptr_t& ptr, params_ptr_t limit, + T&& params, std::index_sequence) { impl::order_sequence{ (impl::ParamUnpack>::type>::type>:: - Load(status, ptr, std::get(params)), + Load(status, ptr, limit, std::get(params)), 0)...}; } }; @@ -397,8 +437,15 @@ struct Unpacker { template struct ParamUnpack> { using storage_type = T; - static void Load(Status& status, params_ptr_t& ptr, storage_type& out_param) { + static void Load(Status& status, params_ptr_t& ptr, params_ptr_t limit, + storage_type& out_param) { + if (!status.ok()) return; ptr = align_ptr(ptr); + if (IREE_UNLIKELY(ptr + sizeof(T) > limit)) { + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument buffer overflow reading primitive"); + return; + } out_param = *reinterpret_cast(ptr); ptr += sizeof(T); } @@ -408,8 +455,15 @@ struct ParamUnpack> { template <> struct ParamUnpack { using storage_type = opaque_ref; - static void Load(Status& status, params_ptr_t& ptr, storage_type& out_param) { + static void Load(Status& status, params_ptr_t& ptr, params_ptr_t limit, + storage_type& out_param) { + if (!status.ok()) return; ptr = align_ptr(ptr); + if (IREE_UNLIKELY(ptr + sizeof(iree_vm_ref_t) > limit)) { + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument buffer overflow reading ref"); + return; + } iree_vm_ref_retain(reinterpret_cast(ptr), &out_param); ptr += sizeof(iree_vm_ref_t); } @@ -421,8 +475,15 @@ struct ParamUnpack { template struct ParamUnpack> { using storage_type = ref; - static void Load(Status& status, params_ptr_t& ptr, storage_type& out_param) { + static void Load(Status& status, params_ptr_t& ptr, params_ptr_t limit, + storage_type& out_param) { + if (!status.ok()) return; ptr = align_ptr(ptr); + if (IREE_UNLIKELY(ptr + sizeof(iree_vm_ref_t) > limit)) { + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument buffer overflow reading ref"); + return; + } auto* reg_ptr = reinterpret_cast(ptr); ptr += sizeof(iree_vm_ref_t); if (reg_ptr->type == ref_type_descriptor::type()) { @@ -447,8 +508,15 @@ struct ParamUnpack> { template struct ParamUnpack> { using storage_type = ref; - static void Load(Status& status, params_ptr_t& ptr, storage_type& out_param) { + static void Load(Status& status, params_ptr_t& ptr, params_ptr_t limit, + storage_type& out_param) { + if (!status.ok()) return; ptr = align_ptr(ptr); + if (IREE_UNLIKELY(ptr + sizeof(iree_vm_ref_t) > limit)) { + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument buffer overflow reading ref"); + return; + } auto* reg_ptr = reinterpret_cast(ptr); ptr += sizeof(iree_vm_ref_t); if (reg_ptr->type == ref_type_descriptor::type()) { @@ -474,8 +542,15 @@ struct ParamUnpack> { template struct ParamUnpack::value>> { using storage_type = T*; - static void Load(Status& status, params_ptr_t& ptr, storage_type& out_param) { + static void Load(Status& status, params_ptr_t& ptr, params_ptr_t limit, + storage_type& out_param) { + if (!status.ok()) return; ptr = align_ptr(ptr); + if (IREE_UNLIKELY(ptr + sizeof(iree_vm_ref_t) > limit)) { + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument buffer overflow reading ref"); + return; + } auto* reg_ptr = reinterpret_cast(ptr); ptr += sizeof(iree_vm_ref_t); if (reg_ptr->type == ref_type_descriptor::type()) { @@ -500,8 +575,15 @@ struct ParamUnpack::value>> { template <> struct ParamUnpack { using storage_type = iree_string_view_t; - static void Load(Status& status, params_ptr_t& ptr, storage_type& out_param) { + static void Load(Status& status, params_ptr_t& ptr, params_ptr_t limit, + storage_type& out_param) { + if (!status.ok()) return; ptr = align_ptr(ptr); + if (IREE_UNLIKELY(ptr + sizeof(iree_vm_ref_t) > limit)) { + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument buffer overflow reading ref"); + return; + } auto* reg_ptr = reinterpret_cast(ptr); ptr += sizeof(iree_vm_ref_t); if (reg_ptr->type == ref_type_descriptor::type()) { @@ -530,8 +612,15 @@ struct ParamUnpack { template <> struct ParamUnpack { using storage_type = std::string_view; - static void Load(Status& status, params_ptr_t& ptr, storage_type& out_param) { + static void Load(Status& status, params_ptr_t& ptr, params_ptr_t limit, + storage_type& out_param) { + if (!status.ok()) return; ptr = align_ptr(ptr); + if (IREE_UNLIKELY(ptr + sizeof(iree_vm_ref_t) > limit)) { + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument buffer overflow reading ref"); + return; + } auto* reg_ptr = reinterpret_cast(ptr); ptr += sizeof(iree_vm_ref_t); if (reg_ptr->type == ref_type_descriptor::type()) { @@ -563,9 +652,10 @@ template struct ParamUnpack> { using element_type = typename impl::remove_cvref::type; using storage_type = std::array; - static void Load(Status& status, params_ptr_t& ptr, storage_type& out_param) { + static void Load(Status& status, params_ptr_t& ptr, params_ptr_t limit, + storage_type& out_param) { for (size_t i = 0; i < S; ++i) { - ParamUnpack::Load(status, ptr, out_param[i]); + ParamUnpack::Load(status, ptr, limit, out_param[i]); } } }; @@ -574,16 +664,17 @@ struct ParamUnpack> { template struct ParamUnpack> { using storage_type = std::tuple::type...>; - static void Load(Status& status, params_ptr_t& ptr, storage_type& out_param) { - UnpackTuple(status, ptr, out_param, + static void Load(Status& status, params_ptr_t& ptr, params_ptr_t limit, + storage_type& out_param) { + UnpackTuple(status, ptr, limit, out_param, std::make_index_sequence()); } template - static void UnpackTuple(Status& status, params_ptr_t& ptr, + static void UnpackTuple(Status& status, params_ptr_t& ptr, params_ptr_t limit, storage_type& params, std::index_sequence) { impl::order_sequence{ (ParamUnpack>::type>:: - Load(status, ptr, std::get(params)), + Load(status, ptr, limit, std::get(params)), 0)...}; } }; @@ -596,12 +687,20 @@ template struct ParamUnpack, enable_if_not_primitive> { using element_type = typename impl::remove_cvref::type; using storage_type = std::vector; - static void Load(Status& status, params_ptr_t& ptr, storage_type& out_param) { + static void Load(Status& status, params_ptr_t& ptr, params_ptr_t limit, + storage_type& out_param) { + if (!status.ok()) return; + if (IREE_UNLIKELY(ptr + sizeof(int32_t) > limit)) { + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument buffer overflow reading span count"); + return; + } iree_host_size_t count = *reinterpret_cast(ptr); ptr += sizeof(int32_t); out_param.resize(count); for (iree_host_size_t i = 0; i < count; ++i) { - ParamUnpack::Load(status, ptr, out_param[i]); + ParamUnpack::Load(status, ptr, limit, out_param[i]); + if (!status.ok()) return; } } }; @@ -612,10 +711,24 @@ template struct ParamUnpack, enable_if_primitive> { using element_type = U; using storage_type = iree::span; - static void Load(Status& status, params_ptr_t& ptr, storage_type& out_param) { + static void Load(Status& status, params_ptr_t& ptr, params_ptr_t limit, + storage_type& out_param) { + if (!status.ok()) return; + if (IREE_UNLIKELY(ptr + sizeof(int32_t) > limit)) { + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument buffer overflow reading span count"); + return; + } iree_host_size_t count = *reinterpret_cast(ptr); ptr += sizeof(int32_t); ptr = align_ptr(ptr); + if (IREE_UNLIKELY(ptr + sizeof(element_type) * count > limit)) { + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument buffer overflow reading span elements" + " (count=%" PRIhsz ")", + count); + return; + } out_param = iree::span(reinterpret_cast(ptr), count); ptr += sizeof(element_type) * count; diff --git a/runtime/src/iree/vm/test/BUILD.bazel b/runtime/src/iree/vm/test/BUILD.bazel index 824492e31626..a81b2d09f6f4 100644 --- a/runtime/src/iree/vm/test/BUILD.bazel +++ b/runtime/src/iree/vm/test/BUILD.bazel @@ -34,6 +34,7 @@ iree_c_embed_data( ":assignment_ops_f32.vmfb", ":assignment_ops_f64.vmfb", ":assignment_ops_i64.vmfb", + ":async_ops.vmfb", ":buffer_ops.vmfb", ":call_ops.vmfb", ":comparison_ops.vmfb", @@ -61,25 +62,6 @@ iree_c_embed_data( h_file_output = "all_bytecode_modules.h", ) -iree_c_embed_data( - name = "async_bytecode_modules_c", - srcs = [ - ":async_ops.vmfb", - ], - c_file_output = "async_bytecode_modules.c", - flatten = True, - h_file_output = "async_bytecode_modules.h", -) - -iree_runtime_cc_library( - name = "async_ops_test_module", - hdrs = ["async_ops_test_module.h"], - deps = [ - "//runtime/src/iree/base", - "//runtime/src/iree/vm", - ], -) - iree_bytecode_module( name = "arithmetic_ops", src = "arithmetic_ops.mlir", diff --git a/runtime/src/iree/vm/test/CMakeLists.txt b/runtime/src/iree/vm/test/CMakeLists.txt index bdc7c787a8e1..8f00924e395f 100644 --- a/runtime/src/iree/vm/test/CMakeLists.txt +++ b/runtime/src/iree/vm/test/CMakeLists.txt @@ -26,6 +26,7 @@ iree_c_embed_data( "assignment_ops_f32.vmfb" "assignment_ops_f64.vmfb" "assignment_ops_i64.vmfb" + "async_ops.vmfb" "buffer_ops.vmfb" "call_ops.vmfb" "comparison_ops.vmfb" @@ -55,30 +56,6 @@ iree_c_embed_data( PUBLIC ) -iree_c_embed_data( - NAME - async_bytecode_modules_c - SRCS - "async_ops.vmfb" - C_FILE_OUTPUT - "async_bytecode_modules.c" - H_FILE_OUTPUT - "async_bytecode_modules.h" - FLATTEN - PUBLIC -) - -iree_cc_library( - NAME - async_ops_test_module - HDRS - "async_ops_test_module.h" - DEPS - iree::base - iree::vm - PUBLIC -) - iree_bytecode_module( NAME arithmetic_ops diff --git a/runtime/src/iree/vm/test/arithmetic_ops.mlir b/runtime/src/iree/vm/test/arithmetic_ops.mlir index 4ec12e1e6266..294180e934f2 100644 --- a/runtime/src/iree/vm/test/arithmetic_ops.mlir +++ b/runtime/src/iree/vm/test/arithmetic_ops.mlir @@ -7,7 +7,7 @@ vm.module @arithmetic_ops { vm.export @test_add_i32 vm.func @test_add_i32() { %c1 = vm.const.i32 1 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %v = vm.add.i32 %c1dno, %c1dno : i32 %c2 = vm.const.i32 2 vm.check.eq %v, %c2, "1+1=2" : i32 @@ -17,9 +17,9 @@ vm.module @arithmetic_ops { vm.export @test_sub_i32 vm.func @test_sub_i32() { %c1 = vm.const.i32 3 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.i32 2 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %v = vm.sub.i32 %c1dno, %c2dno : i32 %c3 = vm.const.i32 1 vm.check.eq %v, %c3, "3-2=1" : i32 @@ -29,7 +29,7 @@ vm.module @arithmetic_ops { vm.export @test_mul_i32 vm.func @test_mul_i32() { %c1 = vm.const.i32 2 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %v = vm.mul.i32 %c1dno, %c1dno : i32 %c2 = vm.const.i32 4 vm.check.eq %v, %c2, "2*2=4" : i32 @@ -39,9 +39,9 @@ vm.module @arithmetic_ops { vm.export @test_div_i32s vm.func @test_div_i32s() { %c1 = vm.const.i32 4 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.i32 -2 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %v = vm.div.i32.s %c1dno, %c2dno : i32 %c3 = vm.const.i32 -2 vm.check.eq %v, %c3, "4/-2=-2" : i32 @@ -51,9 +51,9 @@ vm.module @arithmetic_ops { vm.export @test_div_i32u vm.func @test_div_i32u() { %c1 = vm.const.i32 4 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.i32 2 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %v = vm.div.i32.u %c1dno, %c2dno : i32 %c3 = vm.const.i32 2 vm.check.eq %v, %c3, "4/2=2" : i32 @@ -63,9 +63,9 @@ vm.module @arithmetic_ops { vm.export @test_rem_i32s vm.func @test_rem_i32s() { %c1 = vm.const.i32 -3 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.i32 -2 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %v = vm.rem.i32.s %c1dno, %c2dno : i32 %c3 = vm.const.i32 -1 vm.check.eq %v, %c3, "-3%-2=-1" : i32 @@ -75,9 +75,9 @@ vm.module @arithmetic_ops { vm.export @test_rem_i32u vm.func @test_rem_i32u() { %c1 = vm.const.i32 3 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.i32 2 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %v = vm.rem.i32.u %c1dno, %c2dno : i32 %c3 = vm.const.i32 1 vm.check.eq %v, %c3, "3%2=1" : i32 @@ -87,11 +87,11 @@ vm.module @arithmetic_ops { vm.export @test_fma_i32 vm.func @test_fma_i32() { %c2 = vm.const.i32 2 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %c3 = vm.const.i32 3 - %c3dno = util.optimization_barrier %c3 : i32 + %c3dno = vm.optimization_barrier %c3 : i32 %c5 = vm.const.i32 5 - %c5dno = util.optimization_barrier %c5 : i32 + %c5dno = vm.optimization_barrier %c5 : i32 %v = vm.fma.i32 %c2dno, %c3dno, %c5dno : i32 %c11 = vm.const.i32 11 vm.check.eq %v, %c11, "2*3+5=11" : i32 @@ -101,7 +101,7 @@ vm.module @arithmetic_ops { vm.export @test_abs_i32 vm.func @test_abs_i32() { %cn1 = vm.const.i32 -1 - %cn1dno = util.optimization_barrier %cn1 : i32 + %cn1dno = vm.optimization_barrier %cn1 : i32 %v = vm.abs.i32 %cn1dno : i32 %c1 = vm.const.i32 1 vm.check.eq %v, %c1, "abs(-1)=1" : i32 @@ -111,9 +111,9 @@ vm.module @arithmetic_ops { vm.export @test_min_i32s vm.func @test_min_i32s() { %cn3 = vm.const.i32 -3 - %cn3dno = util.optimization_barrier %cn3 : i32 + %cn3dno = vm.optimization_barrier %cn3 : i32 %c2 = vm.const.i32 2 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %v = vm.min.i32.s %cn3dno, %c2dno : i32 vm.check.eq %v, %cn3, "smin(-3,2)=-3" : i32 vm.return @@ -122,9 +122,9 @@ vm.module @arithmetic_ops { vm.export @test_min_i32u vm.func @test_min_i32u() { %cn3 = vm.const.i32 -3 - %cn3dno = util.optimization_barrier %cn3 : i32 + %cn3dno = vm.optimization_barrier %cn3 : i32 %c2 = vm.const.i32 2 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %v = vm.min.i32.u %cn3dno, %c2dno : i32 vm.check.eq %v, %c2, "umin(-3,2)=2" : i32 vm.return @@ -133,9 +133,9 @@ vm.module @arithmetic_ops { vm.export @test_max_i32s vm.func @test_max_i32s() { %cn3 = vm.const.i32 -3 - %cn3dno = util.optimization_barrier %cn3 : i32 + %cn3dno = vm.optimization_barrier %cn3 : i32 %c2 = vm.const.i32 2 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %v = vm.max.i32.s %cn3dno, %c2dno : i32 vm.check.eq %v, %c2, "smax(-3,2)=2" : i32 vm.return @@ -144,9 +144,9 @@ vm.module @arithmetic_ops { vm.export @test_max_i32u vm.func @test_max_i32u() { %cn3 = vm.const.i32 -3 - %cn3dno = util.optimization_barrier %cn3 : i32 + %cn3dno = vm.optimization_barrier %cn3 : i32 %c2 = vm.const.i32 2 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %v = vm.max.i32.u %cn3dno, %c2dno : i32 vm.check.eq %v, %cn3, "umax(-3,2)=-3" : i32 vm.return @@ -155,7 +155,7 @@ vm.module @arithmetic_ops { vm.export @test_not_i32 vm.func @test_not_i32() { %c1 = vm.const.i32 0 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %v = vm.not.i32 %c1dno : i32 %c2 = vm.const.i32 -1 vm.check.eq %v, %c2, "~0=-1" : i32 @@ -165,9 +165,9 @@ vm.module @arithmetic_ops { vm.export @test_and_i32 vm.func @test_and_i32() { %c1 = vm.const.i32 5 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.i32 3 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %v = vm.and.i32 %c1dno, %c2dno : i32 %c3 = vm.const.i32 1 vm.check.eq %v, %c3, "5&3=1" : i32 @@ -177,9 +177,9 @@ vm.module @arithmetic_ops { vm.export @test_or_i32 vm.func @test_or_i32() { %c1 = vm.const.i32 5 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.i32 3 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %v = vm.or.i32 %c1dno, %c2dno : i32 %c3 = vm.const.i32 7 vm.check.eq %v, %c3, "5|3=7" : i32 @@ -189,9 +189,9 @@ vm.module @arithmetic_ops { vm.export @test_xor_i32 vm.func @test_xor_i32() { %c1 = vm.const.i32 5 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.i32 3 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %v = vm.xor.i32 %c1dno, %c2dno : i32 %c3 = vm.const.i32 6 vm.check.eq %v, %c3, "5^3=6" : i32 @@ -201,7 +201,7 @@ vm.module @arithmetic_ops { vm.export @test_ctlz_i32_const_zero vm.func @test_ctlz_i32_const_zero() { %c = vm.const.i32 0 - %cdno = util.optimization_barrier %c : i32 + %cdno = vm.optimization_barrier %c : i32 %actual = vm.ctlz.i32 %cdno : i32 %expected = vm.const.i32 32 vm.check.eq %actual, %expected, "ctlz(0)=32" : i32 @@ -211,7 +211,7 @@ vm.module @arithmetic_ops { vm.export @test_ctlz_i32_const_1 vm.func @test_ctlz_i32_const_1() { %c = vm.const.i32 1 - %cdno = util.optimization_barrier %c : i32 + %cdno = vm.optimization_barrier %c : i32 %actual = vm.ctlz.i32 %cdno : i32 %expected = vm.const.i32 31 vm.check.eq %actual, %expected, "ctlz(1)=31" : i32 @@ -221,7 +221,7 @@ vm.module @arithmetic_ops { vm.export @test_ctlz_i32_const_ffffffff vm.func @test_ctlz_i32_const_ffffffff() { %c = vm.const.i32 0xFFFFFFFF - %cdno = util.optimization_barrier %c : i32 + %cdno = vm.optimization_barrier %c : i32 %actual = vm.ctlz.i32 %cdno : i32 %expected = vm.const.i32 0 vm.check.eq %actual, %expected, "ctlz(0xFFFFFFFF)=0" : i32 diff --git a/runtime/src/iree/vm/test/arithmetic_ops_f32.mlir b/runtime/src/iree/vm/test/arithmetic_ops_f32.mlir index 2d3fd2ecaf4e..17fefb772796 100644 --- a/runtime/src/iree/vm/test/arithmetic_ops_f32.mlir +++ b/runtime/src/iree/vm/test/arithmetic_ops_f32.mlir @@ -7,7 +7,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_add_f32 vm.func @test_add_f32() { %c1 = vm.const.f32 1.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.add.f32 %c1dno, %c1dno : f32 %c2 = vm.const.f32 3.0 vm.check.eq %v, %c2, "1.5+1.5=3" : f32 @@ -17,9 +17,9 @@ vm.module @arithmetic_ops_f32 { vm.export @test_sub_f32 vm.func @test_sub_f32() { %c1 = vm.const.f32 3.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %c2 = vm.const.f32 2.5 - %c2dno = util.optimization_barrier %c2 : f32 + %c2dno = vm.optimization_barrier %c2 : f32 %v = vm.sub.f32 %c1dno, %c2dno : f32 %c3 = vm.const.f32 0.5 vm.check.eq %v, %c3, "3.0-2.5=0.5" : f32 @@ -29,7 +29,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_mul_f32 vm.func @test_mul_f32() { %c1 = vm.const.f32 2.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.mul.f32 %c1dno, %c1dno : f32 %c2 = vm.const.f32 6.25 vm.check.eq %v, %c2, "2.5*2.5=6.25" : f32 @@ -39,9 +39,9 @@ vm.module @arithmetic_ops_f32 { vm.export @test_div_f32 vm.func @test_div_f32() { %c1 = vm.const.f32 4.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %c2 = vm.const.f32 -2.0 - %c2dno = util.optimization_barrier %c2 : f32 + %c2dno = vm.optimization_barrier %c2 : f32 %v = vm.div.f32 %c1dno, %c2dno : f32 %c3 = vm.const.f32 -2.0 vm.check.eq %v, %c3, "4.0/-2.0=-2.0" : f32 @@ -51,9 +51,9 @@ vm.module @arithmetic_ops_f32 { vm.export @test_rem_f32 vm.func @test_rem_f32() { %c1 = vm.const.f32 -3.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %c2 = vm.const.f32 -2.0 - %c2dno = util.optimization_barrier %c2 : f32 + %c2dno = vm.optimization_barrier %c2 : f32 %v = vm.rem.f32 %c1dno, %c2dno : f32 %c3 = vm.const.f32 1.0 vm.check.eq %v, %c3, "-3.0%-2.0=1.0" : f32 @@ -63,11 +63,11 @@ vm.module @arithmetic_ops_f32 { vm.export @test_fma_f32 vm.func @test_fma_f32() { %c2 = vm.const.f32 2.0 - %c2dno = util.optimization_barrier %c2 : f32 + %c2dno = vm.optimization_barrier %c2 : f32 %c3 = vm.const.f32 3.0 - %c3dno = util.optimization_barrier %c3 : f32 + %c3dno = vm.optimization_barrier %c3 : f32 %c5 = vm.const.f32 5.0 - %c5dno = util.optimization_barrier %c5 : f32 + %c5dno = vm.optimization_barrier %c5 : f32 %v = vm.fma.f32 %c2dno, %c3dno, %c5dno : f32 %c11 = vm.const.f32 11.0 vm.check.eq %v, %c11, "2.0*3.0+5.0=11.0" : f32 @@ -77,7 +77,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_abs_f32 vm.func @test_abs_f32() { %c1 = vm.const.f32 -1.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.abs.f32 %c1dno : f32 %c2 = vm.const.f32 1.0 vm.check.eq %v, %c2, "abs(-1.0)=1.0" : f32 @@ -87,7 +87,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_neg_f32 vm.func @test_neg_f32() { %c1 = vm.const.f32 -1.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.neg.f32 %c1dno : f32 %c2 = vm.const.f32 1.0 vm.check.eq %v, %c2, "neg(-1.0)=1.0" : f32 @@ -97,7 +97,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_ceil_f32 vm.func @test_ceil_f32() { %c1 = vm.const.f32 1.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.ceil.f32 %c1dno : f32 %c2 = vm.const.f32 2.0 vm.check.eq %v, %c2, "ceil(1.5)=2.0" : f32 @@ -107,7 +107,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_floor_f32 vm.func @test_floor_f32() { %c15 = vm.const.f32 1.5 - %c15dno = util.optimization_barrier %c15 : f32 + %c15dno = vm.optimization_barrier %c15 : f32 %v = vm.floor.f32 %c15dno : f32 %c1 = vm.const.f32 1.0 vm.check.eq %v, %c1, "floor(1.5)=1.0" : f32 @@ -117,7 +117,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_round_f32 vm.func @test_round_f32() { %c15 = vm.const.f32 1.5 - %c15dno = util.optimization_barrier %c15 : f32 + %c15dno = vm.optimization_barrier %c15 : f32 %v = vm.round.f32 %c15dno : f32 %c2 = vm.const.f32 2.0 vm.check.eq %v, %c2, "round(1.5)=2.0" : f32 @@ -127,7 +127,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_round_f32_even vm.func @test_round_f32_even() { %c15 = vm.const.f32 1.5 - %c15dno = util.optimization_barrier %c15 : f32 + %c15dno = vm.optimization_barrier %c15 : f32 %v = vm.round.f32.even %c15dno : f32 %c2 = vm.const.f32 2.0 vm.check.eq %v, %c2, "roundeven(1.5)=2.0" : f32 @@ -137,9 +137,9 @@ vm.module @arithmetic_ops_f32 { vm.export @test_min_f32 vm.func @test_min_f32() { %cn3 = vm.const.f32 -3.0 - %cn3dno = util.optimization_barrier %cn3 : f32 + %cn3dno = vm.optimization_barrier %cn3 : f32 %cn2 = vm.const.f32 -2.0 - %cn2dno = util.optimization_barrier %cn2 : f32 + %cn2dno = vm.optimization_barrier %cn2 : f32 %v = vm.min.f32 %cn3dno, %cn2dno : f32 vm.check.eq %v, %cn3, "min(-3.0,-2.0)=-3.0" : f32 vm.return @@ -148,9 +148,9 @@ vm.module @arithmetic_ops_f32 { vm.export @test_max_f32 vm.func @test_max_f32() { %cn3 = vm.const.f32 -3.0 - %cn3dno = util.optimization_barrier %cn3 : f32 + %cn3dno = vm.optimization_barrier %cn3 : f32 %cn2 = vm.const.f32 -2.0 - %cn2dno = util.optimization_barrier %cn2 : f32 + %cn2dno = vm.optimization_barrier %cn2 : f32 %v = vm.max.f32 %cn3dno, %cn2dno : f32 vm.check.eq %v, %cn2, "max(-3.0,-2.0)=-2.0" : f32 vm.return @@ -159,7 +159,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_atan_f32 vm.func @test_atan_f32() { %c1 = vm.const.f32 1.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.atan.f32 %c1dno : f32 %c2 = vm.const.f32 0.7853981633974483 vm.check.eq %v, %c2, "atan(1.0)=0.7853981633974483" : f32 @@ -169,9 +169,9 @@ vm.module @arithmetic_ops_f32 { vm.export @test_atan2_f32 vm.func @test_atan2_f32() { %c1 = vm.const.f32 1.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %c2 = vm.const.f32 0.0 - %c2dno = util.optimization_barrier %c2 : f32 + %c2dno = vm.optimization_barrier %c2 : f32 %v = vm.atan2.f32 %c1dno, %c2dno : f32 %c3 = vm.const.f32 1.5707963267948966 vm.check.eq %v, %c3, "atan2(1.0,0.0)=1.5707963267948966" : f32 @@ -181,7 +181,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_cos_f32 vm.func @test_cos_f32() { %c1 = vm.const.f32 0.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cos.f32 %c1dno : f32 %c2 = vm.const.f32 0.8775825618903728 vm.check.eq %v, %c2, "cos(0.5)=0.8775825618903728" : f32 @@ -191,7 +191,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_sin_f32 vm.func @test_sin_f32() { %c1 = vm.const.f32 0.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.sin.f32 %c1dno : f32 %c2 = vm.const.f32 0.479425538604203 vm.check.eq %v, %c2, "sin(0.5)=0.479425538604203" : f32 @@ -201,7 +201,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_exp_f32 vm.func @test_exp_f32() { %c1 = vm.const.f32 1.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.exp.f32 %c1dno : f32 %c2 = vm.const.f32 2.718281828459045 vm.check.eq %v, %c2, "exp(1.0)=2.718281828459045" : f32 @@ -211,7 +211,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_exp2_f32 vm.func @test_exp2_f32() { %c1 = vm.const.f32 2.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.exp2.f32 %c1dno : f32 %c2 = vm.const.f32 4.0 vm.check.eq %v, %c2, "exp(2.0)=4.0" : f32 @@ -221,7 +221,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_expm1_f32 vm.func @test_expm1_f32() { %c1 = vm.const.f32 2.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.expm1.f32 %c1dno : f32 %c2 = vm.const.f32 6.38905609893065 vm.check.eq %v, %c2, "expm1(2.0)=6.38905609893065" : f32 @@ -231,7 +231,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_log_f32 vm.func @test_log_f32() { %c1 = vm.const.f32 10.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.log.f32 %c1dno : f32 %c2 = vm.const.f32 2.302585092994046 vm.check.eq %v, %c2, "log(10.0)=2.302585092994046" : f32 @@ -241,7 +241,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_log10_f32 vm.func @test_log10_f32() { %c1 = vm.const.f32 10.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.log10.f32 %c1dno : f32 %c2 = vm.const.f32 1.0 vm.check.eq %v, %c2, "log10(10.0)=1.0" : f32 @@ -251,7 +251,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_log1p_f32 vm.func @test_log1p_f32() { %c1 = vm.const.f32 10.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.log1p.f32 %c1dno : f32 %c2 = vm.const.f32 2.3978952727983707 vm.check.eq %v, %c2, "log1p(10.0)=2.3978952727983707" : f32 @@ -261,7 +261,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_log2_f32 vm.func @test_log2_f32() { %c1 = vm.const.f32 10.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.log2.f32 %c1dno : f32 %c2 = vm.const.f32 3.321928094887362 vm.check.eq %v, %c2, "log2(10.0)=3.321928094887362" : f32 @@ -271,9 +271,9 @@ vm.module @arithmetic_ops_f32 { vm.export @test_pow_f32 vm.func @test_pow_f32() { %c1 = vm.const.f32 3.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %c2 = vm.const.f32 2.0 - %c2dno = util.optimization_barrier %c2 : f32 + %c2dno = vm.optimization_barrier %c2 : f32 %v = vm.pow.f32 %c1dno, %c2dno : f32 %c3 = vm.const.f32 9.0 vm.check.eq %v, %c3, "pow(3.0,2.0)=9.0" : f32 @@ -283,7 +283,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_rsqrt_f32 vm.func @test_rsqrt_f32() { %c1 = vm.const.f32 4.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.rsqrt.f32 %c1dno : f32 %c2 = vm.const.f32 0.5 vm.check.eq %v, %c2, "rsqrt(4.0)=0.5" : f32 @@ -293,7 +293,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_sqrt_f32 vm.func @test_sqrt_f32() { %c1 = vm.const.f32 4.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.sqrt.f32 %c1dno : f32 %c2 = vm.const.f32 2.0 vm.check.eq %v, %c2, "sqrt(4.0)=2.0" : f32 @@ -303,7 +303,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_tanh_f32 vm.func @test_tanh_f32() { %c1 = vm.const.f32 0.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.tanh.f32 %c1dno : f32 %c2 = vm.const.f32 0.46211715726000974 vm.check.eq %v, %c2, "tanh(0.5)=0.46211715726000974" : f32 @@ -314,7 +314,7 @@ vm.module @arithmetic_ops_f32 { // vm.export @test_erf_f32 // vm.func @test_erf_f32() { // %c1 = vm.const.f32 0.5 - // %c1dno = util.optimization_barrier %c1 : f32 + // %c1dno = vm.optimization_barrier %c1 : f32 // %v = vm.erf.f32 %c1dno : f32 // %c2 = vm.const.f32 0.520499945 // vm.check.eq %v, %c2, "erf(0.5)=0.520499945" : f32 diff --git a/runtime/src/iree/vm/test/arithmetic_ops_f64.mlir b/runtime/src/iree/vm/test/arithmetic_ops_f64.mlir index 78c4df9b7086..91384cba6c4f 100644 --- a/runtime/src/iree/vm/test/arithmetic_ops_f64.mlir +++ b/runtime/src/iree/vm/test/arithmetic_ops_f64.mlir @@ -7,7 +7,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_add_f64 vm.func @test_add_f64() { %c1 = vm.const.f64 1.5 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.add.f64 %c1dno, %c1dno : f64 %c2 = vm.const.f64 3.0 vm.check.eq %v, %c2, "1.5+1.5=3" : f64 @@ -17,9 +17,9 @@ vm.module @arithmetic_ops_f64 { vm.export @test_sub_f64 vm.func @test_sub_f64() { %c1 = vm.const.f64 3.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %c2 = vm.const.f64 2.5 - %c2dno = util.optimization_barrier %c2 : f64 + %c2dno = vm.optimization_barrier %c2 : f64 %v = vm.sub.f64 %c1dno, %c2dno : f64 %c3 = vm.const.f64 0.5 vm.check.eq %v, %c3, "3.0-2.5=0.5" : f64 @@ -29,7 +29,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_mul_f64 vm.func @test_mul_f64() { %c1 = vm.const.f64 2.5 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.mul.f64 %c1dno, %c1dno : f64 %c2 = vm.const.f64 6.25 vm.check.eq %v, %c2, "2.5*2.5=6.25" : f64 @@ -39,9 +39,9 @@ vm.module @arithmetic_ops_f64 { vm.export @test_div_f64 vm.func @test_div_f64() { %c1 = vm.const.f64 4.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %c2 = vm.const.f64 -2.0 - %c2dno = util.optimization_barrier %c2 : f64 + %c2dno = vm.optimization_barrier %c2 : f64 %v = vm.div.f64 %c1dno, %c2dno : f64 %c3 = vm.const.f64 -2.0 vm.check.eq %v, %c3, "4.0/-2.0=-2.0" : f64 @@ -51,9 +51,9 @@ vm.module @arithmetic_ops_f64 { vm.export @test_rem_f64 vm.func @test_rem_f64() { %c1 = vm.const.f64 -3.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %c2 = vm.const.f64 -2.0 - %c2dno = util.optimization_barrier %c2 : f64 + %c2dno = vm.optimization_barrier %c2 : f64 %v = vm.rem.f64 %c1dno, %c2dno : f64 %c3 = vm.const.f64 1.0 vm.check.eq %v, %c3, "-3.0%-2.0=1.0" : f64 @@ -63,11 +63,11 @@ vm.module @arithmetic_ops_f64 { vm.export @test_fma_f64 vm.func @test_fma_f64() { %c2 = vm.const.f64 2.0 - %c2dno = util.optimization_barrier %c2 : f64 + %c2dno = vm.optimization_barrier %c2 : f64 %c3 = vm.const.f64 3.0 - %c3dno = util.optimization_barrier %c3 : f64 + %c3dno = vm.optimization_barrier %c3 : f64 %c5 = vm.const.f64 5.0 - %c5dno = util.optimization_barrier %c5 : f64 + %c5dno = vm.optimization_barrier %c5 : f64 %v = vm.fma.f64 %c2dno, %c3dno, %c5dno : f64 %c11 = vm.const.f64 11.0 vm.check.eq %v, %c11, "2.0*3.0+5.0=11.0" : f64 @@ -77,7 +77,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_abs_f64 vm.func @test_abs_f64() { %c1 = vm.const.f64 -1.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.abs.f64 %c1dno : f64 %c2 = vm.const.f64 1.0 vm.check.eq %v, %c2, "abs(-1.0)=1.0" : f64 @@ -87,7 +87,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_neg_f64 vm.func @test_neg_f64() { %c1 = vm.const.f64 -1.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.neg.f64 %c1dno : f64 %c2 = vm.const.f64 1.0 vm.check.eq %v, %c2, "neg(-1.0)=1.0" : f64 @@ -97,7 +97,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_ceil_f64 vm.func @test_ceil_f64() { %c1 = vm.const.f64 1.5 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.ceil.f64 %c1dno : f64 %c2 = vm.const.f64 2.0 vm.check.eq %v, %c2, "ceil(1.5)=2.0" : f64 @@ -107,7 +107,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_floor_f64 vm.func @test_floor_f64() { %c15 = vm.const.f64 1.5 - %c15dno = util.optimization_barrier %c15 : f64 + %c15dno = vm.optimization_barrier %c15 : f64 %v = vm.floor.f64 %c15dno : f64 %c1 = vm.const.f64 1.0 vm.check.eq %v, %c1, "floor(1.5)=1.0" : f64 @@ -117,7 +117,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_round_f64 vm.func @test_round_f64() { %c15 = vm.const.f64 1.5 - %c15dno = util.optimization_barrier %c15 : f64 + %c15dno = vm.optimization_barrier %c15 : f64 %v = vm.round.f64 %c15dno : f64 %c2 = vm.const.f64 2.0 vm.check.eq %v, %c2, "round(1.5)=2.0" : f64 @@ -127,7 +127,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_round_f64_even vm.func @test_round_f64_even() { %c15 = vm.const.f64 1.5 - %c15dno = util.optimization_barrier %c15 : f64 + %c15dno = vm.optimization_barrier %c15 : f64 %v = vm.round.f64.even %c15dno : f64 %c2 = vm.const.f64 2.0 vm.check.eq %v, %c2, "roundeven(1.5)=2.0" : f64 @@ -137,9 +137,9 @@ vm.module @arithmetic_ops_f64 { vm.export @test_min_f64 vm.func @test_min_f64() { %cn3 = vm.const.f64 -3.0 - %cn3dno = util.optimization_barrier %cn3 : f64 + %cn3dno = vm.optimization_barrier %cn3 : f64 %cn2 = vm.const.f64 -2.0 - %cn2dno = util.optimization_barrier %cn2 : f64 + %cn2dno = vm.optimization_barrier %cn2 : f64 %v = vm.min.f64 %cn3dno, %cn2dno : f64 vm.check.eq %v, %cn3, "min(-3.0,-2.0)=-3.0" : f64 vm.return @@ -148,9 +148,9 @@ vm.module @arithmetic_ops_f64 { vm.export @test_max_f64 vm.func @test_max_f64() { %cn3 = vm.const.f64 -3.0 - %cn3dno = util.optimization_barrier %cn3 : f64 + %cn3dno = vm.optimization_barrier %cn3 : f64 %cn2 = vm.const.f64 -2.0 - %cn2dno = util.optimization_barrier %cn2 : f64 + %cn2dno = vm.optimization_barrier %cn2 : f64 %v = vm.max.f64 %cn3dno, %cn2dno : f64 vm.check.eq %v, %cn2, "max(-3.0,-2.0)=-2.0" : f64 vm.return @@ -159,7 +159,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_atan_f64 vm.func @test_atan_f64() { %c1 = vm.const.f64 1.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.atan.f64 %c1dno : f64 %c2 = vm.const.f64 0.7853981633974483 vm.check.eq %v, %c2, "atan(1.0)=0.7853981633974483" : f64 @@ -169,9 +169,9 @@ vm.module @arithmetic_ops_f64 { vm.export @test_atan2_f64 vm.func @test_atan2_f64() { %c1 = vm.const.f64 1.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %c2 = vm.const.f64 0.0 - %c2dno = util.optimization_barrier %c2 : f64 + %c2dno = vm.optimization_barrier %c2 : f64 %v = vm.atan2.f64 %c1dno, %c2dno : f64 %c3 = vm.const.f64 1.5707963267948966 vm.check.eq %v, %c3, "atan2(1.0,0.0)=1.5707963267948966" : f64 @@ -181,7 +181,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_cos_f64 vm.func @test_cos_f64() { %c1 = vm.const.f64 0.5 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.cos.f64 %c1dno : f64 %c2 = vm.const.f64 0.8775825618903728 vm.check.eq %v, %c2, "cos(0.5)=0.8775825618903728" : f64 @@ -191,7 +191,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_sin_f64 vm.func @test_sin_f64() { %c1 = vm.const.f64 0.5 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.sin.f64 %c1dno : f64 %c2 = vm.const.f64 0.479425538604203 vm.check.eq %v, %c2, "sin(0.5)=0.479425538604203" : f64 @@ -201,7 +201,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_exp_f64 vm.func @test_exp_f64() { %c1 = vm.const.f64 1.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.exp.f64 %c1dno : f64 %c2 = vm.const.f64 2.718281828459045 vm.check.eq %v, %c2, "exp(1.0)=2.718281828459045" : f64 @@ -211,7 +211,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_exp2_f64 vm.func @test_exp2_f64() { %c1 = vm.const.f64 2.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.exp2.f64 %c1dno : f64 %c2 = vm.const.f64 4.0 vm.check.eq %v, %c2, "exp(2.0)=4.0" : f64 @@ -221,7 +221,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_expm1_f64 vm.func @test_expm1_f64() { %c1 = vm.const.f64 2.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.expm1.f64 %c1dno : f64 %c2 = vm.const.f64 6.38905609893065 vm.check.eq %v, %c2, "expm1(2.0)=6.38905609893065" : f64 @@ -231,7 +231,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_log_f64 vm.func @test_log_f64() { %c1 = vm.const.f64 10.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.log.f64 %c1dno : f64 %c2 = vm.const.f64 2.302585092994046 vm.check.eq %v, %c2, "log(10.0)=2.302585092994046" : f64 @@ -241,7 +241,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_log10_f64 vm.func @test_log10_f64() { %c1 = vm.const.f64 10.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.log10.f64 %c1dno : f64 %c2 = vm.const.f64 1.0 vm.check.eq %v, %c2, "log10(10.0)=1.0" : f64 @@ -251,7 +251,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_log1p_f64 vm.func @test_log1p_f64() { %c1 = vm.const.f64 10.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.log1p.f64 %c1dno : f64 %c2 = vm.const.f64 2.3978952727983707 vm.check.eq %v, %c2, "log1p(10.0)=2.3978952727983707" : f64 @@ -261,7 +261,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_log2_f64 vm.func @test_log2_f64() { %c1 = vm.const.f64 10.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.log2.f64 %c1dno : f64 %c2 = vm.const.f64 3.321928094887362 vm.check.eq %v, %c2, "log2(10.0)=3.321928094887362" : f64 @@ -271,9 +271,9 @@ vm.module @arithmetic_ops_f64 { vm.export @test_pow_f64 vm.func @test_pow_f64() { %c1 = vm.const.f64 3.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %c2 = vm.const.f64 2.0 - %c2dno = util.optimization_barrier %c2 : f64 + %c2dno = vm.optimization_barrier %c2 : f64 %v = vm.pow.f64 %c1dno, %c2dno : f64 %c3 = vm.const.f64 9.0 vm.check.eq %v, %c3, "pow(3.0,2.0)=9.0" : f64 @@ -283,7 +283,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_rsqrt_f64 vm.func @test_rsqrt_f64() { %c1 = vm.const.f64 4.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.rsqrt.f64 %c1dno : f64 %c2 = vm.const.f64 0.5 vm.check.eq %v, %c2, "rsqrt(4.0)=0.5" : f64 @@ -293,7 +293,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_sqrt_f64 vm.func @test_sqrt_f64() { %c1 = vm.const.f64 4.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.sqrt.f64 %c1dno : f64 %c2 = vm.const.f64 2.0 vm.check.eq %v, %c2, "sqrt(4.0)=2.0" : f64 @@ -303,7 +303,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_tanh_f64 vm.func @test_tanh_f64() { %c1 = vm.const.f64 0.5 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.tanh.f64 %c1dno : f64 %c2 = vm.const.f64 0.46211715726000974 vm.check.eq %v, %c2, "tanh(0.5)=0.46211715726000974" : f64 @@ -314,7 +314,7 @@ vm.module @arithmetic_ops_f64 { // vm.export @test_erf_f64 // vm.func @test_erf_f64() { // %c1 = vm.const.f64 0.5 - // %c1dno = util.optimization_barrier %c1 : f64 + // %c1dno = vm.optimization_barrier %c1 : f64 // %v = vm.erf.f64 %c1dno : f64 // %c2 = vm.const.f64 0.520499945 // vm.check.eq %v, %c2, "erf(0.5)=0.520499945" : f64 diff --git a/runtime/src/iree/vm/test/arithmetic_ops_i64.mlir b/runtime/src/iree/vm/test/arithmetic_ops_i64.mlir index b6cc8a2653c6..5658d9b158d8 100644 --- a/runtime/src/iree/vm/test/arithmetic_ops_i64.mlir +++ b/runtime/src/iree/vm/test/arithmetic_ops_i64.mlir @@ -7,7 +7,7 @@ vm.module @arithmetic_ops_i64 { vm.export @test_add_i64 vm.func @test_add_i64() { %c1 = vm.const.i64 1 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %v = vm.add.i64 %c1dno, %c1dno : i64 %c2 = vm.const.i64 2 vm.check.eq %v, %c2, "1+1=2" : i64 @@ -17,9 +17,9 @@ vm.module @arithmetic_ops_i64 { vm.export @test_sub_i64 vm.func @test_sub_i64() { %c1 = vm.const.i64 3 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %c2 = vm.const.i64 2 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %v = vm.sub.i64 %c1dno, %c2dno : i64 %c3 = vm.const.i64 1 vm.check.eq %v, %c3, "3-2=1" : i64 @@ -29,7 +29,7 @@ vm.module @arithmetic_ops_i64 { vm.export @test_mul_i64 vm.func @test_mul_i64() { %c1 = vm.const.i64 2 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %v = vm.mul.i64 %c1dno, %c1dno : i64 %c2 = vm.const.i64 4 vm.check.eq %v, %c2, "2*2=4" : i64 @@ -39,9 +39,9 @@ vm.module @arithmetic_ops_i64 { vm.export @test_div_i64s vm.func @test_div_i64s() { %c1 = vm.const.i64 4 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %c2 = vm.const.i64 -2 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %v = vm.div.i64.s %c1dno, %c2dno : i64 %c3 = vm.const.i64 -2 vm.check.eq %v, %c3, "4/-2=-2" : i64 @@ -51,9 +51,9 @@ vm.module @arithmetic_ops_i64 { vm.export @test_div_i64u vm.func @test_div_i64u() { %c1 = vm.const.i64 4 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %c2 = vm.const.i64 2 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %v = vm.div.i64.u %c1dno, %c2dno : i64 %c3 = vm.const.i64 2 vm.check.eq %v, %c3, "4/2=2" : i64 @@ -63,9 +63,9 @@ vm.module @arithmetic_ops_i64 { vm.export @test_rem_i64s vm.func @test_rem_i64s() { %c1 = vm.const.i64 -3 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %c2 = vm.const.i64 -2 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %v = vm.rem.i64.s %c1dno, %c2dno : i64 %c3 = vm.const.i64 -1 vm.check.eq %v, %c3, "-3%-2=-1" : i64 @@ -75,9 +75,9 @@ vm.module @arithmetic_ops_i64 { vm.export @test_rem_i64u vm.func @test_rem_i64u() { %c1 = vm.const.i64 3 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %c2 = vm.const.i64 2 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %v = vm.rem.i64.u %c1dno, %c2dno : i64 %c3 = vm.const.i64 1 vm.check.eq %v, %c3, "3%2=1" : i64 @@ -87,11 +87,11 @@ vm.module @arithmetic_ops_i64 { vm.export @test_fma_i64 vm.func @test_fma_i64() { %c2 = vm.const.i64 2 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %c3 = vm.const.i64 3 - %c3dno = util.optimization_barrier %c3 : i64 + %c3dno = vm.optimization_barrier %c3 : i64 %c5 = vm.const.i64 5 - %c5dno = util.optimization_barrier %c5 : i64 + %c5dno = vm.optimization_barrier %c5 : i64 %v = vm.fma.i64 %c2dno, %c3dno, %c5dno : i64 %c11 = vm.const.i64 11 vm.check.eq %v, %c11, "2*3+5=11" : i64 @@ -101,7 +101,7 @@ vm.module @arithmetic_ops_i64 { vm.export @test_abs_i64 vm.func @test_abs_i64() { %cn1 = vm.const.i64 -1 - %cn1dno = util.optimization_barrier %cn1 : i64 + %cn1dno = vm.optimization_barrier %cn1 : i64 %v = vm.abs.i64 %cn1dno : i64 %c1 = vm.const.i64 1 vm.check.eq %v, %c1, "abs(-1)=1" : i64 @@ -111,9 +111,9 @@ vm.module @arithmetic_ops_i64 { vm.export @test_min_i64s vm.func @test_min_i64s() { %cn3 = vm.const.i64 -3 - %cn3dno = util.optimization_barrier %cn3 : i64 + %cn3dno = vm.optimization_barrier %cn3 : i64 %c2 = vm.const.i64 2 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %v = vm.min.i64.s %cn3dno, %c2dno : i64 vm.check.eq %v, %cn3, "smin(-3,2)=-3" : i64 vm.return @@ -122,9 +122,9 @@ vm.module @arithmetic_ops_i64 { vm.export @test_min_i64u vm.func @test_min_i64u() { %cn3 = vm.const.i64 -3 - %cn3dno = util.optimization_barrier %cn3 : i64 + %cn3dno = vm.optimization_barrier %cn3 : i64 %c2 = vm.const.i64 2 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %v = vm.min.i64.u %cn3dno, %c2dno : i64 vm.check.eq %v, %c2, "umin(-3,2)=2" : i64 vm.return @@ -133,9 +133,9 @@ vm.module @arithmetic_ops_i64 { vm.export @test_max_i64s vm.func @test_max_i64s() { %cn3 = vm.const.i64 -3 - %cn3dno = util.optimization_barrier %cn3 : i64 + %cn3dno = vm.optimization_barrier %cn3 : i64 %c2 = vm.const.i64 2 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %v = vm.max.i64.s %cn3dno, %c2dno : i64 vm.check.eq %v, %c2, "smax(-3,2)=2" : i64 vm.return @@ -144,9 +144,9 @@ vm.module @arithmetic_ops_i64 { vm.export @test_max_i64u vm.func @test_max_i64u() { %cn3 = vm.const.i64 -3 - %cn3dno = util.optimization_barrier %cn3 : i64 + %cn3dno = vm.optimization_barrier %cn3 : i64 %c2 = vm.const.i64 2 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %v = vm.max.i64.u %cn3dno, %c2dno : i64 vm.check.eq %v, %cn3, "umax(-3,2)=-3" : i64 vm.return @@ -155,7 +155,7 @@ vm.module @arithmetic_ops_i64 { vm.export @test_not_i64 vm.func @test_not_i64() { %c1 = vm.const.i64 0 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %v = vm.not.i64 %c1dno : i64 %c2 = vm.const.i64 -1 vm.check.eq %v, %c2, "~0=-1" : i64 @@ -165,9 +165,9 @@ vm.module @arithmetic_ops_i64 { vm.export @test_and_i64 vm.func @test_and_i64() { %c1 = vm.const.i64 5 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %c2 = vm.const.i64 3 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %v = vm.and.i64 %c1dno, %c2dno : i64 %c3 = vm.const.i64 1 vm.check.eq %v, %c3, "5&3=1" : i64 @@ -177,9 +177,9 @@ vm.module @arithmetic_ops_i64 { vm.export @test_or_i64 vm.func @test_or_i64() { %c1 = vm.const.i64 5 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %c2 = vm.const.i64 3 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %v = vm.or.i64 %c1dno, %c2dno : i64 %c3 = vm.const.i64 7 vm.check.eq %v, %c3, "5|3=7" : i64 @@ -189,9 +189,9 @@ vm.module @arithmetic_ops_i64 { vm.export @test_xor_i64 vm.func @test_xor_i64() { %c1 = vm.const.i64 5 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %c2 = vm.const.i64 3 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %v = vm.xor.i64 %c1dno, %c2dno : i64 %c3 = vm.const.i64 6 vm.check.eq %v, %c3, "5^3=6" : i64 @@ -201,7 +201,7 @@ vm.module @arithmetic_ops_i64 { vm.export @test_ctlz_i64_const_zero vm.func @test_ctlz_i64_const_zero() { %c = vm.const.i64 0 - %cdno = util.optimization_barrier %c : i64 + %cdno = vm.optimization_barrier %c : i64 %actual = vm.ctlz.i64 %cdno : i64 %expected = vm.const.i64 64 vm.check.eq %actual, %expected, "ctlz(0)=64" : i64 @@ -211,7 +211,7 @@ vm.module @arithmetic_ops_i64 { vm.export @test_ctlz_i64_const_1 vm.func @test_ctlz_i64_const_1() { %c = vm.const.i64 1 - %cdno = util.optimization_barrier %c : i64 + %cdno = vm.optimization_barrier %c : i64 %actual = vm.ctlz.i64 %cdno : i64 %expected = vm.const.i64 63 vm.check.eq %actual, %expected, "ctlz(1)=63" : i64 @@ -221,7 +221,7 @@ vm.module @arithmetic_ops_i64 { vm.export @test_ctlz_i64_const_ffffffffffffffff vm.func @test_ctlz_i64_const_ffffffffffffffff() { %c = vm.const.i64 0xFFFFFFFFFFFFFFFF - %cdno = util.optimization_barrier %c : i64 + %cdno = vm.optimization_barrier %c : i64 %actual = vm.ctlz.i64 %cdno : i64 %expected = vm.const.i64 0 vm.check.eq %actual, %expected, "ctlz(0xFFFFFFFFFFFFFFFF)=0" : i64 diff --git a/runtime/src/iree/vm/test/assignment_ops.mlir b/runtime/src/iree/vm/test/assignment_ops.mlir index 891165da8bc3..c9fc08005d63 100644 --- a/runtime/src/iree/vm/test/assignment_ops.mlir +++ b/runtime/src/iree/vm/test/assignment_ops.mlir @@ -7,9 +7,9 @@ vm.module @assignment_ops { vm.export @test_select_i32 vm.func @test_select_i32() { %c0 = vm.const.i32 0 - %c0dno = util.optimization_barrier %c0 : i32 + %c0dno = vm.optimization_barrier %c0 : i32 %c1 = vm.const.i32 1 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %v1 = vm.select.i32 %c0dno, %c0dno, %c1dno : i32 vm.check.eq %v1, %c1, "0 ? 0 : 1 = 1" : i32 %v2 = vm.select.i32 %c1dno, %c0dno, %c1dno : i32 @@ -24,7 +24,7 @@ vm.module @assignment_ops { %c1 = vm.const.i32 1 %list1 = vm.list.alloc %c1 : (i32) -> !vm.list %cond = vm.const.i32 0 - %cond_dno = util.optimization_barrier %cond : i32 + %cond_dno = vm.optimization_barrier %cond : i32 %list = vm.select.ref %cond_dno, %list0, %list1 : !vm.list vm.check.eq %list, %list1, "0 ? list0 : list1 = list1" : !vm.list vm.return @@ -41,17 +41,17 @@ vm.module @assignment_ops { %c300 = vm.const.i32 300 %i0 = vm.const.i32 0 - %i0_dno = util.optimization_barrier %i0 : i32 + %i0_dno = vm.optimization_barrier %i0 : i32 %v0 = vm.switch.i32 %i0_dno[%c100, %c200] else %c300 : i32 vm.check.eq %v0, %c100, "index 0 is 100" : i32 %i1 = vm.const.i32 1 - %i1_dno = util.optimization_barrier %i1 : i32 + %i1_dno = vm.optimization_barrier %i1 : i32 %v1 = vm.switch.i32 %i1_dno[%c100, %c200] else %c300 : i32 vm.check.eq %v1, %c200, "index 1 is 200" : i32 %i2 = vm.const.i32 2 - %i2_dno = util.optimization_barrier %i2 : i32 + %i2_dno = vm.optimization_barrier %i2 : i32 %v2 = vm.switch.i32 %i2_dno[%c100, %c200] else %c300 : i32 vm.check.eq %v2, %c300, "index 2 (out of bounds) is default 300" : i32 diff --git a/runtime/src/iree/vm/test/assignment_ops_f32.mlir b/runtime/src/iree/vm/test/assignment_ops_f32.mlir index 6a0246c16f71..5a368da575fb 100644 --- a/runtime/src/iree/vm/test/assignment_ops_f32.mlir +++ b/runtime/src/iree/vm/test/assignment_ops_f32.mlir @@ -7,9 +7,9 @@ vm.module @assignment_ops_f32 { vm.export @test_select_f32 vm.func @test_select_f32() { %c0 = vm.const.i32 0 - %c0dno = util.optimization_barrier %c0 : i32 + %c0dno = vm.optimization_barrier %c0 : i32 %c1 = vm.const.i32 1 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.f32 0.0 %c3 = vm.const.f32 1.0 %v1 = vm.select.f32 %c0dno, %c2, %c3 : f32 @@ -30,17 +30,17 @@ vm.module @assignment_ops_f32 { %c300 = vm.const.f32 300.0 %i0 = vm.const.i32 0 - %i0_dno = util.optimization_barrier %i0 : i32 + %i0_dno = vm.optimization_barrier %i0 : i32 %v0 = vm.switch.f32 %i0_dno[%c100, %c200] else %c300 : f32 vm.check.eq %v0, %c100, "index 0 is 100" : f32 %i1 = vm.const.i32 1 - %i1_dno = util.optimization_barrier %i1 : i32 + %i1_dno = vm.optimization_barrier %i1 : i32 %v1 = vm.switch.f32 %i1_dno[%c100, %c200] else %c300 : f32 vm.check.eq %v1, %c200, "index 1 is 200" : f32 %i2 = vm.const.i32 2 - %i2_dno = util.optimization_barrier %i2 : i32 + %i2_dno = vm.optimization_barrier %i2 : i32 %v2 = vm.switch.f32 %i2_dno[%c100, %c200] else %c300 : f32 vm.check.eq %v2, %c300, "index 2 (out of bounds) is default 300" : f32 diff --git a/runtime/src/iree/vm/test/assignment_ops_f64.mlir b/runtime/src/iree/vm/test/assignment_ops_f64.mlir index 7f9d6443f22b..13f6c3820607 100644 --- a/runtime/src/iree/vm/test/assignment_ops_f64.mlir +++ b/runtime/src/iree/vm/test/assignment_ops_f64.mlir @@ -7,9 +7,9 @@ vm.module @assignment_ops_f64 { vm.export @test_select_f64 vm.func @test_select_f64() { %c0 = vm.const.i32 0 - %c0dno = util.optimization_barrier %c0 : i32 + %c0dno = vm.optimization_barrier %c0 : i32 %c1 = vm.const.i32 1 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.f64 0.0 %c3 = vm.const.f64 1.0 %v1 = vm.select.f64 %c0dno, %c2, %c3 : f64 @@ -30,17 +30,17 @@ vm.module @assignment_ops_f64 { %c300 = vm.const.f64 300.0 %i0 = vm.const.i32 0 - %i0_dno = util.optimization_barrier %i0 : i32 + %i0_dno = vm.optimization_barrier %i0 : i32 %v0 = vm.switch.f64 %i0_dno[%c100, %c200] else %c300 : f64 vm.check.eq %v0, %c100, "index 0 is 100" : f64 %i1 = vm.const.i32 1 - %i1_dno = util.optimization_barrier %i1 : i32 + %i1_dno = vm.optimization_barrier %i1 : i32 %v1 = vm.switch.f64 %i1_dno[%c100, %c200] else %c300 : f64 vm.check.eq %v1, %c200, "index 1 is 200" : f64 %i2 = vm.const.i32 2 - %i2_dno = util.optimization_barrier %i2 : i32 + %i2_dno = vm.optimization_barrier %i2 : i32 %v2 = vm.switch.f64 %i2_dno[%c100, %c200] else %c300 : f64 vm.check.eq %v2, %c300, "index 2 (out of bounds) is default 300" : f64 diff --git a/runtime/src/iree/vm/test/assignment_ops_i64.mlir b/runtime/src/iree/vm/test/assignment_ops_i64.mlir index a0d9bc18f03f..c2bd579ed7e9 100644 --- a/runtime/src/iree/vm/test/assignment_ops_i64.mlir +++ b/runtime/src/iree/vm/test/assignment_ops_i64.mlir @@ -7,9 +7,9 @@ vm.module @assignment_ops_i64 { vm.export @test_select_i64 vm.func @test_select_i64() { %c0 = vm.const.i32 0 - %c0dno = util.optimization_barrier %c0 : i32 + %c0dno = vm.optimization_barrier %c0 : i32 %c1 = vm.const.i32 1 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.i64 0 %c3 = vm.const.i64 1 %v1 = vm.select.i64 %c0dno, %c2, %c3 : i64 @@ -30,17 +30,17 @@ vm.module @assignment_ops_i64 { %c300 = vm.const.i64 300 %i0 = vm.const.i32 0 - %i0_dno = util.optimization_barrier %i0 : i32 + %i0_dno = vm.optimization_barrier %i0 : i32 %v0 = vm.switch.i64 %i0_dno[%c100, %c200] else %c300 : i64 vm.check.eq %v0, %c100, "index 0 is 100" : i64 %i1 = vm.const.i32 1 - %i1_dno = util.optimization_barrier %i1 : i32 + %i1_dno = vm.optimization_barrier %i1 : i32 %v1 = vm.switch.i64 %i1_dno[%c100, %c200] else %c300 : i64 vm.check.eq %v1, %c200, "index 1 is 200" : i64 %i2 = vm.const.i32 2 - %i2_dno = util.optimization_barrier %i2 : i32 + %i2_dno = vm.optimization_barrier %i2 : i32 %v2 = vm.switch.i64 %i2_dno[%c100, %c200] else %c300 : i64 vm.check.eq %v2, %c300, "index 2 (out of bounds) is default 300" : i64 diff --git a/runtime/src/iree/vm/test/async_ops.mlir b/runtime/src/iree/vm/test/async_ops.mlir index 75772158c5e7..ef75abe7e549 100644 --- a/runtime/src/iree/vm/test/async_ops.mlir +++ b/runtime/src/iree/vm/test/async_ops.mlir @@ -1,47 +1,68 @@ -// Tested by iree/vm/bytecode/dispatch_async_test.cc. - vm.module @async_ops { //===--------------------------------------------------------------------===// // vm.yield //===--------------------------------------------------------------------===// // Tests a simple straight-line yield sequence that requires 3 resumes. - // - // Expects a result of %arg0 + 3. - vm.export @yield_sequence - vm.func @yield_sequence(%arg0: i32) -> i32 { + // Starts with 100, adds 1 three times across yields, expects 103. + vm.export @test_yield_sequence + vm.func @test_yield_sequence() { %c1 = vm.const.i32 1 - %y0 = vm.add.i32 %arg0, %c1 : i32 - %y0_dno = util.optimization_barrier %y0 : i32 + %c100 = vm.const.i32 100 + %c100_dno = vm.optimization_barrier %c100 : i32 + %y0 = vm.add.i32 %c100_dno, %c1 : i32 + %y0_dno = vm.optimization_barrier %y0 : i32 vm.yield ^bb1 ^bb1: %y1 = vm.add.i32 %y0_dno, %c1 : i32 - %y1_dno = util.optimization_barrier %y1 : i32 + %y1_dno = vm.optimization_barrier %y1 : i32 vm.yield ^bb2 ^bb2: %y2 = vm.add.i32 %y1_dno, %c1 : i32 - %y2_dno = util.optimization_barrier %y2 : i32 + %y2_dno = vm.optimization_barrier %y2 : i32 vm.yield ^bb3 ^bb3: - vm.return %y2_dno : i32 + %c103 = vm.const.i32 103 + vm.check.eq %y2_dno, %c103, "100+1+1+1=103" : i32 + vm.return + } + + // Tests a yield with data-dependent control flow (true branch). + vm.export @test_yield_divergent_true + vm.func @test_yield_divergent_true() { + %c1 = vm.const.i32 1 + %c100 = vm.const.i32 100 + %c200 = vm.const.i32 200 + %cond = vm.cmp.nz.i32 %c1 : i32 + vm.cond_br %cond, ^true, ^false + ^true: + %v_true = vm.optimization_barrier %c100 : i32 + vm.yield ^check(%v_true : i32) + ^false: + %v_false = vm.optimization_barrier %c200 : i32 + vm.yield ^check(%v_false : i32) + ^check(%result : i32): + vm.check.eq %result, %c100, "cond=1 selects true branch" : i32 + vm.return } - // Tests a yield with data-dependent control, ensuring that we run the - // alternating branches and pass along branch args on resume. - // - // Expects a result of %arg0 ? %arg1 : %arg2. - vm.export @yield_divergent - vm.func @yield_divergent(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 { - %cond = vm.cmp.nz.i32 %arg0 : i32 + // Tests a yield with data-dependent control flow (false branch). + vm.export @test_yield_divergent_false + vm.func @test_yield_divergent_false() { + %c0 = vm.const.i32 0 + %c100 = vm.const.i32 100 + %c200 = vm.const.i32 200 + %cond = vm.cmp.nz.i32 %c0 : i32 vm.cond_br %cond, ^true, ^false ^true: - %arg1_dno = util.optimization_barrier %arg1 : i32 - vm.yield ^bb3(%arg1_dno : i32) + %v_true = vm.optimization_barrier %c100 : i32 + vm.yield ^check(%v_true : i32) ^false: - %arg2_dno = util.optimization_barrier %arg2 : i32 - vm.yield ^bb3(%arg2_dno: i32) - ^bb3(%result : i32): - vm.return %result : i32 + %v_false = vm.optimization_barrier %c200 : i32 + vm.yield ^check(%v_false : i32) + ^check(%result : i32): + vm.check.eq %result, %c200, "cond=0 selects false branch" : i32 + vm.return } //===--------------------------------------------------------------------===// @@ -53,15 +74,15 @@ vm.module @async_ops { vm.func private @yield_counter(%start : i32) -> i32 { %c1 = vm.const.i32 1 %v0 = vm.add.i32 %start, %c1 : i32 - %v0_dno = util.optimization_barrier %v0 : i32 + %v0_dno = vm.optimization_barrier %v0 : i32 vm.yield ^y1 ^y1: %v1 = vm.add.i32 %v0_dno, %c1 : i32 - %v1_dno = util.optimization_barrier %v1 : i32 + %v1_dno = vm.optimization_barrier %v1 : i32 vm.yield ^y2 ^y2: %v2 = vm.add.i32 %v1_dno, %c1 : i32 - %v2_dno = util.optimization_barrier %v2 : i32 + %v2_dno = vm.optimization_barrier %v2 : i32 vm.yield ^y3 ^y3: %v3 = vm.add.i32 %v2_dno, %c1 : i32 @@ -69,33 +90,37 @@ vm.module @async_ops { } // Tests calling an internal yieldable function. - // The callee yields 4 times, so we need 4 resumes. - // Expects result of 0 + 4 = 4. - vm.export @call_yieldable_internal attributes {emitc.exclude} - vm.func @call_yieldable_internal() -> i32 { + // The callee yields 4 times. Expects result of 0 + 4 = 4. + vm.export @test_call_yieldable_internal attributes {emitc.exclude} + vm.func @test_call_yieldable_internal() { %c0 = vm.const.i32 0 vm.call.yieldable @yield_counter(%c0) : (i32) -> ^resume(i32) ^resume(%result : i32): - vm.return %result : i32 + %c4 = vm.const.i32 4 + vm.check.eq %result, %c4, "0+4=4" : i32 + vm.return } // Internal function that takes an input and yields once, returning input + 1. vm.func private @yield_add_one(%arg0: i32) -> i32 { %c1 = vm.const.i32 1 %result = vm.add.i32 %arg0, %c1 : i32 - %result_dno = util.optimization_barrier %result : i32 + %result_dno = vm.optimization_barrier %result : i32 vm.yield ^done ^done: vm.return %result_dno : i32 } // Tests calling an internal yieldable function with an argument. - // Expects result of %arg0 + 1. - vm.export @call_yieldable_with_arg attributes {emitc.exclude} - vm.func @call_yieldable_with_arg(%arg0: i32) -> i32 { - vm.call.yieldable @yield_add_one(%arg0) : (i32) -> ^resume(i32) + // Expects result of 42 + 1 = 43. + vm.export @test_call_yieldable_with_arg attributes {emitc.exclude} + vm.func @test_call_yieldable_with_arg() { + %c42 = vm.const.i32 42 + vm.call.yieldable @yield_add_one(%c42) : (i32) -> ^resume(i32) ^resume(%result : i32): - vm.return %result : i32 + %c43 = vm.const.i32 43 + vm.check.eq %result, %c43, "42+1=43" : i32 + vm.return } //===--------------------------------------------------------------------===// @@ -112,42 +137,51 @@ vm.module @async_ops { vm.import private @yieldable_test.yield_variadic_sum(%args : i32 ..., %yield_count : i32) -> i32 attributes {vm.yield} // Test: call yieldable import with 3 yields. - // Expected: 3 DEFERRED returns, then OK with result = arg + 3 - vm.export @call_yieldable_import_yields_3 attributes {emitc.exclude} - vm.func @call_yieldable_import_yields_3(%arg0 : i32) -> i32 { + // Expects 100 + 3 = 103. + vm.export @test_call_yieldable_import_yields_3 attributes {emitc.exclude} + vm.func @test_call_yieldable_import_yields_3() { + %c100 = vm.const.i32 100 %c3 = vm.const.i32 3 - vm.call.yieldable @yieldable_test.yield_n(%arg0, %c3) : (i32, i32) -> ^resume(i32) + vm.call.yieldable @yieldable_test.yield_n(%c100, %c3) : (i32, i32) -> ^resume(i32) ^resume(%result : i32): - vm.return %result : i32 + %c103 = vm.const.i32 103 + vm.check.eq %result, %c103, "100+3=103" : i32 + vm.return } // Test: call yieldable import with 0 yields (synchronous). - // Expected: immediate OK with result = arg - vm.export @call_yieldable_import_yields_0 attributes {emitc.exclude} - vm.func @call_yieldable_import_yields_0(%arg0 : i32) -> i32 { + // Expects immediate return with 42. + vm.export @test_call_yieldable_import_yields_0 attributes {emitc.exclude} + vm.func @test_call_yieldable_import_yields_0() { + %c42 = vm.const.i32 42 %c0 = vm.const.i32 0 - vm.call.yieldable @yieldable_test.yield_n(%arg0, %c0) : (i32, i32) -> ^resume(i32) + vm.call.yieldable @yieldable_test.yield_n(%c42, %c0) : (i32, i32) -> ^resume(i32) ^resume(%result : i32): - vm.return %result : i32 + vm.check.eq %result, %c42, "42+0=42" : i32 + vm.return } // Test: call yieldable import after internal function call. - // This exercises Bug 2 fix: return_registers must be cleared after internal call. + // This exercises return_registers clearing after internal call. vm.func private @internal_add_10(%x : i32) -> i32 { %c10 = vm.const.i32 10 %r = vm.add.i32 %x, %c10 : i32 vm.return %r : i32 } - vm.export @call_yieldable_after_internal attributes {emitc.exclude} - vm.func @call_yieldable_after_internal(%arg0 : i32) -> i32 { + // Expects (5 + 10) + 2 = 17. + vm.export @test_call_yieldable_after_internal attributes {emitc.exclude} + vm.func @test_call_yieldable_after_internal() { + %c5 = vm.const.i32 5 // First call an internal function (sets return_registers). - %v1 = vm.call @internal_add_10(%arg0) : (i32) -> i32 - // Then call yieldable import (should see return_registers == NULL for begin). + %v1 = vm.call @internal_add_10(%c5) : (i32) -> i32 + // Then call yieldable import. %c2 = vm.const.i32 2 vm.call.yieldable @yieldable_test.yield_n(%v1, %c2) : (i32, i32) -> ^resume(i32) ^resume(%result : i32): - vm.return %result : i32 + %c17 = vm.const.i32 17 + vm.check.eq %result, %c17, "(5+10)+2=17" : i32 + vm.return } //===--------------------------------------------------------------------===// @@ -155,31 +189,31 @@ vm.module @async_ops { //===--------------------------------------------------------------------===// // Test: two sequential yieldable import calls in the same function. - // This catches bugs where the second call sees stale state from the first. - // Expected: 2 yields from first call + 3 yields from second call = 5 total - // Result: (arg + 2) + 3 = arg + 5 - vm.export @call_yieldable_import_sequential attributes {emitc.exclude} - vm.func @call_yieldable_import_sequential(%arg0 : i32) -> i32 { + // Expects (10 + 2) + 3 = 15. + vm.export @test_call_yieldable_import_sequential attributes {emitc.exclude} + vm.func @test_call_yieldable_import_sequential() { + %c10 = vm.const.i32 10 %c2 = vm.const.i32 2 %c3 = vm.const.i32 3 - // First yieldable import: yields 2 times, returns arg + 2 - vm.call.yieldable @yieldable_test.yield_n(%arg0, %c2) : (i32, i32) -> ^after_first(i32) + // First yieldable import: yields 2 times, returns 10 + 2 = 12 + vm.call.yieldable @yieldable_test.yield_n(%c10, %c2) : (i32, i32) -> ^after_first(i32) ^after_first(%v1 : i32): - // Second yieldable import: yields 3 times, returns v1 + 3 = arg + 5 + // Second yieldable import: yields 3 times, returns 12 + 3 = 15 vm.call.yieldable @yieldable_test.yield_n(%v1, %c3) : (i32, i32) -> ^done(i32) ^done(%result : i32): - vm.return %result : i32 + %c15 = vm.const.i32 15 + vm.check.eq %result, %c15, "(10+2)+3=15" : i32 + vm.return } // Test: yieldable import nested inside an internal yieldable function. // The internal function yields before and after calling the import. - // This creates the most complex frame stack scenario. vm.func private @yield_then_import_then_yield(%arg0 : i32) -> i32 { %c1 = vm.const.i32 1 %c2 = vm.const.i32 2 // Add 1 before yield %v0 = vm.add.i32 %arg0, %c1 : i32 - %v0_dno = util.optimization_barrier %v0 : i32 + %v0_dno = vm.optimization_barrier %v0 : i32 vm.yield ^after_first_yield ^after_first_yield: // Call yieldable import (yields 2 times) @@ -187,31 +221,34 @@ vm.module @async_ops { ^after_import(%v1 : i32): // Add 1 after import %v2 = vm.add.i32 %v1, %c1 : i32 - %v2_dno = util.optimization_barrier %v2 : i32 + %v2_dno = vm.optimization_barrier %v2 : i32 vm.yield ^final ^final: vm.return %v2_dno : i32 } - // Export that calls the nested yieldable function. - // Expected sequence: 1 yield (internal) + 2 yields (import) + 1 yield (internal) = 4 yields - // Result: ((arg + 1) + 2) + 1 = arg + 4 - vm.export @call_nested_yieldable attributes {emitc.exclude} - vm.func @call_nested_yieldable(%arg0 : i32) -> i32 { - vm.call.yieldable @yield_then_import_then_yield(%arg0) : (i32) -> ^resume(i32) + // Expects ((50 + 1) + 2) + 1 = 54. + vm.export @test_call_nested_yieldable attributes {emitc.exclude} + vm.func @test_call_nested_yieldable() { + %c50 = vm.const.i32 50 + vm.call.yieldable @yield_then_import_then_yield(%c50) : (i32) -> ^resume(i32) ^resume(%result : i32): - vm.return %result : i32 + %c54 = vm.const.i32 54 + vm.check.eq %result, %c54, "((50+1)+2)+1=54" : i32 + vm.return } // Test: stress test with many yields to catch state accumulation bugs. - // Calls yieldable import with high yield count. - // Expected: 10 yields, result = arg + 10 - vm.export @call_yieldable_import_stress attributes {emitc.exclude} - vm.func @call_yieldable_import_stress(%arg0 : i32) -> i32 { + // Expects 1000 + 10 = 1010. + vm.export @test_call_yieldable_import_stress attributes {emitc.exclude} + vm.func @test_call_yieldable_import_stress() { + %c1000 = vm.const.i32 1000 %c10 = vm.const.i32 10 - vm.call.yieldable @yieldable_test.yield_n(%arg0, %c10) : (i32, i32) -> ^resume(i32) + vm.call.yieldable @yieldable_test.yield_n(%c1000, %c10) : (i32, i32) -> ^resume(i32) ^resume(%result : i32): - vm.return %result : i32 + %c1010 = vm.const.i32 1010 + vm.check.eq %result, %c1010, "1000+10=1010" : i32 + vm.return } //===--------------------------------------------------------------------===// @@ -219,58 +256,75 @@ vm.module @async_ops { //===--------------------------------------------------------------------===// // Test: call variadic yieldable import with 2 args and 3 yields. - // Expected: 3 DEFERRED returns, then OK with result = (arg0 + arg1) + 3 - vm.export @call_variadic_yieldable_2args attributes {emitc.exclude} - vm.func @call_variadic_yieldable_2args(%arg0 : i32, %arg1 : i32) -> i32 { + // Expects (10 + 20) + 3 = 33. + vm.export @test_call_variadic_yieldable_2args attributes {emitc.exclude} + vm.func @test_call_variadic_yieldable_2args() { + %c10 = vm.const.i32 10 + %c20 = vm.const.i32 20 %c3 = vm.const.i32 3 - vm.call.variadic.yieldable @yieldable_test.yield_variadic_sum(%arg0, %arg1, %c3) {segment_sizes = dense<[2, 1]> : vector<2xi16>, segment_types = [i32, i32]} : (i32, i32, i32) -> ^resume(i32) + vm.call.variadic.yieldable @yieldable_test.yield_variadic_sum(%c10, %c20, %c3) {segment_sizes = dense<[2, 1]> : vector<2xi16>, segment_types = [i32, i32]} : (i32, i32, i32) -> ^resume(i32) ^resume(%result : i32): - vm.return %result : i32 + %c33 = vm.const.i32 33 + vm.check.eq %result, %c33, "(10+20)+3=33" : i32 + vm.return } // Test: call variadic yieldable import with 0 yields (synchronous). - // Expected: immediate OK with result = arg0 + arg1 + arg2 - vm.export @call_variadic_yieldable_0yields attributes {emitc.exclude} - vm.func @call_variadic_yieldable_0yields(%arg0 : i32, %arg1 : i32, %arg2 : i32) -> i32 { + // Expects 5 + 10 + 15 = 30. + vm.export @test_call_variadic_yieldable_0yields attributes {emitc.exclude} + vm.func @test_call_variadic_yieldable_0yields() { + %c5 = vm.const.i32 5 + %c10 = vm.const.i32 10 + %c15 = vm.const.i32 15 %c0 = vm.const.i32 0 - vm.call.variadic.yieldable @yieldable_test.yield_variadic_sum(%arg0, %arg1, %arg2, %c0) {segment_sizes = dense<[3, 1]> : vector<2xi16>, segment_types = [i32, i32]} : (i32, i32, i32, i32) -> ^resume(i32) + vm.call.variadic.yieldable @yieldable_test.yield_variadic_sum(%c5, %c10, %c15, %c0) {segment_sizes = dense<[3, 1]> : vector<2xi16>, segment_types = [i32, i32]} : (i32, i32, i32, i32) -> ^resume(i32) ^resume(%result : i32): - vm.return %result : i32 + %c30 = vm.const.i32 30 + vm.check.eq %result, %c30, "5+10+15+0=30" : i32 + vm.return } // Test: call variadic yieldable import with single arg. - // Expected: 2 yields, result = arg0 + 2 - vm.export @call_variadic_yieldable_1arg attributes {emitc.exclude} - vm.func @call_variadic_yieldable_1arg(%arg0 : i32) -> i32 { + // Expects 100 + 2 = 102. + vm.export @test_call_variadic_yieldable_1arg attributes {emitc.exclude} + vm.func @test_call_variadic_yieldable_1arg() { + %c100 = vm.const.i32 100 %c2 = vm.const.i32 2 - vm.call.variadic.yieldable @yieldable_test.yield_variadic_sum(%arg0, %c2) {segment_sizes = dense<[1, 1]> : vector<2xi16>, segment_types = [i32, i32]} : (i32, i32) -> ^resume(i32) + vm.call.variadic.yieldable @yieldable_test.yield_variadic_sum(%c100, %c2) {segment_sizes = dense<[1, 1]> : vector<2xi16>, segment_types = [i32, i32]} : (i32, i32) -> ^resume(i32) ^resume(%result : i32): - vm.return %result : i32 + %c102 = vm.const.i32 102 + vm.check.eq %result, %c102, "100+2=102" : i32 + vm.return } // Test: call variadic yieldable import with empty variadic list. - // Expected: 1 yield, result = 0 + 1 = 1 - vm.export @call_variadic_yieldable_empty attributes {emitc.exclude} - vm.func @call_variadic_yieldable_empty() -> i32 { + // Expects 0 + 1 = 1. + vm.export @test_call_variadic_yieldable_empty attributes {emitc.exclude} + vm.func @test_call_variadic_yieldable_empty() { %c1 = vm.const.i32 1 vm.call.variadic.yieldable @yieldable_test.yield_variadic_sum(%c1) {segment_sizes = dense<[0, 1]> : vector<2xi16>, segment_types = [i32, i32]} : (i32) -> ^resume(i32) ^resume(%result : i32): - vm.return %result : i32 + vm.check.eq %result, %c1, "0+1=1" : i32 + vm.return } // Test: two sequential variadic yieldable calls. - // Expected: 2 yields from first + 1 yield from second = 3 yields total - // Result: ((arg0 + arg1) + 2) + (arg2) + 1 = arg0 + arg1 + arg2 + 3 - vm.export @call_variadic_yieldable_sequential attributes {emitc.exclude} - vm.func @call_variadic_yieldable_sequential(%arg0 : i32, %arg1 : i32, %arg2 : i32) -> i32 { + // Expects ((10 + 20) + 2) + (32 + 5) + 1 = 38. + vm.export @test_call_variadic_yieldable_sequential attributes {emitc.exclude} + vm.func @test_call_variadic_yieldable_sequential() { %c1 = vm.const.i32 1 %c2 = vm.const.i32 2 - // First variadic yieldable: sum(arg0, arg1) + 2 yields - vm.call.variadic.yieldable @yieldable_test.yield_variadic_sum(%arg0, %arg1, %c2) {segment_sizes = dense<[2, 1]> : vector<2xi16>, segment_types = [i32, i32]} : (i32, i32, i32) -> ^after_first(i32) + %c5 = vm.const.i32 5 + %c10 = vm.const.i32 10 + %c20 = vm.const.i32 20 + // First variadic yieldable: sum(10, 20) + 2 yields = 32 + vm.call.variadic.yieldable @yieldable_test.yield_variadic_sum(%c10, %c20, %c2) {segment_sizes = dense<[2, 1]> : vector<2xi16>, segment_types = [i32, i32]} : (i32, i32, i32) -> ^after_first(i32) ^after_first(%v1 : i32): - // Second variadic yieldable: sum(v1, arg2) + 1 yield - vm.call.variadic.yieldable @yieldable_test.yield_variadic_sum(%v1, %arg2, %c1) {segment_sizes = dense<[2, 1]> : vector<2xi16>, segment_types = [i32, i32]} : (i32, i32, i32) -> ^done(i32) + // Second variadic yieldable: sum(32, 5) + 1 yield = 38 + vm.call.variadic.yieldable @yieldable_test.yield_variadic_sum(%v1, %c5, %c1) {segment_sizes = dense<[2, 1]> : vector<2xi16>, segment_types = [i32, i32]} : (i32, i32, i32) -> ^done(i32) ^done(%result : i32): - vm.return %result : i32 + %c38 = vm.const.i32 38 + vm.check.eq %result, %c38, "((10+20)+2)+((32+5)+1)=38" : i32 + vm.return } } diff --git a/runtime/src/iree/vm/test/buffer_ops.mlir b/runtime/src/iree/vm/test/buffer_ops.mlir index 74eebabcade9..b0114d78d1cb 100644 --- a/runtime/src/iree/vm/test/buffer_ops.mlir +++ b/runtime/src/iree/vm/test/buffer_ops.mlir @@ -16,8 +16,8 @@ vm.module @buffer_ops { vm.func @test_compare() { %rodata_a = vm.const.ref.rodata @rodata_cmp_3xi32_a : !vm.buffer %rodata_b = vm.const.ref.rodata @rodata_cmp_3xi32_b : !vm.buffer - %rodata_a_dno = util.optimization_barrier %rodata_a : !vm.buffer - %rodata_b_dno = util.optimization_barrier %rodata_b : !vm.buffer + %rodata_a_dno = vm.optimization_barrier %rodata_a : !vm.buffer + %rodata_b_dno = vm.optimization_barrier %rodata_b : !vm.buffer %c0 = vm.const.i64 0 %length = vm.buffer.length %rodata_a_dno : !vm.buffer -> i64 @@ -37,8 +37,8 @@ vm.module @buffer_ops { vm.func @test_compare_empty() { %rodata_a = vm.const.ref.rodata @rodata_cmp_3xi32_a : !vm.buffer %rodata_b = vm.const.ref.rodata @rodata_cmp_3xi32_b : !vm.buffer - %rodata_a_dno = util.optimization_barrier %rodata_a : !vm.buffer - %rodata_b_dno = util.optimization_barrier %rodata_b : !vm.buffer + %rodata_a_dno = vm.optimization_barrier %rodata_a : !vm.buffer + %rodata_b_dno = vm.optimization_barrier %rodata_b : !vm.buffer %c0 = vm.const.i64 0 %c2 = vm.const.i64 2 @@ -59,7 +59,7 @@ vm.module @buffer_ops { %c128 = vm.const.i64 128 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %c128, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer %buf_length = vm.buffer.length %buf_dno : !vm.buffer -> i64 @@ -74,7 +74,7 @@ vm.module @buffer_ops { %c0 = vm.const.i64 0 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %c0, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer %buf_length = vm.buffer.length %buf_dno : !vm.buffer -> i64 @@ -98,7 +98,7 @@ vm.module @buffer_ops { %c8 = vm.const.i64 8 %alignment = vm.const.i32 16 %buf = vm.buffer.clone %rodata, %c4, %c8, %alignment : !vm.buffer -> !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer // Compare the cloned range to the original. @@ -116,14 +116,14 @@ vm.module @buffer_ops { %c0 = vm.const.i64 0 %alignment = vm.const.i32 16 %buf0 = vm.buffer.alloc %c0, %alignment : !vm.buffer - %buf0_dno = util.optimization_barrier %buf0 : !vm.buffer + %buf0_dno = vm.optimization_barrier %buf0 : !vm.buffer vm.check.nz %buf0_dno, "!null" : !vm.buffer %buf0_length = vm.buffer.length %buf0_dno : !vm.buffer -> i64 vm.check.eq %c0, %buf0_length, "buffer length == 0" : i64 // Clone it all (or, clone nothing?). %buf1 = vm.buffer.clone %buf0_dno, %c0, %c0, %alignment : !vm.buffer -> !vm.buffer - %buf1_dno = util.optimization_barrier %buf1 : !vm.buffer + %buf1_dno = vm.optimization_barrier %buf1 : !vm.buffer vm.check.nz %buf1_dno, "!null" : !vm.buffer %buf1_length = vm.buffer.length %buf1_dno : !vm.buffer -> i64 vm.check.eq %c0, %buf1_length, "buffer length == 0" : i64 @@ -136,7 +136,7 @@ vm.module @buffer_ops { vm.func @fail_clone_out_of_range() { // Fetch source .rodata blob. %rodata = vm.const.ref.rodata @rodata_3xi32 : !vm.buffer - %rodata_dno = util.optimization_barrier %rodata : !vm.buffer + %rodata_dno = vm.optimization_barrier %rodata : !vm.buffer vm.check.nz %rodata_dno, "!null" : !vm.buffer // Try to clone off the end of the buffer. @@ -162,7 +162,7 @@ vm.module @buffer_ops { // Allocate target buffer. %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %rodata_length, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer // Copy the entire contents. @@ -185,7 +185,7 @@ vm.module @buffer_ops { %c4 = vm.const.i64 4 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %c4, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer // Copy the middle 4-byte element. @@ -208,7 +208,7 @@ vm.module @buffer_ops { %c128 = vm.const.i64 128 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %c128, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer // Try to clone off the end of the source buffer. @@ -225,7 +225,7 @@ vm.module @buffer_ops { %c128 = vm.const.i64 128 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %c128, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer // Try to clone off the end of the source buffer. @@ -244,7 +244,7 @@ vm.module @buffer_ops { %c8 = vm.const.i64 8 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %c8, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer // Try to clone off the end of the target buffer. @@ -261,7 +261,7 @@ vm.module @buffer_ops { %c8 = vm.const.i64 8 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %c8, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer // Try to clone off the end of the target buffer. @@ -286,7 +286,7 @@ vm.module @buffer_ops { %buffer_size = vm.mul.i64 %num_elements, %element_size : i64 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %buffer_size, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer // Fill the middle two elements. @@ -315,7 +315,7 @@ vm.module @buffer_ops { %buffer_size = vm.mul.i64 %num_elements, %element_size : i64 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %buffer_size, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer // Fill the middle two elements. @@ -344,7 +344,7 @@ vm.module @buffer_ops { %buffer_size = vm.mul.i64 %num_elements, %element_size : i64 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %buffer_size, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer // Fill the middle two elements. @@ -373,7 +373,7 @@ vm.module @buffer_ops { %buffer_size = vm.mul.i64 %num_elements, %element_size : i64 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %buffer_size, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer // Fill the middle two elements. @@ -402,7 +402,7 @@ vm.module @buffer_ops { %buffer_size = vm.mul.i64 %num_elements, %element_size : i64 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %buffer_size, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer // Fill the middle two elements. @@ -583,12 +583,12 @@ vm.module @buffer_ops { vm.export @test_store_i8 vm.func @test_store_i8() { %ref = vm.const.ref.rodata @test_store_i8_ref : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %ref_length = vm.buffer.length %ref_dno : !vm.buffer -> i64 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %ref_length, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer %c0 = vm.const.i64 0 %e0 = vm.const.i32 0 @@ -617,12 +617,12 @@ vm.module @buffer_ops { vm.export @test_store_i16 vm.func @test_store_i16() { %ref = vm.const.ref.rodata @test_store_i16_ref : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %ref_length = vm.buffer.length %ref_dno : !vm.buffer -> i64 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %ref_length, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer %c0 = vm.const.i64 0 %e0 = vm.const.i32 0 @@ -651,12 +651,12 @@ vm.module @buffer_ops { vm.export @test_store_i32 vm.func @test_store_i32() { %ref = vm.const.ref.rodata @test_store_i32_ref : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %ref_length = vm.buffer.length %ref_dno : !vm.buffer -> i64 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %ref_length, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer %c0 = vm.const.i64 0 %e0 = vm.const.i32 0 diff --git a/runtime/src/iree/vm/test/bytecode/BUILD.bazel b/runtime/src/iree/vm/test/bytecode/BUILD.bazel new file mode 100644 index 000000000000..1d1500715a51 --- /dev/null +++ b/runtime/src/iree/vm/test/bytecode/BUILD.bazel @@ -0,0 +1,37 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:build_defs.oss.bzl", "iree_cmake_extra_content", "iree_runtime_cc_test") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_cmake_extra_content( + content = """ +if(NOT IREE_BUILD_COMPILER OR NOT IREE_BUILD_TESTS) + return() +endif() +""", + inline = True, +) + +iree_runtime_cc_test( + name = "bytecode_module_test", + srcs = ["bytecode_module_test.cc"], + deps = [ + "//runtime/src/iree/base", + "//runtime/src/iree/testing:gtest", + "//runtime/src/iree/testing:gtest_main", + "//runtime/src/iree/vm", + "//runtime/src/iree/vm/bytecode:module", + "//runtime/src/iree/vm/test:all_bytecode_modules_c", + "//runtime/src/iree/vm/testing:test_runner", + "//runtime/src/iree/vm/testing:yieldable_test_module", + ], +) diff --git a/runtime/src/iree/vm/test/bytecode/CMakeLists.txt b/runtime/src/iree/vm/test/bytecode/CMakeLists.txt new file mode 100644 index 000000000000..4ac6fb773f02 --- /dev/null +++ b/runtime/src/iree/vm/test/bytecode/CMakeLists.txt @@ -0,0 +1,33 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# runtime/src/iree/vm/test/bytecode/BUILD.bazel # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +if(NOT IREE_BUILD_COMPILER OR NOT IREE_BUILD_TESTS) + return() +endif() + +iree_cc_test( + NAME + bytecode_module_test + SRCS + "bytecode_module_test.cc" + DEPS + iree::base + iree::testing::gtest + iree::testing::gtest_main + iree::vm + iree::vm::bytecode::module + iree::vm::test::all_bytecode_modules_c + iree::vm::testing::test_runner + iree::vm::testing::yieldable_test_module +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/runtime/src/iree/vm/test/bytecode/bytecode_module_test.cc b/runtime/src/iree/vm/test/bytecode/bytecode_module_test.cc new file mode 100644 index 000000000000..c9be80b79ca8 --- /dev/null +++ b/runtime/src/iree/vm/test/bytecode/bytecode_module_test.cc @@ -0,0 +1,94 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include "iree/base/api.h" +#include "iree/vm/api.h" +#include "iree/vm/bytecode/module.h" +#include "iree/vm/test/all_bytecode_modules.h" +#include "iree/vm/testing/test_runner.h" +#include "iree/vm/testing/yieldable_test_module.h" + +IREE_VM_TEST_RUNNER_STATIC_STORAGE(); + +namespace iree::vm::testing { +namespace { + +std::vector GetBytecodeTestParams() { + std::vector test_params; + + // Prerequisite factory for modules that import from yieldable_test. + auto yieldable_test_factory = [](iree_vm_instance_t* inst, + iree_vm_module_t** out_mod) { + return yieldable_test_module_create(inst, iree_allocator_system(), out_mod); + }; + + iree_vm_instance_t* instance = nullptr; + IREE_CHECK_OK(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, + iree_allocator_system(), &instance)); + + const struct iree_file_toc_t* module_file_toc = + all_bytecode_modules_c_create(); + for (size_t i = 0; i < all_bytecode_modules_c_size(); ++i) { + const auto& module_file = module_file_toc[i]; + std::string module_name(module_file.name); + + iree_vm_module_t* module = nullptr; + IREE_CHECK_OK(iree_vm_bytecode_module_create( + instance, IREE_VM_BYTECODE_MODULE_FLAG_NONE, + iree_const_byte_span_t{ + reinterpret_cast(module_file.data), + static_cast(module_file.size)}, + iree_allocator_null(), iree_allocator_system(), &module)); + + iree_vm_module_signature_t signature = iree_vm_module_signature(module); + for (iree_host_size_t j = 0; j < signature.export_function_count; ++j) { + iree_vm_function_t function; + IREE_CHECK_OK(iree_vm_module_lookup_function_by_ordinal( + module, IREE_VM_FUNCTION_LINKAGE_EXPORT, j, &function)); + iree_string_view_t function_name = iree_vm_function_name(&function); + std::string fn_name(function_name.data, function_name.size); + + // Capture module data for lambda. + const void* data = module_file.data; + iree_host_size_t size = module_file.size; + + std::vector prereqs; + prereqs.push_back(yieldable_test_factory); + + test_params.push_back({ + module_name, + fn_name, + [data, size](iree_vm_instance_t* inst, iree_vm_module_t** out_mod) { + return iree_vm_bytecode_module_create( + inst, IREE_VM_BYTECODE_MODULE_FLAG_NONE, + iree_const_byte_span_t{reinterpret_cast(data), + static_cast(size)}, + iree_allocator_null(), iree_allocator_system(), out_mod); + }, + /*expects_failure=*/fn_name.find("fail_") == 0, + /*prerequisite_modules=*/prereqs, + }); + } + iree_vm_module_release(module); + } + + iree_vm_instance_release(instance); + return test_params; +} + +class VMBytecodeTest : public VMTestRunner<> {}; + +IREE_VM_TEST_F(VMBytecodeTest) + +INSTANTIATE_TEST_SUITE_P(bytecode, VMBytecodeTest, + ::testing::ValuesIn(GetBytecodeTestParams()), + ::testing::PrintToStringParamName()); + +} // namespace +} // namespace iree::vm::testing diff --git a/runtime/src/iree/vm/test/call_ops.mlir b/runtime/src/iree/vm/test/call_ops.mlir index d79bafa258dd..d103c72b050f 100644 --- a/runtime/src/iree/vm/test/call_ops.mlir +++ b/runtime/src/iree/vm/test/call_ops.mlir @@ -42,10 +42,10 @@ vm.module @call_ops { vm.func private @test_call_r_v_preserve_ref() { %ref = vm.const.ref.zero : !vm.buffer %unused = vm.const.ref.rodata @buffer : !vm.buffer - %unusued_dno_1 = util.optimization_barrier %unused : !vm.buffer + %unusued_dno_1 = vm.optimization_barrier %unused : !vm.buffer vm.check.nz %unused : !vm.buffer vm.call @_r_v_preserve_reg(%ref, %unused) : (!vm.buffer, !vm.buffer) -> () - %unusued_dno_2 = util.optimization_barrier %unused : !vm.buffer + %unusued_dno_2 = vm.optimization_barrier %unused : !vm.buffer vm.check.nz %unusued_dno_2 : !vm.buffer vm.return } @@ -61,7 +61,7 @@ vm.module @call_ops { vm.export @test_call_v_r vm.func @test_call_v_r() { %ref = vm.const.ref.zero : !vm.ref - %ref_dno = util.optimization_barrier %ref : !vm.ref + %ref_dno = vm.optimization_barrier %ref : !vm.ref %res = vm.call @_v_r() : () -> (!vm.ref) vm.check.eq %ref_dno, %res, "_v_r()=NULL" : !vm.ref vm.return @@ -91,21 +91,21 @@ vm.module @call_ops { vm.func @_r_v(%arg : !vm.ref) attributes {inlining_policy = #util.inline.never} { %ref = vm.const.ref.zero : !vm.ref - %ref_dno = util.optimization_barrier %ref : !vm.ref + %ref_dno = vm.optimization_barrier %ref : !vm.ref vm.check.eq %arg, %ref_dno, "Expected %arg to be NULL" : !vm.ref vm.return } vm.func @_r_v_reuse_reg(%arg : !vm.ref, %unused : !vm.ref) attributes {inlining_policy = #util.inline.never} { %ref = vm.const.ref.zero : !vm.ref - %ref_dno = util.optimization_barrier %ref : !vm.ref + %ref_dno = vm.optimization_barrier %ref : !vm.ref vm.check.eq %arg, %ref_dno, "Expected %arg to be NULL" : !vm.ref vm.return } vm.func @_r_v_preserve_reg(%arg1 : !vm.ref, %arg2 : !vm.ref) attributes {inlining_policy = #util.inline.never} { %ref = vm.const.ref.zero : !vm.ref - %ref_dno = util.optimization_barrier %ref : !vm.ref + %ref_dno = vm.optimization_barrier %ref : !vm.ref vm.check.eq %arg1, %ref_dno, "Expected %arg1 to be NULL" : !vm.ref vm.check.nz %arg2, "Expected %arg2 to be not NULL" : !vm.ref vm.return diff --git a/runtime/src/iree/vm/test/comparison_ops.mlir b/runtime/src/iree/vm/test/comparison_ops.mlir index f0095452e806..a25c92207c9c 100644 --- a/runtime/src/iree/vm/test/comparison_ops.mlir +++ b/runtime/src/iree/vm/test/comparison_ops.mlir @@ -7,9 +7,9 @@ vm.module @comparison_ops { vm.export @test_cmp_lt_s_0 vm.func @test_cmp_lt_s_0() { %lhs = vm.const.i32 2 - %lhs_dno = util.optimization_barrier %lhs : i32 + %lhs_dno = vm.optimization_barrier %lhs : i32 %rhs = vm.const.i32 -2 - %rhs_dno = util.optimization_barrier %rhs : i32 + %rhs_dno = vm.optimization_barrier %rhs : i32 %actual = vm.cmp.lt.i32.s %lhs_dno, %rhs_dno : i32 %expected = vm.const.i32 0 vm.check.eq %actual, %expected, "2 < -2" : i32 @@ -19,9 +19,9 @@ vm.module @comparison_ops { vm.export @test_cmp_lt_s_1 vm.func @test_cmp_lt_s_1() { %lhs = vm.const.i32 -2 - %lhs_dno = util.optimization_barrier %lhs : i32 + %lhs_dno = vm.optimization_barrier %lhs : i32 %rhs = vm.const.i32 2 - %rhs_dno = util.optimization_barrier %rhs : i32 + %rhs_dno = vm.optimization_barrier %rhs : i32 %actual = vm.cmp.lt.i32.s %lhs_dno, %rhs_dno : i32 %expected = vm.const.i32 1 vm.check.eq %actual, %expected, "-2 < 2" : i32 @@ -32,9 +32,9 @@ vm.module @comparison_ops { vm.export @test_cmp_lt_s_2 vm.func @test_cmp_lt_s_2() { %lhs = vm.const.i32 4294967295 - %lhs_dno = util.optimization_barrier %lhs : i32 + %lhs_dno = vm.optimization_barrier %lhs : i32 %rhs = vm.const.i32 2 - %rhs_dno = util.optimization_barrier %rhs : i32 + %rhs_dno = vm.optimization_barrier %rhs : i32 %actual = vm.cmp.lt.i32.s %lhs_dno, %rhs_dno : i32 %expected = vm.const.i32 1 vm.check.eq %actual, %expected, "4294967295 (UINT_MAX) < 2" : i32 @@ -48,9 +48,9 @@ vm.module @comparison_ops { vm.export @test_cmp_lt_u_0 vm.func @test_cmp_lt_u_0() { %lhs = vm.const.i32 2 - %lhs_dno = util.optimization_barrier %lhs : i32 + %lhs_dno = vm.optimization_barrier %lhs : i32 %rhs = vm.const.i32 -2 - %rhs_dno = util.optimization_barrier %rhs : i32 + %rhs_dno = vm.optimization_barrier %rhs : i32 %actual = vm.cmp.lt.i32.u %lhs_dno, %rhs_dno : i32 %expected = vm.const.i32 1 vm.check.eq %actual, %expected, "2 < -2 (as unsigned)" : i32 @@ -60,9 +60,9 @@ vm.module @comparison_ops { vm.export @test_cmp_lt_u_1 vm.func @test_cmp_lt_u_1() { %lhs = vm.const.i32 -2 - %lhs_dno = util.optimization_barrier %lhs : i32 + %lhs_dno = vm.optimization_barrier %lhs : i32 %rhs = vm.const.i32 2 - %rhs_dno = util.optimization_barrier %rhs : i32 + %rhs_dno = vm.optimization_barrier %rhs : i32 %actual = vm.cmp.lt.i32.u %lhs_dno, %rhs_dno : i32 %expected = vm.const.i32 0 vm.check.eq %actual, %expected, "-2 < 2 (as unsigned)" : i32 @@ -72,9 +72,9 @@ vm.module @comparison_ops { vm.export @test_cmp_lt_u_2 vm.func @test_cmp_lt_u_2() { %lhs = vm.const.i32 4294967295 - %lhs_dno = util.optimization_barrier %lhs : i32 + %lhs_dno = vm.optimization_barrier %lhs : i32 %rhs = vm.const.i32 2 - %rhs_dno = util.optimization_barrier %rhs : i32 + %rhs_dno = vm.optimization_barrier %rhs : i32 %actual = vm.cmp.lt.i32.u %lhs_dno, %rhs_dno : i32 %expected = vm.const.i32 0 vm.check.eq %actual, %expected, "4294967295 (UINT_MAX) < 2 (as unsigned)" : i32 @@ -94,9 +94,9 @@ vm.module @comparison_ops { %false = vm.const.i32 0 %cn2 = vm.const.i32 -2 - %cn2_dno = util.optimization_barrier %cn2 : i32 + %cn2_dno = vm.optimization_barrier %cn2 : i32 %c2 = vm.const.i32 2 - %c2_dno = util.optimization_barrier %c2 : i32 + %c2_dno = vm.optimization_barrier %c2 : i32 %cmp_0 = vm.cmp.lte.i32.s %cn2_dno, %c2_dno : i32 vm.check.eq %cmp_0, %true, "-2 <= 2" : i32 @@ -121,9 +121,9 @@ vm.module @comparison_ops { %false = vm.const.i32 0 %cn2 = vm.const.i32 -2 - %cn2_dno = util.optimization_barrier %cn2 : i32 + %cn2_dno = vm.optimization_barrier %cn2 : i32 %c2 = vm.const.i32 2 - %c2_dno = util.optimization_barrier %c2 : i32 + %c2_dno = vm.optimization_barrier %c2 : i32 %cmp_0 = vm.cmp.gt.i32.s %cn2_dno, %c2_dno : i32 vm.check.eq %cmp_0, %false, "-2 > 2" : i32 @@ -148,9 +148,9 @@ vm.module @comparison_ops { %false = vm.const.i32 0 %cn2 = vm.const.i32 -2 - %cn2_dno = util.optimization_barrier %cn2 : i32 + %cn2_dno = vm.optimization_barrier %cn2 : i32 %c2 = vm.const.i32 2 - %c2_dno = util.optimization_barrier %c2 : i32 + %c2_dno = vm.optimization_barrier %c2 : i32 %cmp_0 = vm.cmp.gte.i32.s %cn2_dno, %c2_dno : i32 vm.check.eq %cmp_0, %false, "-2 >= 2" : i32 diff --git a/runtime/src/iree/vm/test/comparison_ops_f32.mlir b/runtime/src/iree/vm/test/comparison_ops_f32.mlir index 363a02e50638..3d074eddf1ac 100644 --- a/runtime/src/iree/vm/test/comparison_ops_f32.mlir +++ b/runtime/src/iree/vm/test/comparison_ops_f32.mlir @@ -7,9 +7,9 @@ vm.module @comparison_ops_f32 { vm.export @test_cmp_lt_0_f32 vm.func @test_cmp_lt_0_f32() { %lhs = vm.const.f32 4.0 - %lhs_dno = util.optimization_barrier %lhs : f32 + %lhs_dno = vm.optimization_barrier %lhs : f32 %rhs = vm.const.f32 -4.0 - %rhs_dno = util.optimization_barrier %rhs : f32 + %rhs_dno = vm.optimization_barrier %rhs : f32 %actual = vm.cmp.lt.f32.o %lhs_dno, %rhs_dno : f32 %expected = vm.const.i32 0 vm.check.eq %actual, %expected, "4.0 < -4.0" : i32 @@ -19,9 +19,9 @@ vm.module @comparison_ops_f32 { vm.export @test_cmp_lt_1_f32 vm.func @test_cmp_lt_1_f32() { %lhs = vm.const.f32 -4.0 - %lhs_dno = util.optimization_barrier %lhs : f32 + %lhs_dno = vm.optimization_barrier %lhs : f32 %rhs = vm.const.f32 4.0 - %rhs_dno = util.optimization_barrier %rhs : f32 + %rhs_dno = vm.optimization_barrier %rhs : f32 %actual = vm.cmp.lt.f32.o %lhs_dno, %rhs_dno : f32 %expected = vm.const.i32 1 vm.check.eq %actual, %expected, "-4.0 < 4.0" : i32 @@ -41,9 +41,9 @@ vm.module @comparison_ops_f32 { %false = vm.const.i32 0 %cn2 = vm.const.f32 -2.0 - %cn2_dno = util.optimization_barrier %cn2 : f32 + %cn2_dno = vm.optimization_barrier %cn2 : f32 %c2 = vm.const.f32 2.0 - %c2_dno = util.optimization_barrier %c2 : f32 + %c2_dno = vm.optimization_barrier %c2 : f32 %cmp_0 = vm.cmp.eq.f32.near %cn2_dno, %c2_dno : f32 vm.check.eq %cmp_0, %false, "-2 !~ 2" : i32 @@ -56,9 +56,9 @@ vm.module @comparison_ops_f32 { // off by 84 ULPs, arbitrary threshold sets these as "near enough" %c1a = vm.const.f32 1.00002 - %c1a_dno = util.optimization_barrier %c1a : f32 + %c1a_dno = vm.optimization_barrier %c1a : f32 %c1b = vm.const.f32 1.00003 - %c1b_dno = util.optimization_barrier %c1b : f32 + %c1b_dno = vm.optimization_barrier %c1b : f32 %cmp_4 = vm.cmp.eq.f32.near %c1a_dno, %c1b_dno : f32 vm.check.eq %cmp_4, %true, "1.00002 ~ 1.00003" : i32 @@ -74,9 +74,9 @@ vm.module @comparison_ops_f32 { %false = vm.const.i32 0 %cn2 = vm.const.f32 -2.0 - %cn2_dno = util.optimization_barrier %cn2 : f32 + %cn2_dno = vm.optimization_barrier %cn2 : f32 %c2 = vm.const.f32 2.0 - %c2_dno = util.optimization_barrier %c2 : f32 + %c2_dno = vm.optimization_barrier %c2 : f32 %cmp_0 = vm.cmp.lte.f32.o %cn2_dno, %c2_dno : f32 vm.check.eq %cmp_0, %true, "-2 <= 2" : i32 @@ -94,9 +94,9 @@ vm.module @comparison_ops_f32 { %false = vm.const.i32 0 %cn2 = vm.const.f32 -2.0 - %cn2_dno = util.optimization_barrier %cn2 : f32 + %cn2_dno = vm.optimization_barrier %cn2 : f32 %c2 = vm.const.f32 2.0 - %c2_dno = util.optimization_barrier %c2 : f32 + %c2_dno = vm.optimization_barrier %c2 : f32 %cmp_0 = vm.cmp.gt.f32.o %cn2_dno, %c2_dno : f32 vm.check.eq %cmp_0, %false, "-2 > 2" : i32 @@ -114,9 +114,9 @@ vm.module @comparison_ops_f32 { %false = vm.const.i32 0 %cn2 = vm.const.f32 -2.0 - %cn2_dno = util.optimization_barrier %cn2 : f32 + %cn2_dno = vm.optimization_barrier %cn2 : f32 %c2 = vm.const.f32 2.0 - %c2_dno = util.optimization_barrier %c2 : f32 + %c2_dno = vm.optimization_barrier %c2 : f32 %cmp_0 = vm.cmp.gte.f32.o %cn2_dno, %c2_dno : f32 vm.check.eq %cmp_0, %false, "-2 >= 2" : i32 diff --git a/runtime/src/iree/vm/test/comparison_ops_f64.mlir b/runtime/src/iree/vm/test/comparison_ops_f64.mlir index fb7a67f95332..01b451774fa0 100644 --- a/runtime/src/iree/vm/test/comparison_ops_f64.mlir +++ b/runtime/src/iree/vm/test/comparison_ops_f64.mlir @@ -7,9 +7,9 @@ vm.module @comparison_ops_f64 { vm.export @test_cmp_lt_0_f64 vm.func @test_cmp_lt_0_f64() { %lhs = vm.const.f64 4.0 - %lhs_dno = util.optimization_barrier %lhs : f64 + %lhs_dno = vm.optimization_barrier %lhs : f64 %rhs = vm.const.f64 -4.0 - %rhs_dno = util.optimization_barrier %rhs : f64 + %rhs_dno = vm.optimization_barrier %rhs : f64 %actual = vm.cmp.lt.f64.o %lhs_dno, %rhs_dno : f64 %expected = vm.const.i32 0 vm.check.eq %actual, %expected, "4.0 < -4.0" : i32 @@ -19,9 +19,9 @@ vm.module @comparison_ops_f64 { vm.export @test_cmp_lt_1_f64 vm.func @test_cmp_lt_1_f64() { %lhs = vm.const.f64 -4.0 - %lhs_dno = util.optimization_barrier %lhs : f64 + %lhs_dno = vm.optimization_barrier %lhs : f64 %rhs = vm.const.f64 4.0 - %rhs_dno = util.optimization_barrier %rhs : f64 + %rhs_dno = vm.optimization_barrier %rhs : f64 %actual = vm.cmp.lt.f64.o %lhs_dno, %rhs_dno : f64 %expected = vm.const.i32 1 vm.check.eq %actual, %expected, "-4.0 < 4.0" : i32 @@ -41,9 +41,9 @@ vm.module @comparison_ops_f64 { %false = vm.const.i32 0 %cn2 = vm.const.f64 -2.0 - %cn2_dno = util.optimization_barrier %cn2 : f64 + %cn2_dno = vm.optimization_barrier %cn2 : f64 %c2 = vm.const.f64 2.0 - %c2_dno = util.optimization_barrier %c2 : f64 + %c2_dno = vm.optimization_barrier %c2 : f64 %cmp_0 = vm.cmp.eq.f64.near %cn2_dno, %c2_dno : f64 vm.check.eq %cmp_0, %false, "-2 !~ 2" : i32 @@ -56,9 +56,9 @@ vm.module @comparison_ops_f64 { // off by 84 ULPs, arbitrary threshold sets these as "near enough" %c1a = vm.const.f64 1.00002 - %c1a_dno = util.optimization_barrier %c1a : f64 + %c1a_dno = vm.optimization_barrier %c1a : f64 %c1b = vm.const.f64 1.00003 - %c1b_dno = util.optimization_barrier %c1b : f64 + %c1b_dno = vm.optimization_barrier %c1b : f64 %cmp_4 = vm.cmp.eq.f64.near %c1a_dno, %c1b_dno : f64 vm.check.eq %cmp_4, %true, "1.00002 ~ 1.00003" : i32 @@ -74,9 +74,9 @@ vm.module @comparison_ops_f64 { %false = vm.const.i32 0 %cn2 = vm.const.f64 -2.0 - %cn2_dno = util.optimization_barrier %cn2 : f64 + %cn2_dno = vm.optimization_barrier %cn2 : f64 %c2 = vm.const.f64 2.0 - %c2_dno = util.optimization_barrier %c2 : f64 + %c2_dno = vm.optimization_barrier %c2 : f64 %cmp_0 = vm.cmp.lte.f64.o %cn2_dno, %c2_dno : f64 vm.check.eq %cmp_0, %true, "-2 <= 2" : i32 @@ -94,9 +94,9 @@ vm.module @comparison_ops_f64 { %false = vm.const.i32 0 %cn2 = vm.const.f64 -2.0 - %cn2_dno = util.optimization_barrier %cn2 : f64 + %cn2_dno = vm.optimization_barrier %cn2 : f64 %c2 = vm.const.f64 2.0 - %c2_dno = util.optimization_barrier %c2 : f64 + %c2_dno = vm.optimization_barrier %c2 : f64 %cmp_0 = vm.cmp.gt.f64.o %cn2_dno, %c2_dno : f64 vm.check.eq %cmp_0, %false, "-2 > 2" : i32 @@ -114,9 +114,9 @@ vm.module @comparison_ops_f64 { %false = vm.const.i32 0 %cn2 = vm.const.f64 -2.0 - %cn2_dno = util.optimization_barrier %cn2 : f64 + %cn2_dno = vm.optimization_barrier %cn2 : f64 %c2 = vm.const.f64 2.0 - %c2_dno = util.optimization_barrier %c2 : f64 + %c2_dno = vm.optimization_barrier %c2 : f64 %cmp_0 = vm.cmp.gte.f64.o %cn2_dno, %c2_dno : f64 vm.check.eq %cmp_0, %false, "-2 >= 2" : i32 diff --git a/runtime/src/iree/vm/test/comparison_ops_i64.mlir b/runtime/src/iree/vm/test/comparison_ops_i64.mlir index 3c10ef8e0c11..a8a44be0f7ed 100644 --- a/runtime/src/iree/vm/test/comparison_ops_i64.mlir +++ b/runtime/src/iree/vm/test/comparison_ops_i64.mlir @@ -7,9 +7,9 @@ vm.module @comparison_ops_i64 { vm.export @test_cmp_lt_s_0_i64 vm.func @test_cmp_lt_s_0_i64() { %lhs = vm.const.i64 4294967295 - %lhs_dno = util.optimization_barrier %lhs : i64 + %lhs_dno = vm.optimization_barrier %lhs : i64 %rhs = vm.const.i64 -4294967295 - %rhs_dno = util.optimization_barrier %rhs : i64 + %rhs_dno = vm.optimization_barrier %rhs : i64 %actual = vm.cmp.lt.i64.s %lhs_dno, %rhs_dno : i64 %expected = vm.const.i32 0 vm.check.eq %actual, %expected, "4294967295 (UINT_MAX) < -4294967295 (UINT_MAX)" : i32 @@ -19,9 +19,9 @@ vm.module @comparison_ops_i64 { vm.export @test_cmp_lt_s_1_i64 vm.func @test_cmp_lt_s_1_i64() { %lhs = vm.const.i64 -4294967295 - %lhs_dno = util.optimization_barrier %lhs : i64 + %lhs_dno = vm.optimization_barrier %lhs : i64 %rhs = vm.const.i64 4294967295 - %rhs_dno = util.optimization_barrier %rhs : i64 + %rhs_dno = vm.optimization_barrier %rhs : i64 %actual = vm.cmp.lt.i64.s %lhs_dno, %rhs_dno : i64 %expected = vm.const.i32 1 vm.check.eq %actual, %expected, "-4294967295 (UINT_MAX) < 4294967295 (UINT_MAX)" : i32 @@ -32,9 +32,9 @@ vm.module @comparison_ops_i64 { vm.export @test_cmp_lt_s_2_i64 vm.func @test_cmp_lt_s_2_i64() { %lhs = vm.const.i64 18446744073709551615 - %lhs_dno = util.optimization_barrier %lhs : i64 + %lhs_dno = vm.optimization_barrier %lhs : i64 %rhs = vm.const.i64 2 - %rhs_dno = util.optimization_barrier %rhs : i64 + %rhs_dno = vm.optimization_barrier %rhs : i64 %actual = vm.cmp.lt.i64.s %lhs_dno, %rhs_dno : i64 %expected = vm.const.i32 1 vm.check.eq %actual, %expected, "18446744073709551615 (ULONG_MAX) < 2" : i32 @@ -48,9 +48,9 @@ vm.module @comparison_ops_i64 { vm.export @test_cmp_lt_u_0_i64 vm.func @test_cmp_lt_u_0_i64() { %lhs = vm.const.i64 2 - %lhs_dno = util.optimization_barrier %lhs : i64 + %lhs_dno = vm.optimization_barrier %lhs : i64 %rhs = vm.const.i64 -2 - %rhs_dno = util.optimization_barrier %rhs : i64 + %rhs_dno = vm.optimization_barrier %rhs : i64 %actual = vm.cmp.lt.i64.u %lhs_dno, %rhs_dno : i64 %expected = vm.const.i32 1 vm.check.eq %actual, %expected, "2 < -2 (as unsigned)" : i32 @@ -60,9 +60,9 @@ vm.module @comparison_ops_i64 { vm.export @test_cmp_lt_u_1_i64 vm.func @test_cmp_lt_u_1_i64() { %lhs = vm.const.i64 -2 - %lhs_dno = util.optimization_barrier %lhs : i64 + %lhs_dno = vm.optimization_barrier %lhs : i64 %rhs = vm.const.i64 2 - %rhs_dno = util.optimization_barrier %rhs : i64 + %rhs_dno = vm.optimization_barrier %rhs : i64 %actual = vm.cmp.lt.i64.u %lhs_dno, %rhs_dno : i64 %expected = vm.const.i32 0 vm.check.eq %actual, %expected, "-2 < 2 (as unsigned)" : i32 @@ -72,9 +72,9 @@ vm.module @comparison_ops_i64 { vm.export @test_cmp_lt_u_2_i64 vm.func @test_cmp_lt_u_2_i64() { %lhs = vm.const.i64 18446744073709551615 - %lhs_dno = util.optimization_barrier %lhs : i64 + %lhs_dno = vm.optimization_barrier %lhs : i64 %rhs = vm.const.i64 2 - %rhs_dno = util.optimization_barrier %rhs : i64 + %rhs_dno = vm.optimization_barrier %rhs : i64 %actual = vm.cmp.lt.i64.u %lhs_dno, %rhs_dno : i64 %expected = vm.const.i32 0 vm.check.eq %actual, %expected, "18446744073709551615 (ULONG_MAX) < 2 (as unsigned)" : i32 @@ -94,9 +94,9 @@ vm.module @comparison_ops_i64 { %false = vm.const.i32 0 %cn2 = vm.const.i64 -2 - %cn2_dno = util.optimization_barrier %cn2 : i64 + %cn2_dno = vm.optimization_barrier %cn2 : i64 %c2 = vm.const.i64 2 - %c2_dno = util.optimization_barrier %c2 : i64 + %c2_dno = vm.optimization_barrier %c2 : i64 %cmp_0 = vm.cmp.lte.i64.s %cn2_dno, %c2_dno : i64 vm.check.eq %cmp_0, %true, "-2 <= 2" : i32 @@ -121,9 +121,9 @@ vm.module @comparison_ops_i64 { %false = vm.const.i32 0 %cn2 = vm.const.i64 -2 - %cn2_dno = util.optimization_barrier %cn2 : i64 + %cn2_dno = vm.optimization_barrier %cn2 : i64 %c2 = vm.const.i64 2 - %c2_dno = util.optimization_barrier %c2 : i64 + %c2_dno = vm.optimization_barrier %c2 : i64 %cmp_0 = vm.cmp.gt.i64.s %cn2_dno, %c2_dno : i64 vm.check.eq %cmp_0, %false, "-2 > 2" : i32 @@ -148,9 +148,9 @@ vm.module @comparison_ops_i64 { %false = vm.const.i32 0 %cn2 = vm.const.i64 -2 - %cn2_dno = util.optimization_barrier %cn2 : i64 + %cn2_dno = vm.optimization_barrier %cn2 : i64 %c2 = vm.const.i64 2 - %c2_dno = util.optimization_barrier %c2 : i64 + %c2_dno = vm.optimization_barrier %c2 : i64 %cmp_0 = vm.cmp.gte.i64.s %cn2_dno, %c2_dno : i64 vm.check.eq %cmp_0, %false, "-2 >= 2" : i32 diff --git a/runtime/src/iree/vm/test/control_flow_ops.mlir b/runtime/src/iree/vm/test/control_flow_ops.mlir index a091f942b7c1..902d838c965d 100644 --- a/runtime/src/iree/vm/test/control_flow_ops.mlir +++ b/runtime/src/iree/vm/test/control_flow_ops.mlir @@ -26,7 +26,7 @@ vm.module @control_flow_ops { vm.export @test_check_eq_always vm.func @test_check_eq_always() { %c1 = vm.const.i32 1 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 vm.check.eq %c1, %c1dno, "error!" : i32 vm.return } @@ -35,8 +35,8 @@ vm.module @control_flow_ops { vm.func @fail_check_eq_never() { %c1 = vm.const.i32 1 %c2 = vm.const.i32 2 - %c1dno = util.optimization_barrier %c1 : i32 - %c2dno = util.optimization_barrier %c2 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 vm.check.eq %c1dno, %c2dno, "error!" : i32 vm.return } @@ -72,7 +72,7 @@ vm.module @control_flow_ops { vm.export @test_cond_br vm.func @test_cond_br() { %c1 = vm.const.i32 1 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 vm.cond_br %c1dno, ^bb1, ^bb2 ^bb1: vm.check.eq %c1dno, %c1dno, "error!" : i32 @@ -85,7 +85,7 @@ vm.module @control_flow_ops { vm.export @test_cond_br_int_arg vm.func @test_cond_br_int_arg() { %c1 = vm.const.i32 1 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 vm.cond_br %c1dno, ^bb1(%c1dno : i32), ^bb2(%c1dno : i32) ^bb1(%arg1 : i32): vm.check.eq %arg1, %c1dno, "error!" : i32 @@ -98,7 +98,7 @@ vm.module @control_flow_ops { vm.export @test_cond_br_ref_arg vm.func @test_cond_br_ref_arg() { %c1 = vm.const.i32 1 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %ref = vm.const.ref.zero : !vm.ref vm.cond_br %c1dno, ^bb1(%ref : !vm.ref), ^bb2(%ref : !vm.ref) ^bb1(%arg1 : !vm.ref): @@ -115,9 +115,9 @@ vm.module @control_flow_ops { vm.export @test_cond_br_same_successor attributes {emitc.exclude} vm.func private @test_cond_br_same_successor() { %c1 = vm.const.i32 1 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.i32 2 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 vm.cond_br %c1dno, ^bb1(%c1dno : i32), ^bb1(%c2dno : i32) ^bb1(%arg1 : i32): vm.check.eq %arg1, %c1dno, "error!" : i32 @@ -129,7 +129,7 @@ vm.module @control_flow_ops { %c0 = vm.const.i32 0 %c1 = vm.const.i32 1 %c2 = vm.const.i32 2 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 vm.br_table %c1dno { default: ^bb1(%c2 : i32), 0: ^bb2(%c0 : i32), @@ -148,7 +148,7 @@ vm.module @control_flow_ops { %c1 = vm.const.i32 1 %c2 = vm.const.i32 2 %c-1 = vm.const.i32 -1 - %c-1dno = util.optimization_barrier %c-1 : i32 + %c-1dno = vm.optimization_barrier %c-1 : i32 vm.br_table %c-1dno { default: ^bb1(%c0 : i32), 0: ^bb2(%c1 : i32), diff --git a/runtime/src/iree/vm/test/conversion_ops.mlir b/runtime/src/iree/vm/test/conversion_ops.mlir index 22374a8af34f..d6bdb11cc666 100644 --- a/runtime/src/iree/vm/test/conversion_ops.mlir +++ b/runtime/src/iree/vm/test/conversion_ops.mlir @@ -7,7 +7,7 @@ vm.module @conversion_ops { vm.export @test_trunc_i32_i8 vm.func private @test_trunc_i32_i8() { %c1 = vm.const.i32 2147483647 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %v = vm.trunc.i32.i8 %c1dno : i32 -> i32 %c2 = vm.const.i32 255 vm.check.eq %v, %c2, "truncate unsigned i32 to unsigned i8" : i32 @@ -17,7 +17,7 @@ vm.module @conversion_ops { vm.export @test_trunc_i32_i16 vm.func private @test_trunc_i32_i16() { %c1 = vm.const.i32 2147483647 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %v = vm.trunc.i32.i16 %c1dno : i32 -> i32 %c2 = vm.const.i32 65535 vm.check.eq %v, %c2, "truncate unsigned i32 to unsigned i16" : i32 @@ -30,7 +30,7 @@ vm.module @conversion_ops { %alignment = vm.const.i32 16 %buffer = vm.buffer.alloc %c128, %alignment : !vm.buffer %any = vm.cast.ref.any %buffer : !vm.buffer -> !vm.ref - %any_dno = util.optimization_barrier %any : !vm.ref + %any_dno = vm.optimization_barrier %any : !vm.ref %cast = vm.cast.any.ref %any_dno : !vm.ref -> !vm.buffer vm.check.eq %buffer, %cast, "cast should succeed" : !vm.buffer vm.return @@ -40,7 +40,7 @@ vm.module @conversion_ops { vm.func private @test_cast_any_ref_null() { %null = vm.const.ref.zero : !vm.buffer %any = vm.cast.ref.any %null : !vm.buffer -> !vm.ref - %any_dno = util.optimization_barrier %any : !vm.ref + %any_dno = vm.optimization_barrier %any : !vm.ref %cast = vm.cast.any.ref %any_dno : !vm.ref -> !vm.buffer vm.check.eq %null, %cast, "cast should succeed on nulls" : !vm.buffer vm.return @@ -52,10 +52,10 @@ vm.module @conversion_ops { %alignment = vm.const.i32 16 %buffer = vm.buffer.alloc %c128, %alignment : !vm.buffer %any = vm.cast.ref.any %buffer : !vm.buffer -> !vm.ref - %any_dno = util.optimization_barrier %any : !vm.ref + %any_dno = vm.optimization_barrier %any : !vm.ref // Should fail at runtime because of the type mismatch. %cast = vm.cast.any.ref %any_dno : !vm.ref -> !vm.list - util.optimization_barrier %cast : !vm.list + vm.optimization_barrier %cast : !vm.list vm.return } diff --git a/runtime/src/iree/vm/test/conversion_ops_f32.mlir b/runtime/src/iree/vm/test/conversion_ops_f32.mlir index bb893f77ddbf..dbc6b55b6d0f 100644 --- a/runtime/src/iree/vm/test/conversion_ops_f32.mlir +++ b/runtime/src/iree/vm/test/conversion_ops_f32.mlir @@ -7,7 +7,7 @@ vm.module @conversion_ops_f32 { vm.export @test_bitcast_i32_f32 vm.func @test_bitcast_i32_f32() { %c1 = vm.const.i32 0x40B00000 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %v = vm.bitcast.i32.f32 %c1dno : i32 -> f32 %c2 = vm.const.f32 5.5 vm.check.eq %v, %c2, "bitcast i32 to f32" : f32 @@ -17,7 +17,7 @@ vm.module @conversion_ops_f32 { vm.export @test_bitcast_f32_i32 vm.func @test_bitcast_f32_i32() { %c1 = vm.const.f32 5.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.bitcast.f32.i32 %c1dno : f32 -> i32 %c2 = vm.const.i32 0x40B00000 vm.check.eq %v, %c2, "bitcast f32 to i32" : i32 @@ -27,7 +27,7 @@ vm.module @conversion_ops_f32 { vm.export @test_cast_si32_f32_int_max vm.func @test_cast_si32_f32_int_max() { %c1 = vm.const.i32 2147483647 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %v = vm.cast.si32.f32 %c1dno : i32 -> f32 %c2 = vm.const.f32 2147483647.0 vm.check.eq %v, %c2, "cast signed integer to a floating-point value" : f32 @@ -37,7 +37,7 @@ vm.module @conversion_ops_f32 { vm.export @test_cast_si32_f32_int_min vm.func @test_cast_si32_f32_int_min() { %c1 = vm.const.i32 -2147483648 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %v = vm.cast.si32.f32 %c1dno : i32 -> f32 %c2 = vm.const.f32 -2147483648.0 vm.check.eq %v, %c2, "cast signed integer to a floating-point value" : f32 @@ -47,7 +47,7 @@ vm.module @conversion_ops_f32 { vm.export @test_cast_ui32_f32_int_max vm.func @test_cast_ui32_f32_int_max() { %c1 = vm.const.i32 4294967295 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %v = vm.cast.ui32.f32 %c1dno : i32 -> f32 %c2 = vm.const.f32 4294967295.0 vm.check.eq %v, %c2, "cast unsigned integer to a floating-point value" : f32 @@ -59,7 +59,7 @@ vm.module @conversion_ops_f32 { // This is the maximum value that is representable precisely as both i32 // and f32. An exponent of 30 with all mantissa bits set. %c1 = vm.const.f32 0x4effffff - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cast.f32.si32 %c1dno : f32 -> i32 %c2 = vm.const.i32 0x7FFFFF80 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i32 @@ -69,7 +69,7 @@ vm.module @conversion_ops_f32 { vm.export @test_cast_f32_si32_int_min vm.func @test_cast_f32_si32_int_min() { %c1 = vm.const.f32 -2147483648.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cast.f32.si32 %c1dno : f32 -> i32 %c2 = vm.const.i32 -2147483648 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i32 @@ -79,7 +79,7 @@ vm.module @conversion_ops_f32 { vm.export @test_cast_f32_si32_away_from_zero_pos vm.func @test_cast_f32_si32_away_from_zero_pos() { %c1 = vm.const.f32 2.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cast.f32.si32 %c1dno : f32 -> i32 %c2 = vm.const.i32 3 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i32 @@ -89,7 +89,7 @@ vm.module @conversion_ops_f32 { vm.export @test_cast_f32_si32_away_from_zero_neg vm.func @test_cast_f32_si32_away_from_zero_neg() { %c1 = vm.const.f32 -2.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cast.f32.si32 %c1dno : f32 -> i32 %c2 = vm.const.i32 -3 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i32 @@ -101,7 +101,7 @@ vm.module @conversion_ops_f32 { // This is the maximum value that is representable precisely as both i64 // and f32. An exponent of 62 with all mantissa bits set. %c1 = vm.const.f32 0x5effffff - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cast.f32.si64 %c1dno : f32 -> i64 %c2 = vm.const.i64 0x7FFFFF8000000000 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i64 @@ -111,13 +111,13 @@ vm.module @conversion_ops_f32 { vm.export @test_cast_f32_si64_int_min vm.func @test_cast_f32_si64_int_min() { %c1 = vm.const.f32 -9223372036854775808.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cast.f32.si64 %c1dno : f32 -> i64 // Directly providing the true INT64_MIN of -9223372036854775808 // gives an error so we do -(INT64_MAX) - 1 // See: https://stackoverflow.com/a/65008288 %c2 = vm.const.i64 -9223372036854775807 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %c3 = vm.const.i64 1 %c4 = vm.sub.i64 %c2dno, %c3 : i64 vm.check.eq %v, %c4, "cast floating-point value to a signed integer" : i64 @@ -127,7 +127,7 @@ vm.module @conversion_ops_f32 { vm.export @test_cast_f32_si64_away_from_zero_pos vm.func @test_cast_f32_si64_away_from_zero_pos() { %c1 = vm.const.f32 2.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cast.f32.si64 %c1dno : f32 -> i64 %c2 = vm.const.i64 3 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i64 @@ -137,19 +137,22 @@ vm.module @conversion_ops_f32 { vm.export @test_cast_f32_si64_away_from_zero_neg vm.func @test_cast_f32_si64_away_from_zero_neg() { %c1 = vm.const.f32 -2.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cast.f32.si64 %c1dno : f32 -> i64 %c2 = vm.const.i64 -3 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i64 vm.return } - vm.export @test_cast_f32_ui32_int_big + // EmitC constant folding breaks through vm.optimization_barrier, causing + // this test to be folded to unconditional error. Excluded until EmitC is + // removed or barrier handling is fixed. + vm.export @test_cast_f32_ui32_int_big attributes {emitc.exclude} vm.func @test_cast_f32_ui32_int_big() { // This is the maximum value that is representable precisely as both ui32 // and f32. An exponent of 31 with all mantissa bits set. %c1 = vm.const.f32 0x4f7fffff - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cast.f32.ui32 %c1dno : f32 -> i32 %c2 = vm.const.i32 0xFFFFFF00 vm.check.eq %v, %c2, "cast floating-point value to an unsigned integer" : i32 @@ -159,19 +162,22 @@ vm.module @conversion_ops_f32 { vm.export @test_cast_f32_ui32_away_from_zero vm.func @test_cast_f32_ui32_away_from_zero() { %c1 = vm.const.f32 2.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cast.f32.ui32 %c1dno : f32 -> i32 %c2 = vm.const.i32 3 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i32 vm.return } - vm.export @test_cast_f32_ui64_int_big + // EmitC constant folding breaks through vm.optimization_barrier, causing + // this test to be folded to unconditional error. Excluded until EmitC is + // removed or barrier handling is fixed. + vm.export @test_cast_f32_ui64_int_big attributes {emitc.exclude} vm.func @test_cast_f32_ui64_int_big() { // This is the maximum value that is representable precisely as both ui64 // and f32. An exponent of 63 with all mantissa bits set. %c1 = vm.const.f32 0x5F7FFFFF - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cast.f32.ui64 %c1dno : f32 -> i64 %c2 = vm.const.i64 0xFFFFFF0000000000 vm.check.eq %v, %c2, "cast floating-point value to an unsigned integer" : i64 @@ -181,7 +187,7 @@ vm.module @conversion_ops_f32 { vm.export @test_cast_f32_ui64_away_from_zero vm.func @test_cast_f32_ui64_away_from_zero() { %c1 = vm.const.f32 2.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cast.f32.ui64 %c1dno : f32 -> i64 %c2 = vm.const.i64 3 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i64 diff --git a/runtime/src/iree/vm/test/conversion_ops_f64.mlir b/runtime/src/iree/vm/test/conversion_ops_f64.mlir index 850983425f38..1de7f205e3e1 100644 --- a/runtime/src/iree/vm/test/conversion_ops_f64.mlir +++ b/runtime/src/iree/vm/test/conversion_ops_f64.mlir @@ -7,7 +7,7 @@ vm.module @conversion_ops_f64 { vm.export @test_bitcast_i64_f64 vm.func @test_bitcast_i64_f64() { %c1 = vm.const.i64 0x4016000000000000 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %v = vm.bitcast.i64.f64 %c1dno : i64 -> f64 %c2 = vm.const.f64 5.5 vm.check.eq %v, %c2, "bitcast i64 to f64" : f64 @@ -17,7 +17,7 @@ vm.module @conversion_ops_f64 { vm.export @test_bitcast_f64_i64 vm.func @test_bitcast_f64_i64() { %c1 = vm.const.f64 5.5 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.bitcast.f64.i64 %c1dno : f64 -> i64 %c2 = vm.const.i64 0x4016000000000000 vm.check.eq %v, %c2, "bitcast f64 to i64" : i64 @@ -27,7 +27,7 @@ vm.module @conversion_ops_f64 { vm.export @test_cast_si64_f64_int_max vm.func @test_cast_si64_f64_int_max() { %c1 = vm.const.i64 2147483647 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %v = vm.cast.si64.f64 %c1dno : i64 -> f64 %c2 = vm.const.f64 2147483647.0 vm.check.eq %v, %c2, "cast signed integer to a floating-point value" : f64 @@ -37,7 +37,7 @@ vm.module @conversion_ops_f64 { vm.export @test_cast_si64_f64_int_min vm.func @test_cast_si64_f64_int_min() { %c1 = vm.const.i64 -2147483648 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %v = vm.cast.si64.f64 %c1dno : i64 -> f64 %c2 = vm.const.f64 -2147483648.0 vm.check.eq %v, %c2, "cast signed integer to a floating-point value" : f64 @@ -47,7 +47,7 @@ vm.module @conversion_ops_f64 { vm.export @test_cast_ui64_f64_int_max vm.func @test_cast_ui64_f64_int_max() { %c1 = vm.const.i64 4294967295 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %v = vm.cast.ui64.f64 %c1dno : i64 -> f64 %c2 = vm.const.f64 4294967295.0 vm.check.eq %v, %c2, "cast unsigned integer to a floating-point value" : f64 @@ -57,7 +57,7 @@ vm.module @conversion_ops_f64 { vm.export @test_cast_f64_si64_int_min vm.func @test_cast_f64_si64_int_min() { %c1 = vm.const.f64 -2147483648.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.cast.f64.si64 %c1dno : f64 -> i64 %c2 = vm.const.i64 -2147483648 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i64 @@ -67,7 +67,7 @@ vm.module @conversion_ops_f64 { vm.export @test_cast_f64_si64_away_from_zero_pos vm.func @test_cast_f64_si64_away_from_zero_pos() { %c1 = vm.const.f64 2.5 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.cast.f64.si64 %c1dno : f64 -> i64 %c2 = vm.const.i64 3 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i64 @@ -77,7 +77,7 @@ vm.module @conversion_ops_f64 { vm.export @test_cast_f64_si64_away_from_zero_neg vm.func @test_cast_f64_si64_away_from_zero_neg() { %c1 = vm.const.f64 -2.5 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.cast.f64.si64 %c1dno : f64 -> i64 %c2 = vm.const.i64 -3 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i64 @@ -87,7 +87,7 @@ vm.module @conversion_ops_f64 { vm.export @test_cast_f64_ui64_away_from_zero vm.func @test_cast_f64_ui64_away_from_zero() { %c1 = vm.const.f64 2.5 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.cast.f64.ui64 %c1dno : f64 -> i64 %c2 = vm.const.i64 3 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i64 diff --git a/runtime/src/iree/vm/test/conversion_ops_i64.mlir b/runtime/src/iree/vm/test/conversion_ops_i64.mlir index 4ab99fa5e1fd..dc17376d9af2 100644 --- a/runtime/src/iree/vm/test/conversion_ops_i64.mlir +++ b/runtime/src/iree/vm/test/conversion_ops_i64.mlir @@ -7,7 +7,7 @@ vm.module @conversion_ops_i64 { vm.export @test_trunc_i64_i32 vm.func @test_trunc_i64_i32() { %c1 = vm.const.i64 9223372036854775807 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %v = vm.trunc.i64.i32 %c1dno : i64 -> i32 %c2 = vm.const.i32 4294967295 vm.check.eq %v, %c2, "truncate unsigned i64 to unsigned i32" : i32 diff --git a/runtime/src/iree/vm/test/emitc/BUILD.bazel b/runtime/src/iree/vm/test/emitc/BUILD.bazel index 1516cbf796a0..a19aebabe3d2 100644 --- a/runtime/src/iree/vm/test/emitc/BUILD.bazel +++ b/runtime/src/iree/vm/test/emitc/BUILD.bazel @@ -8,13 +8,14 @@ load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_test") load("//build_tools/bazel:iree_c_module.bzl", "iree_c_module") package( + default_visibility = ["//runtime/src/iree/vm:__subpackages__"], features = ["layering_check"], licenses = ["notice"], # Apache 2.0 ) iree_runtime_cc_test( - name = "module_test", - srcs = ["module_test.cc"], + name = "emitc_module_test", + srcs = ["emitc_module_test.cc"], deps = [ ":arithmetic_ops", ":arithmetic_ops_f32", @@ -46,6 +47,7 @@ iree_runtime_cc_test( "//runtime/src/iree/vm:ops", "//runtime/src/iree/vm:ops_emitc", "//runtime/src/iree/vm:shims_emitc", + "//runtime/src/iree/vm/testing:test_runner", ], ) diff --git a/runtime/src/iree/vm/test/emitc/CMakeLists.txt b/runtime/src/iree/vm/test/emitc/CMakeLists.txt index c80da9ee6f81..d174730f9c6f 100644 --- a/runtime/src/iree/vm/test/emitc/CMakeLists.txt +++ b/runtime/src/iree/vm/test/emitc/CMakeLists.txt @@ -10,14 +10,10 @@ if(IREE_OUTPUT_FORMAT_C) iree_cc_test( NAME - module_test + emitc_module_test SRCS - "module_test.cc" + "emitc_module_test.cc" DEPS - iree::base - iree::testing::gtest - iree::testing::gtest_main - iree::vm ::arithmetic_ops ::arithmetic_ops_f32 ::arithmetic_ops_i64 @@ -41,6 +37,14 @@ iree_cc_test( ::ref_ops ::shift_ops ::shift_ops_i64 + iree::base + iree::testing::gtest + iree::testing::gtest_main + iree::vm + iree::vm::ops + iree::vm::ops_emitc + iree::vm::shims_emitc + iree::vm::testing::test_runner ) iree_c_module( diff --git a/runtime/src/iree/vm/test/emitc/emitc_module_test.cc b/runtime/src/iree/vm/test/emitc/emitc_module_test.cc new file mode 100644 index 000000000000..1830a76d80fb --- /dev/null +++ b/runtime/src/iree/vm/test/emitc/emitc_module_test.cc @@ -0,0 +1,122 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +// We should not be including C implementation-only headers in a C++ +// module like this. In order to make this work for the moment across +// runtime libraries that are strict, do a global using of the std namespace. +// EmitC is deprecated and will not be gaining any additional test support so +// this is an "as long as it works it's fine" compromise. +using namespace std; + +#include "iree/base/api.h" +#include "iree/vm/api.h" +#include "iree/vm/testing/test_runner.h" + +#define EMITC_IMPLEMENTATION +#include "iree/vm/test/emitc/arithmetic_ops.h" +#include "iree/vm/test/emitc/arithmetic_ops_f32.h" +#include "iree/vm/test/emitc/arithmetic_ops_i64.h" +#include "iree/vm/test/emitc/assignment_ops.h" +#include "iree/vm/test/emitc/assignment_ops_f32.h" +#include "iree/vm/test/emitc/assignment_ops_i64.h" +#include "iree/vm/test/emitc/buffer_ops.h" +#include "iree/vm/test/emitc/call_ops.h" +#include "iree/vm/test/emitc/comparison_ops.h" +#include "iree/vm/test/emitc/comparison_ops_f32.h" +#include "iree/vm/test/emitc/comparison_ops_i64.h" +#include "iree/vm/test/emitc/control_flow_ops.h" +#include "iree/vm/test/emitc/conversion_ops.h" +#include "iree/vm/test/emitc/conversion_ops_f32.h" +#include "iree/vm/test/emitc/conversion_ops_i64.h" +#include "iree/vm/test/emitc/global_ops.h" +#include "iree/vm/test/emitc/global_ops_f32.h" +#include "iree/vm/test/emitc/global_ops_i64.h" +#include "iree/vm/test/emitc/list_ops.h" +#include "iree/vm/test/emitc/list_variant_ops.h" +#include "iree/vm/test/emitc/ref_ops.h" +#include "iree/vm/test/emitc/shift_ops.h" +#include "iree/vm/test/emitc/shift_ops_i64.h" + +IREE_VM_TEST_RUNNER_STATIC_STORAGE(); + +namespace iree::vm::testing { +namespace { + +typedef iree_status_t (*emitc_create_fn_t)(iree_vm_instance_t*, + iree_allocator_t, + iree_vm_module_t**); + +struct EmitcModuleInfo { + iree_vm_native_module_descriptor_t descriptor; + emitc_create_fn_t create_fn; +}; + +std::vector GetEmitcTestParams() { + std::vector test_params; + + std::vector modules = { + {arithmetic_ops_descriptor_, arithmetic_ops_create}, + {arithmetic_ops_f32_descriptor_, arithmetic_ops_f32_create}, + {arithmetic_ops_i64_descriptor_, arithmetic_ops_i64_create}, + {assignment_ops_descriptor_, assignment_ops_create}, + {assignment_ops_f32_descriptor_, assignment_ops_f32_create}, + {assignment_ops_i64_descriptor_, assignment_ops_i64_create}, + {buffer_ops_descriptor_, buffer_ops_create}, + {call_ops_descriptor_, call_ops_create}, + {comparison_ops_descriptor_, comparison_ops_create}, + {comparison_ops_f32_descriptor_, comparison_ops_f32_create}, + {comparison_ops_i64_descriptor_, comparison_ops_i64_create}, + {control_flow_ops_descriptor_, control_flow_ops_create}, + {conversion_ops_descriptor_, conversion_ops_create}, + {conversion_ops_f32_descriptor_, conversion_ops_f32_create}, + {conversion_ops_i64_descriptor_, conversion_ops_i64_create}, + {global_ops_descriptor_, global_ops_create}, + {global_ops_f32_descriptor_, global_ops_f32_create}, + {global_ops_i64_descriptor_, global_ops_i64_create}, + {list_ops_descriptor_, list_ops_create}, + {list_variant_ops_descriptor_, list_variant_ops_create}, + {ref_ops_descriptor_, ref_ops_create}, + {shift_ops_descriptor_, shift_ops_create}, + {shift_ops_i64_descriptor_, shift_ops_i64_create}, + }; + + for (const auto& mod : modules) { + std::string module_name(mod.descriptor.name.data, mod.descriptor.name.size); + emitc_create_fn_t create_fn = mod.create_fn; + + for (iree_host_size_t i = 0; i < mod.descriptor.export_count; ++i) { + const iree_vm_native_export_descriptor_t& export_desc = + mod.descriptor.exports[i]; + std::string fn_name(export_desc.local_name.data, + export_desc.local_name.size); + test_params.push_back({ + module_name, + fn_name, + [create_fn](iree_vm_instance_t* inst, iree_vm_module_t** out_mod) { + return create_fn(inst, iree_allocator_system(), out_mod); + }, + /*expects_failure=*/fn_name.find("fail_") == 0, + /*prerequisite_modules=*/{}, + }); + } + } + + return test_params; +} + +class VMEmitcTest : public VMTestRunner<> {}; + +IREE_VM_TEST_F(VMEmitcTest) + +INSTANTIATE_TEST_SUITE_P(emitc, VMEmitcTest, + ::testing::ValuesIn(GetEmitcTestParams()), + ::testing::PrintToStringParamName()); + +} // namespace +} // namespace iree::vm::testing diff --git a/runtime/src/iree/vm/test/emitc/module_test.cc b/runtime/src/iree/vm/test/emitc/module_test.cc deleted file mode 100644 index fd73e3044c73..000000000000 --- a/runtime/src/iree/vm/test/emitc/module_test.cc +++ /dev/null @@ -1,186 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -// TODO: We should not be including C implementation-only headers in a C++ -// module like this. In order to make this work for the moment across -// runtime libraries that are strict, do a global using of the std namespace. -// See #7605 -#include -using namespace std; - -#include "iree/base/api.h" -#include "iree/testing/gtest.h" -#include "iree/vm/api.h" -#define EMITC_IMPLEMENTATION -#include "iree/vm/test/emitc/arithmetic_ops.h" -#include "iree/vm/test/emitc/arithmetic_ops_f32.h" -#include "iree/vm/test/emitc/arithmetic_ops_i64.h" -#include "iree/vm/test/emitc/assignment_ops.h" -#include "iree/vm/test/emitc/assignment_ops_f32.h" -#include "iree/vm/test/emitc/assignment_ops_i64.h" -#include "iree/vm/test/emitc/buffer_ops.h" -#include "iree/vm/test/emitc/call_ops.h" -#include "iree/vm/test/emitc/comparison_ops.h" -#include "iree/vm/test/emitc/comparison_ops_f32.h" -#include "iree/vm/test/emitc/comparison_ops_i64.h" -#include "iree/vm/test/emitc/control_flow_ops.h" -#include "iree/vm/test/emitc/conversion_ops.h" -#include "iree/vm/test/emitc/conversion_ops_f32.h" -#include "iree/vm/test/emitc/conversion_ops_i64.h" -#include "iree/vm/test/emitc/global_ops.h" -#include "iree/vm/test/emitc/global_ops_f32.h" -#include "iree/vm/test/emitc/global_ops_i64.h" -#include "iree/vm/test/emitc/list_ops.h" -#include "iree/vm/test/emitc/list_variant_ops.h" -#include "iree/vm/test/emitc/ref_ops.h" -#include "iree/vm/test/emitc/shift_ops.h" -#include "iree/vm/test/emitc/shift_ops_i64.h" - -namespace { - -typedef iree_status_t (*create_function_t)(iree_vm_instance_t*, - iree_allocator_t, - iree_vm_module_t**); - -struct TestParams { - std::string module_name; - std::string local_name; - create_function_t create_function; -}; - -struct ModuleDescription { - iree_vm_native_module_descriptor_t descriptor; - create_function_t create_function; -}; - -std::ostream& operator<<(std::ostream& os, const TestParams& params) { - std::string qualified_name = params.module_name + "." + params.local_name; - auto name_sv = - iree_make_string_view(qualified_name.data(), qualified_name.size()); - iree_string_view_replace_char(name_sv, ':', '_'); - iree_string_view_replace_char(name_sv, '.', '_'); - return os << qualified_name; -} - -std::vector GetModuleTestParams() { - std::vector test_params; - - // TODO(simon-camp): get these automatically - std::vector modules = { - {arithmetic_ops_descriptor_, arithmetic_ops_create}, - {arithmetic_ops_f32_descriptor_, arithmetic_ops_f32_create}, - {arithmetic_ops_i64_descriptor_, arithmetic_ops_i64_create}, - {assignment_ops_descriptor_, assignment_ops_create}, - {assignment_ops_f32_descriptor_, assignment_ops_f32_create}, - {assignment_ops_i64_descriptor_, assignment_ops_i64_create}, - {buffer_ops_descriptor_, buffer_ops_create}, - {call_ops_descriptor_, call_ops_create}, - {comparison_ops_descriptor_, comparison_ops_create}, - {comparison_ops_f32_descriptor_, comparison_ops_f32_create}, - {comparison_ops_i64_descriptor_, comparison_ops_i64_create}, - {control_flow_ops_descriptor_, control_flow_ops_create}, - {conversion_ops_descriptor_, conversion_ops_create}, - {conversion_ops_f32_descriptor_, conversion_ops_f32_create}, - {conversion_ops_i64_descriptor_, conversion_ops_i64_create}, - {global_ops_descriptor_, global_ops_create}, - {global_ops_f32_descriptor_, global_ops_f32_create}, - {global_ops_i64_descriptor_, global_ops_i64_create}, - {list_ops_descriptor_, list_ops_create}, - {list_variant_ops_descriptor_, list_variant_ops_create}, - {ref_ops_descriptor_, ref_ops_create}, - {shift_ops_descriptor_, shift_ops_create}, - {shift_ops_i64_descriptor_, shift_ops_i64_create}}; - - for (size_t i = 0; i < modules.size(); i++) { - iree_vm_native_module_descriptor_t descriptor = modules[i].descriptor; - create_function_t function = modules[i].create_function; - - std::string module_name = - std::string(descriptor.name.data, descriptor.name.size); - - for (iree_host_size_t i = 0; i < descriptor.export_count; i++) { - iree_vm_native_export_descriptor_t export_descriptor = - descriptor.exports[i]; - std::string local_name = std::string(export_descriptor.local_name.data, - export_descriptor.local_name.size); - test_params.push_back({module_name, local_name, function}); - } - } - - return test_params; -} - -class VMCModuleTest : public ::testing::Test, - public ::testing::WithParamInterface { - protected: - virtual void SetUp() { - const auto& test_params = GetParam(); - - IREE_CHECK_OK(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, - iree_allocator_system(), &instance_)); - - iree_vm_module_t* module_ = nullptr; - IREE_CHECK_OK(test_params.create_function( - instance_, iree_allocator_system(), &module_)); - - std::vector modules = {module_}; - IREE_CHECK_OK(iree_vm_context_create_with_modules( - instance_, IREE_VM_CONTEXT_FLAG_NONE, modules.size(), modules.data(), - iree_allocator_system(), &context_)); - - iree_vm_module_release(module_); - } - - virtual void TearDown() { - iree_vm_context_release(context_); - iree_vm_instance_release(instance_); - } - - iree_status_t RunFunction(std::string module_name, std::string local_name) { - std::string qualified_name = module_name + "." + local_name; - iree_vm_function_t function; - IREE_CHECK_OK(iree_vm_context_resolve_function( - context_, - iree_string_view_t{qualified_name.data(), static_cast( - qualified_name.size())}, - &function)); - - return iree_vm_invoke(context_, function, IREE_VM_INVOCATION_FLAG_NONE, - /*policy=*/nullptr, /*inputs=*/nullptr, - /*outputs=*/nullptr, iree_allocator_system()); - } - - iree_vm_instance_t* instance_ = nullptr; - iree_vm_context_t* context_ = nullptr; -}; - -TEST_P(VMCModuleTest, Check) { - const auto& test_params = GetParam(); - bool expect_failure = test_params.local_name.find("fail_") == 0; - - iree::Status result = - RunFunction(test_params.module_name, test_params.local_name); - if (result.ok()) { - if (expect_failure) { - GTEST_FAIL() << "Function expected failure but succeeded"; - } else { - GTEST_SUCCEED(); - } - } else { - if (expect_failure) { - GTEST_SUCCEED(); - } else { - GTEST_FAIL() << "Function expected success but failed with error: " - << result.ToString(); - } - } -} - -INSTANTIATE_TEST_SUITE_P(VMIRFunctions, VMCModuleTest, - ::testing::ValuesIn(GetModuleTestParams()), - ::testing::PrintToStringParamName()); - -} // namespace diff --git a/runtime/src/iree/vm/test/global_ops.mlir b/runtime/src/iree/vm/test/global_ops.mlir index 263e7b5028a5..fc6aab40905f 100644 --- a/runtime/src/iree/vm/test/global_ops.mlir +++ b/runtime/src/iree/vm/test/global_ops.mlir @@ -22,7 +22,7 @@ vm.module @global_ops { vm.func @test_global_load_ref() { %actual = vm.global.load.ref @g0 : !vm.buffer %expected = vm.const.ref.zero : !vm.buffer - %expecteddno = util.optimization_barrier %expected : !vm.buffer + %expecteddno = vm.optimization_barrier %expected : !vm.buffer vm.check.eq %actual, %expecteddno : !vm.buffer vm.return } diff --git a/runtime/src/iree/vm/test/list_ops.mlir b/runtime/src/iree/vm/test/list_ops.mlir index 696be360616e..4947a675beec 100644 --- a/runtime/src/iree/vm/test/list_ops.mlir +++ b/runtime/src/iree/vm/test/list_ops.mlir @@ -12,7 +12,7 @@ vm.module @list_ops { %list = vm.list.alloc %c42 : (i32) -> !vm.list vm.list.reserve %list, %c100 : (!vm.list, i32) %sz = vm.list.size %list : (!vm.list) -> i32 - %sz_dno = util.optimization_barrier %sz : i32 + %sz_dno = vm.optimization_barrier %sz : i32 vm.check.eq %sz_dno, %c0, "list.empty.size()=0" : i32 vm.return } @@ -107,7 +107,7 @@ vm.module @list_ops { %list = vm.list.alloc %c1 : (i32) -> !vm.list vm.list.resize %list, %c1 : (!vm.list, i32) %v = vm.list.get.i32 %list, %c1 : (!vm.list, i32) -> i32 - %v_dno = util.optimization_barrier %v : i32 + %v_dno = vm.optimization_barrier %v : i32 // Add a dummy use of %v_dno to please recent versions of clang for the C target vm.list.set.i32 %list, %c1, %v_dno : (!vm.list, i32, i32) vm.return diff --git a/runtime/src/iree/vm/test/list_variant_ops.mlir b/runtime/src/iree/vm/test/list_variant_ops.mlir index 202c92ececbd..10cfb400fd32 100644 --- a/runtime/src/iree/vm/test/list_variant_ops.mlir +++ b/runtime/src/iree/vm/test/list_variant_ops.mlir @@ -113,7 +113,7 @@ vm.module @list_variant_ops { vm.list.resize %list, %c1 : (!vm.list, i32) %ref = vm.list.get.ref %list, %c1 : (!vm.list, i32) -> !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer vm.return } diff --git a/runtime/src/iree/vm/test/ref_ops.mlir b/runtime/src/iree/vm/test/ref_ops.mlir index 019cde3f83fb..ca83e92a9b65 100644 --- a/runtime/src/iree/vm/test/ref_ops.mlir +++ b/runtime/src/iree/vm/test/ref_ops.mlir @@ -17,7 +17,7 @@ vm.module @ref_ops { vm.export @test_zero_ref_eq vm.func @test_zero_ref_eq() { %ref = vm.const.ref.zero : !vm.ref - %ref_dno = util.optimization_barrier %ref : !vm.ref + %ref_dno = vm.optimization_barrier %ref : !vm.ref vm.check.eq %ref_dno, %ref_dno : !vm.ref vm.return } @@ -30,9 +30,9 @@ vm.module @ref_ops { vm.export @test_ref_eq attributes {emitc.exclude} vm.func @test_ref_eq() { %ref_1 = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_1_dno = util.optimization_barrier %ref_1 : !vm.buffer + %ref_1_dno = vm.optimization_barrier %ref_1 : !vm.buffer %ref_2 = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_2_dno = util.optimization_barrier %ref_2 : !vm.buffer + %ref_2_dno = vm.optimization_barrier %ref_2 : !vm.buffer vm.check.eq %ref_1_dno, %ref_2_dno : !vm.buffer vm.return } @@ -40,9 +40,9 @@ vm.module @ref_ops { vm.export @test_ref_ne vm.func @test_ref_ne() { %ref_a = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_a_dno = util.optimization_barrier %ref_a : !vm.buffer + %ref_a_dno = vm.optimization_barrier %ref_a : !vm.buffer %ref_b = vm.const.ref.rodata @buffer_b : !vm.buffer - %ref_b_dno = util.optimization_barrier %ref_b : !vm.buffer + %ref_b_dno = vm.optimization_barrier %ref_b : !vm.buffer vm.check.ne %ref_a_dno, %ref_b_dno : !vm.buffer vm.return } @@ -50,7 +50,7 @@ vm.module @ref_ops { vm.export @test_ref_nz vm.func @test_ref_nz() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer vm.check.nz %ref_dno : !vm.buffer vm.return } @@ -64,7 +64,7 @@ vm.module @ref_ops { vm.export @test_ref_survives_call attributes {emitc.exclude} vm.func @test_ref_survives_call() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer vm.check.nz %ref_dno, "ref valid before call" : !vm.buffer vm.call @_consume_ref(%ref_dno) : (!vm.buffer) -> () // Ref should still be valid after the call. @@ -74,7 +74,7 @@ vm.module @ref_ops { vm.func private @_consume_ref(%arg : !vm.buffer) attributes {inlining_policy = #util.inline.never} { - %arg_dno = util.optimization_barrier %arg : !vm.buffer + %arg_dno = vm.optimization_barrier %arg : !vm.buffer vm.check.nz %arg_dno, "ref valid in callee" : !vm.buffer vm.return } @@ -83,7 +83,7 @@ vm.module @ref_ops { vm.export @test_same_ref_multiple_args attributes {emitc.exclude} vm.func @test_same_ref_multiple_args() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer vm.call @_consume_two_refs(%ref_dno, %ref_dno) : (!vm.buffer, !vm.buffer) -> () // Ref should still be valid after the call. vm.check.nz %ref_dno, "ref valid after call with same ref twice" : !vm.buffer @@ -92,8 +92,8 @@ vm.module @ref_ops { vm.func private @_consume_two_refs(%arg0 : !vm.buffer, %arg1 : !vm.buffer) attributes {inlining_policy = #util.inline.never} { - %arg0_dno = util.optimization_barrier %arg0 : !vm.buffer - %arg1_dno = util.optimization_barrier %arg1 : !vm.buffer + %arg0_dno = vm.optimization_barrier %arg0 : !vm.buffer + %arg1_dno = vm.optimization_barrier %arg1 : !vm.buffer vm.check.nz %arg0_dno, "first arg valid" : !vm.buffer vm.check.nz %arg1_dno, "second arg valid" : !vm.buffer vm.check.eq %arg0_dno, %arg1_dno, "both args are same ref" : !vm.buffer @@ -104,7 +104,7 @@ vm.module @ref_ops { vm.export @test_ref_returned_from_call attributes {emitc.exclude} vm.func @test_ref_returned_from_call() { %ref = vm.call @_return_ref() : () -> !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer vm.check.nz %ref_dno, "returned ref is valid" : !vm.buffer vm.return } @@ -119,9 +119,9 @@ vm.module @ref_ops { vm.export @test_ref_passthrough attributes {emitc.exclude} vm.func @test_ref_passthrough() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %returned = vm.call @_passthrough_ref(%ref_dno) : (!vm.buffer) -> !vm.buffer - %returned_dno = util.optimization_barrier %returned : !vm.buffer + %returned_dno = vm.optimization_barrier %returned : !vm.buffer vm.check.eq %ref_dno, %returned_dno, "passthrough returns same ref" : !vm.buffer vm.return } @@ -139,9 +139,9 @@ vm.module @ref_ops { vm.export @test_ref_cond_br_both_paths attributes {emitc.exclude} vm.func @test_ref_cond_br_both_paths() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %c1 = vm.const.i32 1 - %c1_dno = util.optimization_barrier %c1 : i32 + %c1_dno = vm.optimization_barrier %c1 : i32 vm.cond_br %c1_dno, ^bb1(%ref_dno : !vm.buffer), ^bb2(%ref_dno : !vm.buffer) ^bb1(%arg1 : !vm.buffer): vm.check.nz %arg1, "ref valid in bb1" : !vm.buffer @@ -157,9 +157,9 @@ vm.module @ref_ops { vm.export @test_ref_cond_br_one_path attributes {emitc.exclude} vm.func @test_ref_cond_br_one_path() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %c1 = vm.const.i32 1 - %c1_dno = util.optimization_barrier %c1 : i32 + %c1_dno = vm.optimization_barrier %c1 : i32 vm.cond_br %c1_dno, ^bb1(%ref_dno : !vm.buffer), ^bb2 ^bb1(%arg1 : !vm.buffer): vm.check.nz %arg1, "ref valid in bb1" : !vm.buffer @@ -172,7 +172,7 @@ vm.module @ref_ops { vm.export @test_ref_in_loop attributes {emitc.exclude} vm.func @test_ref_in_loop() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %c0 = vm.const.i32 0 %c1 = vm.const.i32 1 %c3 = vm.const.i32 3 @@ -191,9 +191,9 @@ vm.module @ref_ops { vm.export @test_multiple_refs_in_loop attributes {emitc.exclude} vm.func @test_multiple_refs_in_loop() { %ref_a = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_a_dno = util.optimization_barrier %ref_a : !vm.buffer + %ref_a_dno = vm.optimization_barrier %ref_a : !vm.buffer %ref_b = vm.const.ref.rodata @buffer_b : !vm.buffer - %ref_b_dno = util.optimization_barrier %ref_b : !vm.buffer + %ref_b_dno = vm.optimization_barrier %ref_b : !vm.buffer %c0 = vm.const.i32 0 %c1 = vm.const.i32 1 %c3 = vm.const.i32 3 @@ -217,10 +217,10 @@ vm.module @ref_ops { vm.export @test_global_store_load_ref attributes {emitc.exclude} vm.func @test_global_store_load_ref() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer vm.global.store.ref %ref_dno, @global_ref : !vm.buffer %loaded = vm.global.load.ref @global_ref : !vm.buffer - %loaded_dno = util.optimization_barrier %loaded : !vm.buffer + %loaded_dno = vm.optimization_barrier %loaded : !vm.buffer vm.check.eq %ref_dno, %loaded_dno, "loaded ref equals stored ref" : !vm.buffer vm.return } @@ -229,7 +229,7 @@ vm.module @ref_ops { vm.export @test_ref_valid_after_global_store attributes {emitc.exclude} vm.func @test_ref_valid_after_global_store() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer vm.check.nz %ref_dno, "ref valid before store" : !vm.buffer vm.global.store.ref %ref_dno, @global_ref : !vm.buffer // Original ref should still be valid after storing to global. @@ -248,12 +248,12 @@ vm.module @ref_ops { %c0 = vm.const.i32 0 %c1 = vm.const.i32 1 %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %list = vm.list.alloc %c1 : (i32) -> !vm.list vm.list.resize %list, %c1 : (!vm.list, i32) vm.list.set.ref %list, %c0, %ref_dno : (!vm.list, i32, !vm.buffer) %retrieved = vm.list.get.ref %list, %c0 : (!vm.list, i32) -> !vm.buffer - %retrieved_dno = util.optimization_barrier %retrieved : !vm.buffer + %retrieved_dno = vm.optimization_barrier %retrieved : !vm.buffer vm.check.eq %ref_dno, %retrieved_dno, "retrieved ref equals set ref" : !vm.buffer vm.return } @@ -265,17 +265,17 @@ vm.module @ref_ops { %c1 = vm.const.i32 1 %c2 = vm.const.i32 2 %ref_a = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_a_dno = util.optimization_barrier %ref_a : !vm.buffer + %ref_a_dno = vm.optimization_barrier %ref_a : !vm.buffer %ref_b = vm.const.ref.rodata @buffer_b : !vm.buffer - %ref_b_dno = util.optimization_barrier %ref_b : !vm.buffer + %ref_b_dno = vm.optimization_barrier %ref_b : !vm.buffer %list = vm.list.alloc %c2 : (i32) -> !vm.list vm.list.resize %list, %c2 : (!vm.list, i32) vm.list.set.ref %list, %c0, %ref_a_dno : (!vm.list, i32, !vm.buffer) vm.list.set.ref %list, %c1, %ref_b_dno : (!vm.list, i32, !vm.buffer) %retrieved_a = vm.list.get.ref %list, %c0 : (!vm.list, i32) -> !vm.buffer - %retrieved_a_dno = util.optimization_barrier %retrieved_a : !vm.buffer + %retrieved_a_dno = vm.optimization_barrier %retrieved_a : !vm.buffer %retrieved_b = vm.list.get.ref %list, %c1 : (!vm.list, i32) -> !vm.buffer - %retrieved_b_dno = util.optimization_barrier %retrieved_b : !vm.buffer + %retrieved_b_dno = vm.optimization_barrier %retrieved_b : !vm.buffer vm.check.eq %ref_a_dno, %retrieved_a_dno, "retrieved ref_a equals set ref_a" : !vm.buffer vm.check.eq %ref_b_dno, %retrieved_b_dno, "retrieved ref_b equals set ref_b" : !vm.buffer vm.check.ne %retrieved_a_dno, %retrieved_b_dno, "refs are different" : !vm.buffer @@ -288,12 +288,12 @@ vm.module @ref_ops { %c0 = vm.const.i32 0 %c1 = vm.const.i32 1 %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %list = vm.list.alloc %c1 : (i32) -> !vm.list vm.list.resize %list, %c1 : (!vm.list, i32) vm.list.set.ref %list, %c0, %ref_dno : (!vm.list, i32, !vm.buffer) %retrieved = vm.list.get.ref %list, %c0 : (!vm.list, i32) -> !vm.buffer - %retrieved_dno = util.optimization_barrier %retrieved : !vm.buffer + %retrieved_dno = vm.optimization_barrier %retrieved : !vm.buffer // Use retrieved ref multiple times. vm.check.nz %retrieved_dno, "retrieved ref valid (use 1)" : !vm.buffer vm.check.nz %retrieved_dno, "retrieved ref valid (use 2)" : !vm.buffer @@ -307,7 +307,7 @@ vm.module @ref_ops { %c0 = vm.const.i32 0 %c1 = vm.const.i32 1 %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer vm.check.nz %ref_dno, "ref valid before list set" : !vm.buffer %list = vm.list.alloc %c1 : (i32) -> !vm.list vm.list.resize %list, %c1 : (!vm.list, i32) @@ -325,13 +325,13 @@ vm.module @ref_ops { vm.export @test_select_ref_true attributes {emitc.exclude} vm.func @test_select_ref_true() { %ref_a = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_a_dno = util.optimization_barrier %ref_a : !vm.buffer + %ref_a_dno = vm.optimization_barrier %ref_a : !vm.buffer %ref_b = vm.const.ref.rodata @buffer_b : !vm.buffer - %ref_b_dno = util.optimization_barrier %ref_b : !vm.buffer + %ref_b_dno = vm.optimization_barrier %ref_b : !vm.buffer %c1 = vm.const.i32 1 - %c1_dno = util.optimization_barrier %c1 : i32 + %c1_dno = vm.optimization_barrier %c1 : i32 %result = vm.select.ref %c1_dno, %ref_a_dno, %ref_b_dno : !vm.buffer - %result_dno = util.optimization_barrier %result : !vm.buffer + %result_dno = vm.optimization_barrier %result : !vm.buffer vm.check.eq %result_dno, %ref_a_dno, "select true returns first ref" : !vm.buffer vm.return } @@ -339,13 +339,13 @@ vm.module @ref_ops { vm.export @test_select_ref_false attributes {emitc.exclude} vm.func @test_select_ref_false() { %ref_a = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_a_dno = util.optimization_barrier %ref_a : !vm.buffer + %ref_a_dno = vm.optimization_barrier %ref_a : !vm.buffer %ref_b = vm.const.ref.rodata @buffer_b : !vm.buffer - %ref_b_dno = util.optimization_barrier %ref_b : !vm.buffer + %ref_b_dno = vm.optimization_barrier %ref_b : !vm.buffer %c0 = vm.const.i32 0 - %c0_dno = util.optimization_barrier %c0 : i32 + %c0_dno = vm.optimization_barrier %c0 : i32 %result = vm.select.ref %c0_dno, %ref_a_dno, %ref_b_dno : !vm.buffer - %result_dno = util.optimization_barrier %result : !vm.buffer + %result_dno = vm.optimization_barrier %result : !vm.buffer vm.check.eq %result_dno, %ref_b_dno, "select false returns second ref" : !vm.buffer vm.return } @@ -354,13 +354,13 @@ vm.module @ref_ops { vm.export @test_select_ref_input_survives attributes {emitc.exclude} vm.func @test_select_ref_input_survives() { %ref_a = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_a_dno = util.optimization_barrier %ref_a : !vm.buffer + %ref_a_dno = vm.optimization_barrier %ref_a : !vm.buffer %ref_b = vm.const.ref.rodata @buffer_b : !vm.buffer - %ref_b_dno = util.optimization_barrier %ref_b : !vm.buffer + %ref_b_dno = vm.optimization_barrier %ref_b : !vm.buffer %c1 = vm.const.i32 1 - %c1_dno = util.optimization_barrier %c1 : i32 + %c1_dno = vm.optimization_barrier %c1 : i32 %result = vm.select.ref %c1_dno, %ref_a_dno, %ref_b_dno : !vm.buffer - %result_dno = util.optimization_barrier %result : !vm.buffer + %result_dno = vm.optimization_barrier %result : !vm.buffer // Both input refs should still be valid after select. vm.check.nz %ref_a_dno, "ref_a valid after select" : !vm.buffer vm.check.nz %ref_b_dno, "ref_b valid after select" : !vm.buffer @@ -376,7 +376,7 @@ vm.module @ref_ops { vm.export @test_ref_multiple_sequential_uses attributes {emitc.exclude} vm.func @test_ref_multiple_sequential_uses() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer // Use 1: check nz vm.check.nz %ref_dno, "use 1" : !vm.buffer // Use 2: pass to call @@ -394,9 +394,9 @@ vm.module @ref_ops { vm.export @test_ref_call_chain attributes {emitc.exclude} vm.func @test_ref_call_chain() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %result = vm.call @_call_chain_a(%ref_dno) : (!vm.buffer) -> !vm.buffer - %result_dno = util.optimization_barrier %result : !vm.buffer + %result_dno = vm.optimization_barrier %result : !vm.buffer vm.check.eq %ref_dno, %result_dno, "chain returns same ref" : !vm.buffer vm.return } @@ -416,9 +416,9 @@ vm.module @ref_ops { vm.export @test_return_multiple_refs attributes {emitc.exclude} vm.func @test_return_multiple_refs() { %ref_a = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_a_dno = util.optimization_barrier %ref_a : !vm.buffer + %ref_a_dno = vm.optimization_barrier %ref_a : !vm.buffer %ref_b = vm.const.ref.rodata @buffer_b : !vm.buffer - %ref_b_dno = util.optimization_barrier %ref_b : !vm.buffer + %ref_b_dno = vm.optimization_barrier %ref_b : !vm.buffer %results:2 = vm.call @_return_two_refs(%ref_a_dno, %ref_b_dno) : (!vm.buffer, !vm.buffer) -> (!vm.buffer, !vm.buffer) vm.check.eq %results#0, %ref_a_dno, "first result is ref_a" : !vm.buffer @@ -436,9 +436,9 @@ vm.module @ref_ops { vm.export @test_return_refs_swapped attributes {emitc.exclude} vm.func @test_return_refs_swapped() { %ref_a = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_a_dno = util.optimization_barrier %ref_a : !vm.buffer + %ref_a_dno = vm.optimization_barrier %ref_a : !vm.buffer %ref_b = vm.const.ref.rodata @buffer_b : !vm.buffer - %ref_b_dno = util.optimization_barrier %ref_b : !vm.buffer + %ref_b_dno = vm.optimization_barrier %ref_b : !vm.buffer %results:2 = vm.call @_return_refs_swapped(%ref_a_dno, %ref_b_dno) : (!vm.buffer, !vm.buffer) -> (!vm.buffer, !vm.buffer) vm.check.eq %results#0, %ref_b_dno, "first result is ref_b (swapped)" : !vm.buffer @@ -467,7 +467,7 @@ vm.module @ref_ops { vm.export @test_discard_single_ref attributes {emitc.exclude} vm.func private @test_discard_single_ref() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer vm.check.nz %ref_dno, "ref valid before discard" : !vm.buffer vm.discard.refs %ref_dno : !vm.buffer // Note: After discard, the ref is released. We shouldn't use it. @@ -478,9 +478,9 @@ vm.module @ref_ops { vm.export @test_discard_multiple_refs attributes {emitc.exclude} vm.func private @test_discard_multiple_refs() { %ref_a = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_a_dno = util.optimization_barrier %ref_a : !vm.buffer + %ref_a_dno = vm.optimization_barrier %ref_a : !vm.buffer %ref_b = vm.const.ref.rodata @buffer_b : !vm.buffer - %ref_b_dno = util.optimization_barrier %ref_b : !vm.buffer + %ref_b_dno = vm.optimization_barrier %ref_b : !vm.buffer vm.check.nz %ref_a_dno, "ref_a valid before discard" : !vm.buffer vm.check.nz %ref_b_dno, "ref_b valid before discard" : !vm.buffer vm.discard.refs %ref_a_dno, %ref_b_dno : !vm.buffer, !vm.buffer @@ -491,9 +491,9 @@ vm.module @ref_ops { vm.export @test_discard_in_branch attributes {emitc.exclude} vm.func private @test_discard_in_branch() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %c1 = vm.const.i32 1 - %c1_dno = util.optimization_barrier %c1 : i32 + %c1_dno = vm.optimization_barrier %c1 : i32 vm.cond_br %c1_dno, ^bb1, ^bb2 ^bb1: vm.discard.refs %ref_dno : !vm.buffer @@ -513,7 +513,7 @@ vm.module @ref_ops { vm.export @test_nested_loop_outer_ref attributes {emitc.exclude} vm.func private @test_nested_loop_outer_ref() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %c0 = vm.const.i32 0 %c1 = vm.const.i32 1 %c2 = vm.const.i32 2 @@ -543,9 +543,9 @@ vm.module @ref_ops { vm.export @test_ping_pong_swap attributes {emitc.exclude} vm.func private @test_ping_pong_swap() { %ref_a = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_a_dno = util.optimization_barrier %ref_a : !vm.buffer + %ref_a_dno = vm.optimization_barrier %ref_a : !vm.buffer %ref_b = vm.const.ref.rodata @buffer_b : !vm.buffer - %ref_b_dno = util.optimization_barrier %ref_b : !vm.buffer + %ref_b_dno = vm.optimization_barrier %ref_b : !vm.buffer %c0 = vm.const.i32 0 %c1 = vm.const.i32 1 %c3 = vm.const.i32 3 @@ -571,9 +571,9 @@ vm.module @ref_ops { vm.export @test_diamond_asymmetric_use attributes {emitc.exclude} vm.func private @test_diamond_asymmetric_use() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %c1 = vm.const.i32 1 - %c1_dno = util.optimization_barrier %c1 : i32 + %c1_dno = vm.optimization_barrier %c1 : i32 vm.cond_br %c1_dno, ^use_path(%ref_dno : !vm.buffer), ^nouse_path(%ref_dno : !vm.buffer) ^use_path(%r1 : !vm.buffer): vm.check.nz %r1, "ref valid in use_path" : !vm.buffer @@ -589,9 +589,9 @@ vm.module @ref_ops { vm.export @test_diamond_asymmetric_nouse attributes {emitc.exclude} vm.func private @test_diamond_asymmetric_nouse() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %c0 = vm.const.i32 0 - %c0_dno = util.optimization_barrier %c0 : i32 + %c0_dno = vm.optimization_barrier %c0 : i32 vm.cond_br %c0_dno, ^use_path(%ref_dno : !vm.buffer), ^nouse_path(%ref_dno : !vm.buffer) ^use_path(%r1 : !vm.buffer): vm.check.nz %r1, "ref valid in use_path" : !vm.buffer diff --git a/runtime/src/iree/vm/test/shift_ops.mlir b/runtime/src/iree/vm/test/shift_ops.mlir index b1e618d6a310..d6b258cf4436 100644 --- a/runtime/src/iree/vm/test/shift_ops.mlir +++ b/runtime/src/iree/vm/test/shift_ops.mlir @@ -7,7 +7,7 @@ vm.module @shift_ops { vm.export @test_shl_i32 vm.func @test_shl_i32() { %c1 = vm.const.i32 1 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.i32 2 %v = vm.shl.i32 %c1dno, %c2 : i32 %c4 = vm.const.i32 4 @@ -18,7 +18,7 @@ vm.module @shift_ops { vm.export @test_shr_i32s vm.func @test_shr_i32s() { %cn1 = vm.const.i32 -1 - %cn1dno = util.optimization_barrier %cn1 : i32 + %cn1dno = vm.optimization_barrier %cn1 : i32 %c2 = vm.const.i32 2 %v = vm.shr.i32.s %cn1dno, %c2 : i32 vm.check.eq %v, %cn1dno, "-1>>2=-1" : i32 @@ -28,7 +28,7 @@ vm.module @shift_ops { vm.export @test_shr_i32u vm.func @test_shr_i32u() { %c4 = vm.const.i32 4 - %c4dno = util.optimization_barrier %c4 : i32 + %c4dno = vm.optimization_barrier %c4 : i32 %c2 = vm.const.i32 2 %v = vm.shr.i32.u %c4dno, %c2 : i32 %c1 = vm.const.i32 1 diff --git a/runtime/src/iree/vm/test/shift_ops_i64.mlir b/runtime/src/iree/vm/test/shift_ops_i64.mlir index 00c072423595..6a10d14d4a8e 100644 --- a/runtime/src/iree/vm/test/shift_ops_i64.mlir +++ b/runtime/src/iree/vm/test/shift_ops_i64.mlir @@ -7,7 +7,7 @@ vm.module @shift_ops_i64 { vm.export @test_shl_i64 vm.func @test_shl_i64() { %c1 = vm.const.i64 1 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %shamt = vm.const.i32 2 %v = vm.shl.i64 %c1dno, %shamt : i64 %c4 = vm.const.i64 4 @@ -18,7 +18,7 @@ vm.module @shift_ops_i64 { vm.export @test_shr_i64s vm.func @test_shr_i64s() { %c1 = vm.const.i64 -1 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %shamt = vm.const.i32 2 %v = vm.shr.i64.s %c1dno, %shamt : i64 %cn1 = vm.const.i64 -1 @@ -29,7 +29,7 @@ vm.module @shift_ops_i64 { vm.export @test_shr_i64u vm.func @test_shr_i64u() { %c4 = vm.const.i64 4 - %c4dno = util.optimization_barrier %c4 : i64 + %c4dno = vm.optimization_barrier %c4 : i64 %shamt = vm.const.i32 2 %v = vm.shr.i64.u %c4dno, %shamt : i64 %c1 = vm.const.i64 1 diff --git a/runtime/src/iree/vm/testing/BUILD.bazel b/runtime/src/iree/vm/testing/BUILD.bazel new file mode 100644 index 000000000000..f2623cac8524 --- /dev/null +++ b/runtime/src/iree/vm/testing/BUILD.bazel @@ -0,0 +1,35 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_library") + +package( + default_visibility = ["//runtime/src/iree/vm:__subpackages__"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_runtime_cc_library( + name = "test_runner", + testonly = True, + srcs = ["test_runner.cc"], + hdrs = ["test_runner.h"], + deps = [ + "//runtime/src/iree/base", + "//runtime/src/iree/testing:gtest", + "//runtime/src/iree/vm", + ], +) + +iree_runtime_cc_library( + name = "yieldable_test_module", + testonly = True, + hdrs = ["yieldable_test_module.h"], + deps = [ + "//runtime/src/iree/base", + "//runtime/src/iree/vm", + ], +) diff --git a/runtime/src/iree/vm/testing/CMakeLists.txt b/runtime/src/iree/vm/testing/CMakeLists.txt new file mode 100644 index 000000000000..e967460df364 --- /dev/null +++ b/runtime/src/iree/vm/testing/CMakeLists.txt @@ -0,0 +1,40 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# runtime/src/iree/vm/testing/BUILD.bazel # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_cc_library( + NAME + test_runner + HDRS + "test_runner.h" + SRCS + "test_runner.cc" + DEPS + iree::base + iree::testing::gtest + iree::vm + TESTONLY + PUBLIC +) + +iree_cc_library( + NAME + yieldable_test_module + HDRS + "yieldable_test_module.h" + DEPS + iree::base + iree::vm + TESTONLY + PUBLIC +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/runtime/src/iree/vm/testing/test_runner.cc b/runtime/src/iree/vm/testing/test_runner.cc new file mode 100644 index 000000000000..4696039fb4c3 --- /dev/null +++ b/runtime/src/iree/vm/testing/test_runner.cc @@ -0,0 +1,20 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/vm/testing/test_runner.h" + +namespace iree::vm::testing { + +std::ostream& operator<<(std::ostream& os, const VMTestParams& params) { + std::string name = params.module_name + "_" + params.function_name; + // Replace special characters for valid test names. + for (char& c : name) { + if (c == ':' || c == '.') c = '_'; + } + return os << name; +} + +} // namespace iree::vm::testing diff --git a/runtime/src/iree/vm/testing/test_runner.h b/runtime/src/iree/vm/testing/test_runner.h new file mode 100644 index 000000000000..0ee371f19fb8 --- /dev/null +++ b/runtime/src/iree/vm/testing/test_runner.h @@ -0,0 +1,225 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// Shared test framework for VM module testing. +// +// This framework provides a common test runner that works across different +// VM module implementations (bytecode interpreter, EmitC, JIT, etc.). +// Tests are defined in MLIR files under iree/vm/test/ and compiled to +// different formats per backend. + +#ifndef IREE_VM_TESTING_TEST_RUNNER_H_ +#define IREE_VM_TESTING_TEST_RUNNER_H_ + +#include +#include +#include + +#include "iree/base/api.h" +#include "iree/testing/gtest.h" +#include "iree/testing/status_matchers.h" +#include "iree/vm/api.h" + +namespace iree::vm::testing { + +//===----------------------------------------------------------------------===// +// VMTestParams +//===----------------------------------------------------------------------===// +// Parameters for a single VM test function. + +// Module creation function type. +// Different backends implement this differently: +// - Bytecode: loads from embedded binary data +// - EmitC: calls static _create() function +// - JIT: compiles and loads at runtime +using VMModuleCreateFn = + std::function; + +// Parameters describing a single test to run. +struct VMTestParams { + // Module name (e.g., "arithmetic_ops"). + std::string module_name; + // Function name within the module (e.g., "test_add_i32"). + std::string function_name; + // Factory function to create the module under test. + VMModuleCreateFn create_module; + // Whether this function is expected to fail (fail_ prefix). + bool expects_failure = false; + // Factory functions for prerequisite modules that must be loaded before the + // module under test. These are loaded in order and added to the context + // first. Examples: native yieldable test module, HAL module, custom import + // modules. + std::vector prerequisite_modules; +}; + +// Allows test names to be printed nicely in gtest output. +std::ostream& operator<<(std::ostream& os, const VMTestParams& params); + +//===----------------------------------------------------------------------===// +// VMTestResources +//===----------------------------------------------------------------------===// +// Static resources shared across all tests in a suite. + +class VMTestResources { + public: + static iree_vm_instance_t* instance_; +}; + +//===----------------------------------------------------------------------===// +// VMTestRunner +//===----------------------------------------------------------------------===// +// Base test fixture for VM module testing. +// +// Usage: +// 1. Backend-specific test files include this header +// 2. Backend implements GetTestParams() returning vector +// 3. INSTANTIATE_TEST_SUITE_P with the params +// +// The runner automatically: +// - Creates VM instance/context +// - Loads the module under test +// - Optionally loads the native yieldable test module +// - Executes functions and checks results +// - Handles async/yieldable functions transparently + +template +class VMTestRunner : public BaseType, + public ::testing::WithParamInterface, + public VMTestResources { + public: + static void SetUpTestSuite() { + IREE_ASSERT_OK(iree_vm_instance_create( + IREE_VM_TYPE_CAPACITY_DEFAULT, iree_allocator_system(), &instance_)); + } + + static void TearDownTestSuite() { + if (instance_) { + iree_vm_instance_release(instance_); + instance_ = nullptr; + } + } + + void SetUp() override { + const auto& params = this->GetParam(); + + // Build module list for context. + std::vector modules; + + // Create and add prerequisite modules first (in order). + for (const auto& create_fn : params.prerequisite_modules) { + iree_vm_module_t* prereq_module = nullptr; + IREE_ASSERT_OK(create_fn(instance_, &prereq_module)); + prerequisite_modules_.push_back(prereq_module); + modules.push_back(prereq_module); + } + + // Create the module under test and add last. + IREE_ASSERT_OK(params.create_module(instance_, &test_module_)); + modules.push_back(test_module_); + + IREE_ASSERT_OK(iree_vm_context_create_with_modules( + instance_, IREE_VM_CONTEXT_FLAG_NONE, modules.size(), modules.data(), + iree_allocator_system(), &context_)); + } + + void TearDown() override { + if (context_) { + iree_vm_context_release(context_); + context_ = nullptr; + } + if (test_module_) { + iree_vm_module_release(test_module_); + test_module_ = nullptr; + } + for (auto* module : prerequisite_modules_) { + iree_vm_module_release(module); + } + prerequisite_modules_.clear(); + } + + // Runs a function by name. + // Handles DEFERRED status by resuming until completion. + // NOTE: Only supports void-returning functions; test functions should perform + // internal assertions via vm.check.* ops rather than returning values. + iree_status_t RunFunction(const char* function_name) { + iree_vm_function_t function; + IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_name( + test_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, + iree_make_cstring_view(function_name), &function)); + + IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, + iree_vm_context_state_resolver(context_), + iree_allocator_system()); + + iree_vm_function_call_t call; + memset(&call, 0, sizeof(call)); + call.function = function; + + iree_status_t status = + function.module->begin_call(function.module->self, stack, call); + + // Resume until completion. + // Limit iterations to catch infinite yield loops in tests. + constexpr int kMaxResumeCount = 10000; + int resume_count = 0; + while (iree_status_code(status) == IREE_STATUS_DEFERRED) { + iree_status_ignore(status); + if (++resume_count > kMaxResumeCount) { + iree_vm_stack_deinitialize(stack); + return iree_make_status( + IREE_STATUS_RESOURCE_EXHAUSTED, + "resume limit (%d) exceeded for function '%s'; possible infinite " + "yield loop", + kMaxResumeCount, function_name); + } + status = function.module->resume_call(function.module->self, stack, + call.results); + } + + iree_vm_stack_deinitialize(stack); + return status; + } + + protected: + iree_vm_context_t* context_ = nullptr; + iree_vm_module_t* test_module_ = nullptr; + std::vector prerequisite_modules_; +}; + +// Storage for static members. +// Note: This must only be included in one translation unit per test binary. +// The generated test template will include this. +#define IREE_VM_TEST_RUNNER_STATIC_STORAGE() \ + namespace iree::vm::testing { \ + /*static*/ iree_vm_instance_t* VMTestResources::instance_ = nullptr; \ + } + +//===----------------------------------------------------------------------===// +// Standard Test Macros +//===----------------------------------------------------------------------===// +// The parameterized test that runs each function. + +#define IREE_VM_TEST_F(test_class) \ + TEST_P(test_class, Check) { \ + const auto& params = GetParam(); \ + iree_status_t status = RunFunction(params.function_name.c_str()); \ + if (iree_status_is_ok(status)) { \ + if (params.expects_failure) { \ + GTEST_FAIL() << "Function expected failure but succeeded"; \ + } \ + } else { \ + if (params.expects_failure) { \ + iree_status_ignore(status); \ + } else { \ + GTEST_FAIL() << "Function expected success but failed with error: " \ + << iree::Status(std::move(status)).ToString(); \ + } \ + } \ + } + +} // namespace iree::vm::testing + +#endif // IREE_VM_TESTING_TEST_RUNNER_H_ diff --git a/runtime/src/iree/vm/test/async_ops_test_module.h b/runtime/src/iree/vm/testing/yieldable_test_module.h similarity index 70% rename from runtime/src/iree/vm/test/async_ops_test_module.h rename to runtime/src/iree/vm/testing/yieldable_test_module.h index 2bab57e3f227..7d841f2239d3 100644 --- a/runtime/src/iree/vm/test/async_ops_test_module.h +++ b/runtime/src/iree/vm/testing/yieldable_test_module.h @@ -4,16 +4,32 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// A simple native module for testing vm.call.yieldable to imports. +// A native module for testing vm.call.yieldable to imports. // Exports a single function that yields N times before returning. +// +// This module provides a controlled way to test async/yield behavior: +// yield_n(arg: i32, yield_count: i32) -> i32 +// Returns arg + yield_count after yielding yield_count times. +// +// NOTE: This module stores coroutine state (yield_count, accumulator) in module +// state, which means it is not reentrant. Concurrent or interleaved calls on +// the same module instance through the same VM context are not supported. This +// is consistent with IREE's threading model where modules are thread-compatible +// (safe for sequential use) but not thread-safe (no concurrent access). + +#ifndef IREE_VM_TESTING_YIELDABLE_TEST_MODULE_H_ +#define IREE_VM_TESTING_YIELDABLE_TEST_MODULE_H_ #include "iree/base/api.h" #include "iree/vm/native_module.h" +#ifdef __cplusplus +extern "C" { +#endif + //===----------------------------------------------------------------------===// // yieldable_test_module //===----------------------------------------------------------------------===// -// Native module with a single yieldable function for testing. typedef struct yieldable_test_module_state_t { iree_allocator_t allocator; @@ -37,18 +53,50 @@ static iree_status_t yieldable_test_module_yield_variadic_sum_shim( // Parse variadic arguments. // Layout: [segment_count: i32] [values: i32 * segment_count] [yield_count: // i32] + + // Validate minimum size for segment_count field. + if (args_storage.data_length < sizeof(int32_t)) { + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "argument buffer too small for segment_count; have %" PRIhsz + " bytes, need at least %" PRIhsz, + args_storage.data_length, sizeof(int32_t)); + } + const uint8_t* p = args_storage.data; - int32_t segment_count = *(const int32_t*)p; + int32_t segment_count; + memcpy(&segment_count, p, sizeof(int32_t)); p += sizeof(int32_t); + // Validate segment_count is non-negative and buffer has sufficient space. + // Required size: segment_count (1) + values (segment_count) + yield_count + // (1). + if (segment_count < 0) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "segment_count must be non-negative; got %d", + segment_count); + } + iree_host_size_t required_size = + (iree_host_size_t)(segment_count + 2) * sizeof(int32_t); + if (args_storage.data_length < required_size) { + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "argument buffer too small for %d variadic args; have %" PRIhsz + " bytes, need %" PRIhsz, + segment_count, args_storage.data_length, required_size); + } + // Sum all variadic values. int32_t sum = 0; for (int32_t i = 0; i < segment_count; ++i) { - sum += *(const int32_t*)p; + int32_t value; + memcpy(&value, p, sizeof(int32_t)); + sum += value; p += sizeof(int32_t); } - int32_t yield_count = *(const int32_t*)p; + int32_t yield_count; + memcpy(&yield_count, p, sizeof(int32_t)); // Initialize state. state->yield_count = yield_count; @@ -99,11 +147,22 @@ static iree_status_t yieldable_test_module_yield_n_shim( int32_t arg; int32_t yield_count; } args_t; - const args_t* args = (const args_t*)args_storage.data; + + // Validate buffer size. + if (args_storage.data_length < sizeof(args_t)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument buffer too small; have %" PRIhsz + " bytes, need %" PRIhsz, + args_storage.data_length, sizeof(args_t)); + } + + // Use memcpy for alignment-safe access. + args_t args; + memcpy(&args, args_storage.data, sizeof(args_t)); // Initialize state for coroutine. - state->yield_count = args->yield_count; - state->accumulator = args->arg; + state->yield_count = args.yield_count; + state->accumulator = args.arg; if (state->yield_count > 0) { state->accumulator += 1; @@ -200,3 +259,9 @@ static iree_status_t yieldable_test_module_create( &yieldable_test_module_descriptor_, instance, allocator, out_module); } + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // IREE_VM_TESTING_YIELDABLE_TEST_MODULE_H_ diff --git a/samples/custom_dispatch/cuda/kernels/README.md b/samples/custom_dispatch/cuda/kernels/README.md index d4ebc0bf6e6f..3b784617a5b4 100644 --- a/samples/custom_dispatch/cuda/kernels/README.md +++ b/samples/custom_dispatch/cuda/kernels/README.md @@ -75,11 +75,11 @@ nvcc ... (TODO, see CMakeLists.txt) -o kernels_sm_80.ptx #hal.pipeline.binding, #hal.pipeline.binding, #hal.pipeline.binding - ]>) attributes {workgroup_size = [64 : index, 1 : index, 1 : index]} count(%device: !hal.device, %workload: index) -> (index, index, index) { + ]>) count(%device: !hal.device, %workload: index) -> (index, index, index) { %x = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%workload] %c1 = arith.constant 1 : index hal.return %x, %c1, %c1 : index, index, index - } + } attributes {workgroup_size = [64 : index, 1 : index, 1 : index]} } ``` diff --git a/samples/custom_dispatch/cuda/kernels/example.mlir b/samples/custom_dispatch/cuda/kernels/example.mlir index 69f66e6008ae..f332e7302c7e 100644 --- a/samples/custom_dispatch/cuda/kernels/example.mlir +++ b/samples/custom_dispatch/cuda/kernels/example.mlir @@ -79,11 +79,7 @@ module @example attributes {hal.device.targets = [#cuda_target]} { #hal.pipeline.binding, #hal.pipeline.binding, #hal.pipeline.binding - ]>) attributes { - // Certain backends (like CUDA) require a workgroup size (aka block - // size) to be defined ahead of time. - workgroup_size = [64 : index, 1 : index, 1 : index] - } count(%device: !hal.device, %workload: index) -> (index, index, index) { + ]>) count(%device: !hal.device, %workload: index) -> (index, index, index) { // This host function is used to compute the XYZ workgroup count // dispatched at runtime. It can query the %device for capabilities // and limits (shared memory size, etc). The other arguments are the @@ -92,6 +88,10 @@ module @example attributes {hal.device.targets = [#cuda_target]} { %x = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%workload] %c1 = arith.constant 1 : index hal.return %x, %c1, %c1 : index, index, index + } attributes { + // Certain backends (like CUDA) require a workgroup size (aka block + // size) to be defined ahead of time. + workgroup_size = [64 : index, 1 : index, 1 : index] } // Similar to the above but in-place by using a read/write binding. @@ -99,12 +99,12 @@ module @example attributes {hal.device.targets = [#cuda_target]} { layout(#hal.pipeline.layout, #hal.pipeline.binding - ]>) attributes { - workgroup_size = [64 : index, 1 : index, 1 : index] - } count(%device: !hal.device, %workload: index) -> (index, index, index) { + ]>) count(%device: !hal.device, %workload: index) -> (index, index, index) { %x = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%workload] %c1 = arith.constant 1 : index hal.return %x, %c1, %c1 : index, index, index + } attributes { + workgroup_size = [64 : index, 1 : index, 1 : index] } } // hal.executable.source diff --git a/samples/custom_dispatch/hip/kernels/example.mlir b/samples/custom_dispatch/hip/kernels/example.mlir index ed44046cb3c8..aa83fd5490f2 100644 --- a/samples/custom_dispatch/hip/kernels/example.mlir +++ b/samples/custom_dispatch/hip/kernels/example.mlir @@ -70,11 +70,7 @@ module @example attributes {hal.device.targets = [#rocm_target]} { #hal.pipeline.binding, #hal.pipeline.binding, #hal.pipeline.binding - ]>) attributes { - // Certain backends (like ROCM) require a workgroup size (aka block - // size) to be defined ahead of time. - workgroup_size = [64 : index, 1 : index, 1 : index] - } count(%device: !hal.device, %workload: index) -> (index, index, index) { + ]>) count(%device: !hal.device, %workload: index) -> (index, index, index) { // This host function is used to compute the XYZ workgroup count // dispatched at runtime. It can query the %device for capabilities // and limits (shared memory size, etc). The other arguments are the @@ -83,6 +79,10 @@ module @example attributes {hal.device.targets = [#rocm_target]} { %x = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%workload] %c1 = arith.constant 1 : index hal.return %x, %c1, %c1 : index, index, index + } attributes { + // Certain backends (like ROCM) require a workgroup size (aka block + // size) to be defined ahead of time. + workgroup_size = [64 : index, 1 : index, 1 : index] } // Similar to the above but in-place by using a read/write binding. @@ -90,12 +90,12 @@ module @example attributes {hal.device.targets = [#rocm_target]} { layout(#hal.pipeline.layout, #hal.pipeline.binding - ]>) attributes { - workgroup_size = [64 : index, 1 : index, 1 : index] - } count(%device: !hal.device, %workload: index) -> (index, index, index) { + ]>) count(%device: !hal.device, %workload: index) -> (index, index, index) { %x = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%workload] %c1 = arith.constant 1 : index hal.return %x, %c1, %c1 : index, index, index + } attributes { + workgroup_size = [64 : index, 1 : index, 1 : index] } } // hal.executable.source diff --git a/tests/compiler_driver/BUILD.bazel b/tests/compiler_driver/BUILD.bazel index a9b99b8558d2..fafafa864bf0 100644 --- a/tests/compiler_driver/BUILD.bazel +++ b/tests/compiler_driver/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "executable_benchmarks.mlir", "hal_executable.mlir", diff --git a/tests/compiler_driver/streams.mlir b/tests/compiler_driver/streams.mlir index c7d7df7fe9c0..1181afff67f9 100644 --- a/tests/compiler_driver/streams.mlir +++ b/tests/compiler_driver/streams.mlir @@ -1,9 +1,7 @@ // RUN: iree-compile --split-input-file \ // RUN: --iree-hal-target-device=local \ // RUN: --iree-hal-local-target-device-backends=vmvx \ -// RUN: --output-format=vm-bytecode \ -// RUN: --iree-vm-bytecode-module-output-format=flatbuffer-text %s \ -// RUN: --mlir-print-ir-after=iree-vm-ordinal-allocation 2>&1 | FileCheck %s +// RUN: --compile-to=vm %s | FileCheck %s // This file has a few test programs that show how to mix `flow` dispatches into // those created by the `linalg` dispatch region formation: the idea is to use diff --git a/tests/e2e/linalg/BUILD.bazel b/tests/e2e/linalg/BUILD.bazel index ab8360219b06..80d1f03eb50b 100644 --- a/tests/e2e/linalg/BUILD.bazel +++ b/tests/e2e/linalg/BUILD.bazel @@ -73,8 +73,10 @@ iree_check_single_backend_test_suite( VMVX_SRCS = enforce_glob( # keep sorted [ + "argmax.mlir", "conv2d.mlir", "gather_like_ops.mlir", + "index.mlir", "narrow_n_matmuls.mlir", "pack.mlir", "pack_dynamic_inner_tiles.mlir", @@ -84,10 +86,8 @@ VMVX_SRCS = enforce_glob( ], include = ["*.mlir"], exclude = [ - "argmax.mlir", "fp_to_subbyte.mlir", "fp4_f32_conversion.mlir", - "index.mlir", "large_linalg_matmul.mlir", "subbyte_to_fp.mlir", ], @@ -124,18 +124,18 @@ iree_check_single_backend_test_suite( VULKAN_SRCS = enforce_glob( # keep sorted [ + "argmax.mlir", "conv2d.mlir", "gather_like_ops.mlir", + "index.mlir", "narrow_n_matmuls.mlir", "softmax.mlir", "subbyte_to_fp.mlir", ], include = ["*.mlir"], exclude = [ - "argmax.mlir", "fp_to_subbyte.mlir", "fp4_f32_conversion.mlir", - "index.mlir", "large_linalg_matmul.mlir", "pack.mlir", "pack_dynamic_inner_tiles.mlir", @@ -221,20 +221,20 @@ ROCM_SRCS = enforce_glob( # keep sorted [ "argmax.mlir", + "conv2d.mlir", + "fp4_f32_conversion.mlir", + "fp_to_subbyte.mlir", "gather_like_ops.mlir", + "index.mlir", + "narrow_n_matmuls.mlir", "pack_i8.mlir", "softmax.mlir", + "subbyte_to_fp.mlir", "unpack.mlir", ], include = ["*.mlir"], exclude = [ - "conv2d.mlir", - "fp_to_subbyte.mlir", - "fp4_f32_conversion.mlir", - "index.mlir", "large_linalg_matmul.mlir", - "narrow_n_matmuls.mlir", - "subbyte_to_fp.mlir", # https://github.com/llvm/llvm-project/issues/131386 causes # See bug #20294 "pack.mlir", diff --git a/tests/e2e/linalg/CMakeLists.txt b/tests/e2e/linalg/CMakeLists.txt index 9f941a2f6024..33db43f627b4 100644 --- a/tests/e2e/linalg/CMakeLists.txt +++ b/tests/e2e/linalg/CMakeLists.txt @@ -55,8 +55,10 @@ iree_check_single_backend_test_suite( NAME check_vmvx_local-task SRCS + "argmax.mlir" "conv2d.mlir" "gather_like_ops.mlir" + "index.mlir" "narrow_n_matmuls.mlir" "pack.mlir" "pack_dynamic_inner_tiles.mlir" @@ -89,8 +91,10 @@ iree_check_single_backend_test_suite( NAME check_vulkan-spirv_vulkan SRCS + "argmax.mlir" "conv2d.mlir" "gather_like_ops.mlir" + "index.mlir" "narrow_n_matmuls.mlir" "softmax.mlir" "subbyte_to_fp.mlir" @@ -156,9 +160,15 @@ iree_check_single_backend_test_suite( check_rocm_hip SRCS "argmax.mlir" + "conv2d.mlir" + "fp4_f32_conversion.mlir" + "fp_to_subbyte.mlir" "gather_like_ops.mlir" + "index.mlir" + "narrow_n_matmuls.mlir" "pack_i8.mlir" "softmax.mlir" + "subbyte_to_fp.mlir" "unpack.mlir" TARGET_BACKEND "rocm" diff --git a/tests/e2e/linalg_ext_ops/BUILD.bazel b/tests/e2e/linalg_ext_ops/BUILD.bazel index 0fa227cc29c1..68295287a35f 100644 --- a/tests/e2e/linalg_ext_ops/BUILD.bazel +++ b/tests/e2e/linalg_ext_ops/BUILD.bazel @@ -19,6 +19,7 @@ ALL_SRCS = enforce_glob( "attention.mlir", "attention_i1_mask_encoding.mlir", "gather.mlir", + "map_gather.mlir", "map_scatter.mlir", "scan.mlir", "scatter.mlir", @@ -68,7 +69,9 @@ VMVX_SRCS = enforce_glob( # keep sorted [ "arg_compare.mlir", + "attention.mlir", "gather.mlir", + "map_gather.mlir", "map_scatter.mlir", "scan.mlir", "scatter.mlir", @@ -79,7 +82,6 @@ VMVX_SRCS = enforce_glob( ], include = ["*.mlir"], exclude = [ - "attention.mlir", "attention_i1_mask.mlir", "attention_i1_mask_encoding.mlir", ], @@ -109,6 +111,7 @@ LLVM_GPU_SRCS = enforce_glob( "attention.mlir", "attention_i1_mask.mlir", "attention_i1_mask_encoding.mlir", + "map_gather.mlir", "map_scatter.mlir", ], ) @@ -134,6 +137,7 @@ ROCM_HIP_SRCS = enforce_glob( [ "arg_compare.mlir", "gather.mlir", + "map_gather.mlir", "map_scatter.mlir", "scan.mlir", "scatter.mlir", @@ -175,6 +179,7 @@ iree_check_single_backend_test_suite( "attention.mlir", "attention_i1_mask.mlir", "attention_i1_mask_encoding.mlir", + "map_gather.mlir", "map_scatter.mlir", "top-k.mlir", ], @@ -190,9 +195,12 @@ iree_check_single_backend_test_suite( [ "arg_compare.mlir", "gather.mlir", + "map_gather.mlir", + "map_scatter.mlir", "scan.mlir", "scatter.mlir", "sort.mlir", + "top-k.mlir", "winograd_input.mlir", "winograd_output.mlir", ], @@ -201,8 +209,6 @@ iree_check_single_backend_test_suite( "attention.mlir", "attention_i1_mask.mlir", "attention_i1_mask_encoding.mlir", - "map_scatter.mlir", - "top-k.mlir", ], ), driver = "vulkan", diff --git a/tests/e2e/linalg_ext_ops/CMakeLists.txt b/tests/e2e/linalg_ext_ops/CMakeLists.txt index 8dcf9032dd93..8ce864dcfdb2 100644 --- a/tests/e2e/linalg_ext_ops/CMakeLists.txt +++ b/tests/e2e/linalg_ext_ops/CMakeLists.txt @@ -18,6 +18,7 @@ iree_check_single_backend_test_suite( "attention.mlir" "attention_i1_mask_encoding.mlir" "gather.mlir" + "map_gather.mlir" "map_scatter.mlir" "scan.mlir" "scatter.mlir" @@ -56,7 +57,9 @@ iree_check_single_backend_test_suite( check_vmvx_local-task SRCS "arg_compare.mlir" + "attention.mlir" "gather.mlir" + "map_gather.mlir" "map_scatter.mlir" "scan.mlir" "scatter.mlir" @@ -100,6 +103,7 @@ iree_check_single_backend_test_suite( SRCS "arg_compare.mlir" "gather.mlir" + "map_gather.mlir" "map_scatter.mlir" "scan.mlir" "scatter.mlir" @@ -135,9 +139,12 @@ iree_check_single_backend_test_suite( SRCS "arg_compare.mlir" "gather.mlir" + "map_gather.mlir" + "map_scatter.mlir" "scan.mlir" "scatter.mlir" "sort.mlir" + "top-k.mlir" "winograd_input.mlir" "winograd_output.mlir" TARGET_BACKEND diff --git a/tests/e2e/linalg_ext_ops/map_gather.mlir b/tests/e2e/linalg_ext_ops/map_gather.mlir new file mode 100644 index 000000000000..96a7b27d09de --- /dev/null +++ b/tests/e2e/linalg_ext_ops/map_gather.mlir @@ -0,0 +1,54 @@ +func.func @copy_like() { + %source = util.unfoldable_constant dense<123.0> : tensor<4x16x64xf32> + %output = tensor.empty() : tensor<4x16x64xf32> + %padding = arith.constant 0.0 : f32 + %0 = iree_linalg_ext.map_gather %source into %output { + ^bb0(%idx0: index, %idx1: index, %idx2: index): + iree_linalg_ext.yield %idx0, %idx1, %idx2, %padding : index, index, index, f32 + } : tensor<4x16x64xf32> into tensor<4x16x64xf32> -> tensor<4x16x64xf32> + check.expect_almost_eq(%0, %source) : tensor<4x16x64xf32> + return +} + +func.func @expand_shape_like() { + %source = util.unfoldable_constant dense<[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0]> : tensor<16xf32> + %padding = arith.constant 0.0 : f32 + %output = tensor.empty() : tensor<4x4xf32> + %result = iree_linalg_ext.map_gather %source into %output { + ^bb0(%idx0: index, %idx1: index): + %linear = affine.linearize_index disjoint [%idx0, %idx1] by (4, 4) : index + iree_linalg_ext.yield %linear, %padding : index, f32 + } : tensor<16xf32> into tensor<4x4xf32> -> tensor<4x4xf32> + %expected = tensor.expand_shape %source [[0, 1]] output_shape [4, 4] : tensor<16xf32> into tensor<4x4xf32> + check.expect_almost_eq(%result, %expected) : tensor<4x4xf32> + return +} + +func.func @collapse_shape_like() { + %source = util.unfoldable_constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi32> + %padding = arith.constant 0 : i32 + %output = tensor.empty() : tensor<16xi32> + %result = iree_linalg_ext.map_gather %source into %output { + ^bb0(%idx0: index): + %2:2 = affine.delinearize_index %idx0 into (4, 4) : index, index + iree_linalg_ext.yield %2#0, %2#1, %padding : index, index, i32 + } : tensor<4x4xi32> into tensor<16xi32> -> tensor<16xi32> + %expected = tensor.collapse_shape %source [[0, 1]] : tensor<4x4xi32> into tensor<16xi32> + check.expect_eq(%result, %expected) : tensor<16xi32> + return +} + +func.func @pad_slice_like() { + // Source is 4 elements, output is 8 elements (with padding for out-of-bounds) + %source = util.unfoldable_constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32> + %padding = arith.constant 0.0 : f32 + %output = tensor.empty() : tensor<8xf32> + %result = iree_linalg_ext.map_gather %source into %output { + ^bb0(%idx0: index): + // Identity mapping - indices 0-3 are in-bounds, 4-7 get padding + iree_linalg_ext.yield %idx0, %padding : index, f32 + } : tensor<4xf32> into tensor<8xf32> -> tensor<8xf32> + %expected = util.unfoldable_constant dense<[1.0, 2.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0]> : tensor<8xf32> + check.expect_almost_eq(%result, %expected) : tensor<8xf32> + return +} diff --git a/tests/e2e/parameters/BUILD.bazel b/tests/e2e/parameters/BUILD.bazel index 81e36b107eb9..fe20806c512f 100644 --- a/tests/e2e/parameters/BUILD.bazel +++ b/tests/e2e/parameters/BUILD.bazel @@ -15,7 +15,9 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ + "encode_parameters.mlir", "export_parameters.mlir", "generate_splat_archive.mlir", ], @@ -29,6 +31,7 @@ iree_lit_test_suite( tools = [ "//tools:iree-compile", "//tools:iree-dump-parameters", + "//tools:iree-encode-parameters", "//tools:iree-run-module", "@llvm-project//llvm:FileCheck", ], diff --git a/tests/e2e/parameters/CMakeLists.txt b/tests/e2e/parameters/CMakeLists.txt index 61ad996589da..e24056bae27c 100644 --- a/tests/e2e/parameters/CMakeLists.txt +++ b/tests/e2e/parameters/CMakeLists.txt @@ -15,12 +15,14 @@ iree_lit_test_suite( NAME lit SRCS + "encode_parameters.mlir" "export_parameters.mlir" "generate_splat_archive.mlir" TOOLS FileCheck iree-compile iree-dump-parameters + iree-encode-parameters iree-run-module LABELS "driver=local-task" diff --git a/tests/e2e/parameters/encode_parameters.mlir b/tests/e2e/parameters/encode_parameters.mlir new file mode 100644 index 000000000000..9613be11f520 --- /dev/null +++ b/tests/e2e/parameters/encode_parameters.mlir @@ -0,0 +1,68 @@ +// RUN: rm -f %t_main.vmfb %t_encoder.mlir %t_encoder.vmfb %t_input.irpa %t_output.irpa +// +// Compile main module with encoder MLIR output and splat parameter export. +// RUN: iree-compile %s \ +// RUN: --iree-hal-target-device=local \ +// RUN: --iree-hal-local-target-device-backends=vmvx \ +// RUN: --iree-parameter-encoder-output-file=%t_encoder.mlir \ +// RUN: --iree-parameter-splat=%t_input.irpa \ +// RUN: -o %t_main.vmfb +// +// Compile the encoder module separately. +// RUN: iree-compile %t_encoder.mlir \ +// RUN: --iree-hal-target-device=local \ +// RUN: --iree-hal-local-target-device-backends=vmvx \ +// RUN: -o %t_encoder.vmfb +// +// Run the encoder to transform parameters. +// RUN: iree-encode-parameters \ +// RUN: --module=%t_encoder.vmfb \ +// RUN: --parameters=model=%t_input.irpa \ +// RUN: --output=encoded=%t_output.irpa \ +// RUN: --quiet +// +// Run the main module with both input and encoded parameters. +// The encoded parameters contain the pre-computed transformed values. +// RUN: iree-run-module \ +// RUN: --device=local-sync \ +// RUN: --module=%t_main.vmfb \ +// RUN: --function=main \ +// RUN: --parameters=model=%t_input.irpa \ +// RUN: --parameters=encoded=%t_output.irpa | \ +// RUN: FileCheck %s + +// Test parameter transformation with encoder. +// The global loads a parameter and applies an add operation to transform it. +// The encoder runs the add offline, and the main module loads the +// pre-computed result from the encoded parameter scope. + +// CHECK-LABEL: EXEC @main +// CHECK: 256xi32=42 42 42 42 + +// Parameter loaded from input archive (model scope). +// The splat export creates this with all zeros. +util.global private @raw_param = #flow.parameter.named<"model"::"param_global"> : tensor<256xi32> + +// This global holds the transformed value. +util.global private @transformed : tensor<256xi32> + +util.initializer { + // Load the raw parameter (all zeros from splat). + %raw = util.global.load @raw_param : tensor<256xi32> + // Add 42 to each element - this uses the parameter values and can be encoded. + // With input of 0s, result is 42s. + %c42 = arith.constant 42 : i32 + %init = tensor.empty() : tensor<256xi32> + %c42_tensor = linalg.fill ins(%c42 : i32) outs(%init : tensor<256xi32>) -> tensor<256xi32> + %added = linalg.add ins(%raw, %c42_tensor : tensor<256xi32>, tensor<256xi32>) outs(%init : tensor<256xi32>) -> tensor<256xi32> + util.global.store %added, @transformed : tensor<256xi32> + util.return +} + +func.func @main() -> tensor<256xi32> { + // Load and return the full transformed tensor. + // If encoding worked, all elements should be 42 (0 + 42). + // If encoding didn't work, all elements would be 0 (splat init). + %tensor = util.global.load @transformed : tensor<256xi32> + return %tensor : tensor<256xi32> +} diff --git a/tests/e2e/parameters/export_parameters.mlir b/tests/e2e/parameters/export_parameters.mlir index 4288a90bd823..ef8321ee6ffd 100644 --- a/tests/e2e/parameters/export_parameters.mlir +++ b/tests/e2e/parameters/export_parameters.mlir @@ -1,8 +1,8 @@ // RUN: iree-compile %s \ // RUN: --iree-hal-target-device=local \ // RUN: --iree-hal-local-target-device-backends=vmvx \ -// RUN: --iree-opt-export-parameters=scope=%t.irpa \ -// RUN: --iree-opt-export-parameter-minimum-size=0 | \ +// RUN: --iree-parameter-export=scope=%t.irpa \ +// RUN: --iree-parameter-export-minimum-size=0 | \ // RUN: iree-run-module \ // RUN: --device=local-sync \ // RUN: --module=- \ diff --git a/tests/e2e/parameters/generate_splat_archive.mlir b/tests/e2e/parameters/generate_splat_archive.mlir index 79a8b632ddd9..fb888c7aef65 100644 --- a/tests/e2e/parameters/generate_splat_archive.mlir +++ b/tests/e2e/parameters/generate_splat_archive.mlir @@ -2,7 +2,7 @@ // RUN: iree-compile %s \ // RUN: --iree-hal-target-device=local \ // RUN: --iree-hal-local-target-device-backends=vmvx \ -// RUN: --iree-opt-splat-parameters=%t.irpa | \ +// RUN: --iree-parameter-splat=%t.irpa | \ // RUN: iree-run-module \ // RUN: --device=local-sync \ // RUN: --module=- \ diff --git a/tests/e2e/regression/stablehlo/BUILD.bazel b/tests/e2e/regression/stablehlo/BUILD.bazel index 7e2fff3b9132..1a99b7fac129 100644 --- a/tests/e2e/regression/stablehlo/BUILD.bazel +++ b/tests/e2e/regression/stablehlo/BUILD.bazel @@ -35,6 +35,7 @@ NON_CHECK_TESTS = [ iree_check_single_backend_test_suite( name = "check_stablehlo_regression_llvm-cpu", srcs = enforce_glob( + # keep sorted CHECK_TESTS + CPU_SPECIFIC_TESTS, include = ["*.mlir"], exclude = NON_CHECK_TESTS, @@ -47,6 +48,7 @@ iree_check_single_backend_test_suite( iree_check_single_backend_test_suite( name = "check_stablehlo_regression_vmvx", srcs = enforce_glob( + # keep sorted CHECK_TESTS, include = ["*.mlir"], exclude = CPU_SPECIFIC_TESTS + NON_CHECK_TESTS, @@ -58,6 +60,7 @@ iree_check_single_backend_test_suite( iree_check_single_backend_test_suite( name = "check_stablehlo_regression_vulkan-spirv", srcs = enforce_glob( + # keep sorted CHECK_TESTS, include = ["*.mlir"], exclude = CPU_SPECIFIC_TESTS + NON_CHECK_TESTS, @@ -69,6 +72,7 @@ iree_check_single_backend_test_suite( iree_check_single_backend_test_suite( name = "check_stablehlo_regression_cuda", srcs = enforce_glob( + # keep sorted CHECK_TESTS, include = ["*.mlir"], exclude = CPU_SPECIFIC_TESTS + NON_CHECK_TESTS, @@ -88,6 +92,7 @@ iree_check_single_backend_test_suite( iree_check_single_backend_test_suite( name = "check_stablehlo_regression_hip", srcs = enforce_glob( + # keep sorted CHECK_TESTS, include = ["*.mlir"], exclude = CPU_SPECIFIC_TESTS + NON_CHECK_TESTS, @@ -109,6 +114,7 @@ iree_check_single_backend_test_suite( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted NON_CHECK_TESTS, include = ["*.mlir"], exclude = CPU_SPECIFIC_TESTS + CHECK_TESTS, diff --git a/tests/e2e/stablehlo_models/BUILD.bazel b/tests/e2e/stablehlo_models/BUILD.bazel index 3e68c6253689..78b4a6457c89 100644 --- a/tests/e2e/stablehlo_models/BUILD.bazel +++ b/tests/e2e/stablehlo_models/BUILD.bazel @@ -20,6 +20,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "collatz.mlir", "edge_detection.mlir", diff --git a/tests/e2e/stablehlo_ops/BUILD.bazel b/tests/e2e/stablehlo_ops/BUILD.bazel index 5b8fc814f49b..c8297028824f 100644 --- a/tests/e2e/stablehlo_ops/BUILD.bazel +++ b/tests/e2e/stablehlo_ops/BUILD.bazel @@ -13,6 +13,7 @@ package( ) ALL_SRCS = enforce_glob( + # keep sorted [ "abs.mlir", "add.mlir", @@ -240,6 +241,7 @@ iree_check_single_backend_test_suite( "reduce_window.mlir", "remainder.mlir", "reshape.mlir", + "reverse.mlir", "rng_normal.mlir", "rng_uniform.mlir", "round.mlir", @@ -263,7 +265,6 @@ iree_check_single_backend_test_suite( exclude = [ "exponential_fp16.mlir", "fft.mlir", # TODO(#9583) - "reverse.mlir", # TODO(#12415): disabled due to miscompilation on Pixel 6. ], ), compiler_flags = [ diff --git a/tests/e2e/stablehlo_ops/CMakeLists.txt b/tests/e2e/stablehlo_ops/CMakeLists.txt index 2d357621cee5..c76b852331cb 100644 --- a/tests/e2e/stablehlo_ops/CMakeLists.txt +++ b/tests/e2e/stablehlo_ops/CMakeLists.txt @@ -287,6 +287,7 @@ iree_check_single_backend_test_suite( "reduce_window.mlir" "remainder.mlir" "reshape.mlir" + "reverse.mlir" "rng_normal.mlir" "rng_uniform.mlir" "round.mlir" diff --git a/tests/e2e/subbyte_types/BUILD.bazel b/tests/e2e/subbyte_types/BUILD.bazel index ff1b1f3ea643..5782fd122dec 100644 --- a/tests/e2e/subbyte_types/BUILD.bazel +++ b/tests/e2e/subbyte_types/BUILD.bazel @@ -21,6 +21,7 @@ package( iree_check_single_backend_test_suite( name = "check_llvm-cpu_subbyte_emulation", srcs = enforce_glob( + # keep sorted [ "subbyte_types.mlir", ], diff --git a/tests/e2e/tosa_ops/BUILD.bazel b/tests/e2e/tosa_ops/BUILD.bazel index c43028acb86d..fd7c52baac42 100644 --- a/tests/e2e/tosa_ops/BUILD.bazel +++ b/tests/e2e/tosa_ops/BUILD.bazel @@ -13,6 +13,7 @@ package( ) ALL_SRCS = enforce_glob( + # keep sorted [ "abs.mlir", "add.mlir", @@ -99,6 +100,7 @@ iree_check_single_backend_test_suite( ) ROCM_AND_CUDA_SRCS = enforce_glob( + # keep sorted [ "abs.mlir", "add.mlir", diff --git a/tests/external/iree-test-suites/onnx_models/onnx_models_gpu_hip_rdna4.json b/tests/external/iree-test-suites/onnx_models/onnx_models_gpu_hip_rdna4.json new file mode 100644 index 000000000000..cca9d5e3cc8f --- /dev/null +++ b/tests/external/iree-test-suites/onnx_models/onnx_models_gpu_hip_rdna4.json @@ -0,0 +1,30 @@ +{ + "config_name": "gpu_hip_rdna4", + "iree_compile_flags": [ + "--iree-hal-target-device=hip", + "--iree-hip-target=gfx1201" + ], + "iree_run_module_flags": [ + "--device=hip" + ], + "tests_and_expected_outcomes": { + "default": "skip", + "tests/model_zoo/validated/vision/body_analysis_models_test.py::test_models[age_gender/models/age_googlenet.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[alexnet/model/bvlcalexnet-12.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[caffenet/model/caffenet-12.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[densenet-121/model/densenet-12.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[efficientnet-lite4/model/efficientnet-lite4-11.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[inception_and_googlenet/googlenet/model/googlenet-12.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[inception_and_googlenet/inception_v2/model/inception-v2-9.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[mnist/model/mnist-12.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[mobilenet/model/mobilenetv2-12.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[rcnn_ilsvrc13/model/rcnn-ilsvrc13-9.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[resnet/model/resnet50-v1-12.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[resnet/model/resnet50-v2-7.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[shufflenet/model/shufflenet-9.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[shufflenet/model/shufflenet-v2-12.onnx]": "pass", + "tests/model_zoo/validated/vision/object_detection_segmentation_models_test.py::test_models[tiny-yolov2/model/tinyyolov2-8.onnx]": "pass", + "tests/model_zoo/validated/vision/object_detection_segmentation_models_test.py::test_models[yolov2-coco/model/yolov2-coco-9.onnx]": "pass", + "tests/model_zoo/validated/vision/super_resolution_models_test.py::test_models[sub_pixel_cnn_2016/model/super-resolution-10.onnx]": "pass" + } +} diff --git a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_hip_rdna4_O3.json b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_hip_rdna4_O3.json new file mode 100644 index 000000000000..b7afbda04c6a --- /dev/null +++ b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_hip_rdna4_O3.json @@ -0,0 +1,27 @@ +{ + "config_name": "gpu_hip_rdna4", + "iree_compile_flags": [ + "--iree-hal-target-device=hip", + "--iree-hip-target=gfx1201", + "--iree-input-demote-f64-to-f32=false", + "--iree-opt-level=O3" + ], + "iree_run_module_flags": [ + "--device=hip" + ], + "skip_compile_tests": [ + "onnx/node/generated/test_dequantizelinear", + "onnx/node/generated/test_einsum_inner_prod", + "onnx/node/generated/test_group_normalization_epsilon_expanded", + "onnx/node/generated/test_group_normalization_example_expanded", + "onnx/node/generated/test_nonmaxsuppression_two_batches", + "onnx/node/generated/test_constantofshape_int_shape_zero" + ], + "skip_run_tests": [ + "onnx/node/generated/test_top_k", + "onnx/node/generated/test_top_k_negative_axis", + "onnx/node/generated/test_top_k_smallest" + ], + "expected_compile_failures": [], + "expected_run_failures": [] +} diff --git a/tests/external/iree-test-suites/torch_ops/torch_ops_cpu_llvm_sync.json b/tests/external/iree-test-suites/torch_ops/torch_ops_cpu_llvm_sync.json index 9f064617587a..376e5cfc1ebc 100644 --- a/tests/external/iree-test-suites/torch_ops/torch_ops_cpu_llvm_sync.json +++ b/tests/external/iree-test-suites/torch_ops/torch_ops_cpu_llvm_sync.json @@ -6,8 +6,23 @@ "--iree-llvmcpu-target-cpu=host" ], "iree_run_module_flags": [], - "skip_compile_tests": [], - "skip_run_tests": [], + "skip_compile_tests": [ + "InterestingShapesBiasAdd/997x997xi8_NN_bias" + ], + "skip_run_tests": [ + "AB/8192x8192xf32_bench", + "AB/4096x4096xf32_bench", + "AB/2048x2048xf32_bench" + ], "expected_compile_failures": [], - "expected_run_failures": [] + "expected_run_failures": [], + "golden_times_ms": { + "AB/8192x8192xf32_bench": 5587.488262355328, + "AB/1024x1024xf32_bench": 1.1874544098876767, + "AB/256x256xf32_bench": 0.044473891122477315, + "AB/128x128xf32_bench": 0.03577919309132721, + "AB/2048x2048xf32_bench": 10.41092509722771, + "AB/4096x4096xf32_bench": 131.7884701769799, + "AB/512x512xf32_bench": 0.12528392536236224 + } } diff --git a/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx1100_O3.json b/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx1100_O3.json index 414cefe81cb0..4f8f6ff4b11c 100644 --- a/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx1100_O3.json +++ b/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx1100_O3.json @@ -8,11 +8,22 @@ "iree_run_module_flags": [ "--device=hip" ], - "skip_compile_tests": [], + "skip_compile_tests": [ + "InterestingShapesBiasAdd/997x997xi8_NN_bias" + ], "skip_run_tests": [ - "generated/test_a_b_plus_c_float16", - "generated/test_a_t_b_float16" + "ABPlusC/64x64xf16", + "ATB/64x64xf16" ], "expected_compile_failures": [], - "expected_run_failures": [] + "expected_run_failures": [], + "golden_times_ms": { + "AB/8192x8192xf32_bench": 211.36675303181013, + "AB/1024x1024xf32_bench": 0.473261977629155, + "AB/256x256xf32_bench": 0.13523232175050667, + "AB/128x128xf32_bench": 0.10226182854380102, + "AB/2048x2048xf32_bench": 3.35047472617589, + "AB/4096x4096xf32_bench": 26.499415814344378, + "AB/512x512xf32_bench": 0.16945170770798412 + } } diff --git a/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx1201_O3.json b/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx1201_O3.json new file mode 100644 index 000000000000..baebf6555755 --- /dev/null +++ b/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx1201_O3.json @@ -0,0 +1,36 @@ +{ + "config_name": "gpu_hip_gfx1201", + "iree_compile_flags": [ + "--iree-hal-target-device=hip", + "--iree-hip-target=gfx1201", + "--iree-opt-level=O3" + ], + "iree_run_module_flags": [ + "--device=hip" + ], + "skip_compile_tests": [ + "InterestingShapesBiasAdd/997x997xi8_NN_bias" + ], + "skip_run_tests": [ + "ABPlusC/64x64xf16", + "ATB/64x64xf16", + "AB/1024x1024xf32_bench", + "AB/128x128xf32_bench", + "AB/2048x2048xf32_bench", + "AB/256x256xf32_bench", + "AB/4096x4096xf32_bench", + "AB/512x512xf32_bench", + "AB/8192x8192xf32_bench" + ], + "expected_compile_failures": [], + "expected_run_failures": [], + "golden_times_ms": { + "AB/1024x1024xf32_bench" : 0.0, + "AB/128x128xf32_bench" : 0.0, + "AB/2048x2048xf32_bench" : 0.0, + "AB/256x256xf32_bench" : 0.0, + "AB/4096x4096xf32_bench" : 0.0, + "AB/512x512xf32_bench" : 0.0, + "AB/8192x8192xf32_bench" : 0.0 + } +} diff --git a/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx942_O3.json b/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx942_O3.json index 281b9e432a51..858a5eadcc0c 100644 --- a/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx942_O3.json +++ b/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx942_O3.json @@ -8,10 +8,21 @@ "iree_run_module_flags": [ "--device=hip" ], - "skip_compile_tests": [], + "skip_compile_tests": [ + "InterestingShapesBiasAdd/997x997xi8_NN_bias" + ], "skip_run_tests": [ - "generated/test_a_t_b_float16" + "ATB/64x64xf16" ], "expected_compile_failures": [], - "expected_run_failures": [] + "expected_run_failures": [], + "golden_times_ms": { + "AB/8192x8192xf32_bench": 10.345919956763586, + "AB/1024x1024xf32_bench": 0.12306747736492638, + "AB/256x256xf32_bench": 0.06101156551354003, + "AB/128x128xf32_bench": 0.052202587451140196, + "AB/2048x2048xf32_bench": 0.2345012331137894, + "AB/4096x4096xf32_bench": 1.423236482683913, + "AB/512x512xf32_bench": 0.07336693902980182 + } } diff --git a/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_vulkan_O3.json b/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_vulkan_O3.json index a2d406c149cd..d511f7942c3b 100644 --- a/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_vulkan_O3.json +++ b/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_vulkan_O3.json @@ -8,15 +8,29 @@ "--device=vulkan" ], "skip_compile_tests": [ - "generated/test_a_t_b_float16", - "generated/test_a_b_plus_c_float16", - "generated/test_a_b_t_float16", - "generated/test_relu_a_b_plus_c_float16", - "generated/test_gelu_a_b_plus_c_float16" + "ATB/64x64xf16", + "ABPlusC/64x64xf16", + "ABT/64x64xf16", + "ReluABPlusC/64x64xf16", + "GeluABPlusC/64x64xf16", + "AB/64x64xf16", + "AB/Nx64xf16_64xNxf16", + "InterestingShapesBiasAdd/997x997xi8_NN_bias" ], "skip_run_tests": [ - "generated/test_a_b_float16" + "InterestingShapesBiasAdd/1152x997xf16_matmul_997x576xf16_NN", + "InterestingShapesBiasAdd/6144x419xbf16_matmul_419x384xbf16_NT", + "InterestingShapesBiasAdd/997x997xf16_NT_bias" ], "expected_compile_failures": [], - "expected_run_failures": [] + "expected_run_failures": [], + "golden_times_ms": { + "AB/8192x8192xf32_bench": 107.60851647438749, + "AB/1024x1024xf32_bench": 0.4509026762196051, + "AB/256x256xf32_bench": 0.1743873563575457, + "AB/128x128xf32_bench": 0.148048073505022, + "AB/2048x2048xf32_bench": 1.5943956199949283, + "AB/4096x4096xf32_bench": 10.252960922347533, + "AB/512x512xf32_bench": 0.23526767679662958 + } } diff --git a/third_party/llvm-project b/third_party/llvm-project index fc66e8eaa7e8..31f0e3e64485 160000 --- a/third_party/llvm-project +++ b/third_party/llvm-project @@ -1 +1 @@ -Subproject commit fc66e8eaa7e843305b917f749ba02775d3a3d5ac +Subproject commit 31f0e3e644857ed4886884b650530ef791680f95 diff --git a/tools/BUILD.bazel b/tools/BUILD.bazel index 666779b937c5..32fce9c6f8c5 100644 --- a/tools/BUILD.bazel +++ b/tools/BUILD.bazel @@ -154,6 +154,27 @@ iree_runtime_cc_binary( ], ) +iree_runtime_cc_binary( + name = "iree-encode-parameters", + srcs = ["iree-encode-parameters-main.c"], + deps = [ + "//runtime/src/iree/base", + "//runtime/src/iree/base/internal:flags", + "//runtime/src/iree/hal", + "//runtime/src/iree/io:file_handle", + "//runtime/src/iree/io:parameter_index", + "//runtime/src/iree/io:parameter_index_provider", + "//runtime/src/iree/io:scope_map", + "//runtime/src/iree/io:stream", + "//runtime/src/iree/io/formats/irpa", + "//runtime/src/iree/modules/hal", + "//runtime/src/iree/tooling:context_util", + "//runtime/src/iree/tooling:function_util", + "//runtime/src/iree/tooling:parameter_util", + "//runtime/src/iree/vm", + ], +) + iree_runtime_cc_binary( name = "iree-fatelf", srcs = ["iree-fatelf.c"], diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 7a70cfb36300..329f589cc70d 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -221,6 +221,30 @@ iree_cc_binary( INSTALL_COMPONENT IREETools-Runtime ) +iree_cc_binary( + NAME + iree-encode-parameters + SRCS + "iree-encode-parameters-main.c" + DEPS + iree::base + iree::base::internal::flags + iree::hal + iree::io::file_handle + iree::io::formats::irpa + iree::io::parameter_index + iree::io::parameter_index_provider + iree::io::scope_map + iree::io::stream + iree::modules::hal + iree::tooling::context_util + iree::tooling::function_util + iree::tooling::parameter_util + iree::vm + COVERAGE ${IREE_ENABLE_RUNTIME_COVERAGE} + INSTALL_COMPONENT IREETools-Runtime +) + # Only enable fatelf tool when we're compiling it in. # Currently it requires that the host and target both support embedded ELFs as # the ELF implementation is only compiled when the target supports it. diff --git a/tools/iree-dump-module-main.c b/tools/iree-dump-module-main.c index a6a596da1cb0..af384c196621 100644 --- a/tools/iree-dump-module-main.c +++ b/tools/iree-dump-module-main.c @@ -336,6 +336,23 @@ static void iree_tooling_print_rwdata_segment_defs( } } +// Returns the first export name for the given internal ordinal, or empty if not +// exported. +static iree_string_view_t iree_tooling_lookup_export_name( + iree_host_size_t internal_ordinal, + iree_vm_ExportFunctionDef_vec_t export_defs) { + for (size_t j = 0; j < iree_vm_ExportFunctionDef_vec_len(export_defs); ++j) { + iree_vm_ExportFunctionDef_table_t export_def = + iree_vm_ExportFunctionDef_vec_at(export_defs, j); + if (iree_vm_ExportFunctionDef_internal_ordinal(export_def) == + internal_ordinal) { + const char* name = iree_vm_ExportFunctionDef_local_name(export_def); + return iree_make_string_view(name, strlen(name)); + } + } + return iree_string_view_empty(); +} + static void iree_tooling_print_function_descriptors( iree_vm_FunctionDescriptor_vec_t descriptors, iree_vm_ExportFunctionDef_vec_t export_defs) { @@ -550,11 +567,23 @@ static iree_status_t iree_tooling_dump_module_disassembly( iree_status_t status = iree_vm_bytecode_module_create( instance, IREE_VM_BYTECODE_MODULE_FLAG_ALLOW_PLACEHOLDER_TYPES, archive_contents, iree_allocator_null(), host_allocator, &module); + iree_const_byte_span_t flatbuffer_contents = iree_const_byte_span_empty(); + iree_host_size_t rodata_offset = 0; + if (iree_status_is_ok(status)) { + status = iree_vm_bytecode_archive_parse_header( + archive_contents, &flatbuffer_contents, &rodata_offset); + } if (iree_status_is_ok(status)) { iree_string_builder_t builder; iree_string_builder_initialize(host_allocator, &builder); - // Iterate over exported functions and build the disassembly output. + // Extract export names from the flatbuffer module definition. + iree_vm_BytecodeModuleDef_table_t module_def = + iree_vm_BytecodeModuleDef_as_root(flatbuffer_contents.data); + iree_vm_ExportFunctionDef_vec_t export_defs = + iree_vm_BytecodeModuleDef_exported_functions(module_def); + + // Iterate over internal functions and build the disassembly output. iree_vm_module_signature_t signature = iree_vm_module_signature(module); for (iree_host_size_t i = 0; i < signature.internal_function_count; ++i) { iree_vm_function_t function; @@ -562,8 +591,15 @@ static iree_status_t iree_tooling_dump_module_disassembly( module, IREE_VM_FUNCTION_LINKAGE_INTERNAL, i, &function); if (!iree_status_is_ok(status)) break; + // Get function name from exports if available, otherwise use internal + // name. + iree_string_view_t export_name = + iree_tooling_lookup_export_name(i, export_defs); + iree_string_view_t function_name = iree_string_view_is_empty(export_name) + ? iree_vm_function_name(&function) + : export_name; + // Apply filter (ordinal or name) if provided. - iree_string_view_t function_name = iree_vm_function_name(&function); if (!iree_string_view_is_empty(function_filter)) { uint32_t filter_ordinal = -1; if (iree_string_view_atoi_uint32(function_filter, &filter_ordinal)) { diff --git a/tools/iree-encode-parameters-main.c b/tools/iree-encode-parameters-main.c new file mode 100644 index 000000000000..447bf188cec6 --- /dev/null +++ b/tools/iree-encode-parameters-main.c @@ -0,0 +1,1116 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include + +#include "iree/base/api.h" +#include "iree/base/internal/flags.h" +#include "iree/hal/api.h" +#include "iree/io/file_handle.h" +#include "iree/io/formats/irpa/irpa_builder.h" +#include "iree/io/parameter_index.h" +#include "iree/io/parameter_index_provider.h" +#include "iree/io/scope_map.h" +#include "iree/io/stream.h" +#include "iree/modules/hal/module.h" +#include "iree/tooling/context_util.h" +#include "iree/tooling/function_util.h" +#include "iree/tooling/parameter_util.h" +#include "iree/vm/api.h" + +//===----------------------------------------------------------------------===// +// Flags +//===----------------------------------------------------------------------===// + +IREE_FLAG(bool, list_targets, false, + "Lists the targets an encoding module can produce parameters for and " + "exit."); + +IREE_FLAG(bool, list_parameters, false, + "Lists the parameters that will be encoded and exit."); + +IREE_FLAG(string, target, "", + "Target to use for encoding. If not specified, uses auto-detection."); + +IREE_FLAG(bool, quiet, false, + "Suppress output except for errors. Exit code indicates success."); + +IREE_FLAG_LIST(string, output, + "Specifies an output parameter file per scope.\n" + "Format: `scope=path.irpa` or `path.irpa` for default scope.\n" + "Example: `--output=encoded=output.irpa`"); + +//===----------------------------------------------------------------------===// +// Encoder target discovery +//===----------------------------------------------------------------------===// + +// Encoder function set for a single target. +typedef struct iree_encode_target_t { + iree_string_view_t target; + iree_vm_function_t indices_fn; + iree_vm_function_t steps_fn; + iree_vm_function_t encode_fn; +} iree_encode_target_t; + +// Storage for discovered encoder targets. +typedef struct iree_encode_target_set_t { + iree_vm_function_t detect_target_fn; + iree_host_size_t target_count; + iree_host_size_t target_capacity; + iree_encode_target_t* targets; + iree_allocator_t allocator; +} iree_encode_target_set_t; + +static void iree_encode_target_set_initialize( + iree_allocator_t allocator, iree_encode_target_set_t* out_target_set) { + memset(out_target_set, 0, sizeof(*out_target_set)); + out_target_set->allocator = allocator; +} + +static void iree_encode_target_set_deinitialize( + iree_encode_target_set_t* target_set) { + if (target_set->targets) { + iree_allocator_free(target_set->allocator, target_set->targets); + } + memset(target_set, 0, sizeof(*target_set)); +} + +static iree_status_t iree_encode_target_set_add( + iree_encode_target_set_t* target_set, iree_string_view_t target_name, + iree_encode_target_t** out_target) { + // Check if target already exists. + for (iree_host_size_t i = 0; i < target_set->target_count; ++i) { + if (iree_string_view_equal(target_set->targets[i].target, target_name)) { + *out_target = &target_set->targets[i]; + return iree_ok_status(); + } + } + // Grow if needed. + if (target_set->target_count >= target_set->target_capacity) { + iree_host_size_t new_capacity = + target_set->target_capacity ? target_set->target_capacity * 2 : 4; + IREE_RETURN_IF_ERROR(iree_allocator_realloc( + target_set->allocator, new_capacity * sizeof(iree_encode_target_t), + (void**)&target_set->targets)); + target_set->target_capacity = new_capacity; + } + // Add new target. + iree_encode_target_t* target = &target_set->targets[target_set->target_count]; + memset(target, 0, sizeof(*target)); + target->target = target_name; + ++target_set->target_count; + *out_target = target; + return iree_ok_status(); +} + +// Looks up a reflection attribute value by key. +static iree_string_view_t iree_encode_lookup_reflection_attr( + iree_vm_function_t* function, iree_string_view_t key) { + return iree_vm_function_lookup_attr_by_name(function, key); +} + +// Discovers encoder functions from the module by scanning exported function +// attributes. +static iree_status_t iree_encode_discover_functions( + iree_vm_module_t* module, iree_encode_target_set_t* target_set) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_vm_module_signature_t signature = iree_vm_module_signature(module); + + for (iree_host_size_t i = 0; i < signature.export_function_count; ++i) { + iree_vm_function_t function; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_vm_module_lookup_function_by_ordinal( + module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &function)); + + // Check for iree.encode.function attribute. + iree_string_view_t encode_function = iree_encode_lookup_reflection_attr( + &function, IREE_SV("iree.encode.function")); + if (iree_string_view_is_empty(encode_function)) continue; + + if (iree_string_view_equal(encode_function, IREE_SV("detect_target"))) { + target_set->detect_target_fn = function; + } else { + // Get target name for indices/steps/encode functions. + iree_string_view_t target_name = iree_encode_lookup_reflection_attr( + &function, IREE_SV("iree.encode.target")); + if (iree_string_view_is_empty(target_name)) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "encoder function missing iree.encode.target"); + } + + iree_encode_target_t* target = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_encode_target_set_add(target_set, target_name, &target)); + + if (iree_string_view_equal(encode_function, IREE_SV("indices"))) { + target->indices_fn = function; + } else if (iree_string_view_equal(encode_function, IREE_SV("steps"))) { + target->steps_fn = function; + } else if (iree_string_view_equal(encode_function, IREE_SV("encode"))) { + target->encode_fn = function; + } + } + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// Output scope/archive types +//===----------------------------------------------------------------------===// + +typedef struct iree_output_scope_t { + iree_string_view_t scope; + iree_string_view_t path; +} iree_output_scope_t; + +typedef struct iree_output_scope_list_t { + iree_host_size_t count; + iree_output_scope_t* entries; + iree_allocator_t allocator; +} iree_output_scope_list_t; + +static void iree_output_scope_list_initialize(iree_allocator_t allocator, + iree_output_scope_list_t* list) { + memset(list, 0, sizeof(*list)); + list->allocator = allocator; +} + +static void iree_output_scope_list_deinitialize( + iree_output_scope_list_t* list) { + if (list->entries) { + iree_allocator_free(list->allocator, list->entries); + } + memset(list, 0, sizeof(*list)); +} + +// Archive context for a single output scope. +typedef struct iree_output_archive_t { + iree_string_view_t scope; + iree_string_view_t path; + iree_io_parameter_archive_builder_t builder; + iree_io_file_handle_t* file_handle; + iree_io_parameter_index_t* index; + iree_io_parameter_provider_t* provider; +} iree_output_archive_t; + +static void iree_output_archive_deinitialize(iree_output_archive_t* archive) { + iree_io_parameter_provider_release(archive->provider); + iree_io_parameter_index_release(archive->index); + iree_io_file_handle_release(archive->file_handle); + iree_io_parameter_archive_builder_deinitialize(&archive->builder); +} + +//===----------------------------------------------------------------------===// +// Load modules and discover encoder functions +//===----------------------------------------------------------------------===// + +static iree_status_t iree_encode_load_and_discover( + iree_vm_instance_t* instance, iree_allocator_t host_allocator, + iree_tooling_module_list_t* out_module_list, + iree_vm_module_t** out_encoder_module, + iree_encode_target_set_t* out_target_set) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_tooling_module_list_initialize(out_module_list); + iree_encode_target_set_initialize(host_allocator, out_target_set); + + // Load modules from flags. + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_tooling_load_modules_from_flags(instance, host_allocator, + out_module_list)); + + if (out_module_list->count == 0) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "no modules specified; use --module=path.vmfb"); + } + + // Encoder module is the last module (by convention). + *out_encoder_module = out_module_list->values[out_module_list->count - 1]; + + // Discover encoder functions. + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_encode_discover_functions(*out_encoder_module, out_target_set)); + + if (out_target_set->target_count == 0) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status( + IREE_STATUS_NOT_FOUND, + "no encoder functions found in module; ensure the module was produced " + "by iree-compile with --iree-parameter-encoder-output-file"); + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// Select target +//===----------------------------------------------------------------------===// + +static iree_status_t iree_encode_select_target( + iree_encode_target_set_t* target_set, + iree_encode_target_t** out_selected_target) { + iree_string_view_t target_flag = iree_make_cstring_view(FLAG_target); + + if (iree_string_view_is_empty(target_flag)) { + // Use first target. + *out_selected_target = &target_set->targets[0]; + return iree_ok_status(); + } + + // Find matching target. + for (iree_host_size_t i = 0; i < target_set->target_count; ++i) { + if (iree_string_view_equal(target_set->targets[i].target, target_flag)) { + *out_selected_target = &target_set->targets[i]; + return iree_ok_status(); + } + } + + return iree_make_status(IREE_STATUS_NOT_FOUND, + "target '%s' not found in encoder module; " + "use --list-targets to see available targets", + FLAG_target); +} + +static iree_status_t iree_encode_validate_target(iree_encode_target_t* target) { + if (!target->indices_fn.module) { + return iree_make_status(IREE_STATUS_NOT_FOUND, + "indices function not found for target '%.*s'; " + "encoder module may be incomplete", + (int)target->target.size, target->target.data); + } + if (!target->encode_fn.module) { + return iree_make_status(IREE_STATUS_NOT_FOUND, + "encode function not found for target '%.*s'; " + "encoder module may be incomplete", + (int)target->target.size, target->target.data); + } + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// --list_targets implementation +//===----------------------------------------------------------------------===// + +static iree_status_t iree_encode_print_targets( + iree_vm_module_t* encoder_module, iree_encode_target_set_t* target_set) { + iree_string_view_t module_name = iree_vm_module_name(encoder_module); + fprintf(stdout, "Encoder module: %.*s\n", (int)module_name.size, + module_name.data); + fprintf(stdout, "Available targets:\n"); + + for (iree_host_size_t i = 0; i < target_set->target_count; ++i) { + iree_encode_target_t* target = &target_set->targets[i]; + fprintf(stdout, " %.*s\n", (int)target->target.size, target->target.data); + + iree_string_view_t scopes = iree_encode_lookup_reflection_attr( + &target->indices_fn, IREE_SV("iree.encode.scopes")); + if (!iree_string_view_is_empty(scopes)) { + fprintf(stdout, " scopes: %.*s\n", (int)scopes.size, scopes.data); + } + } + + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// Call indices function +//===----------------------------------------------------------------------===// + +// Creates a temporary context and calls the indices function. +// The indices function returns constant data and doesn't need parameters. +// TODO(benvanik): Consider calling without full context if function has no +// imports. +static iree_status_t iree_encode_call_indices( + iree_vm_instance_t* instance, iree_tooling_module_list_t* module_list, + iree_encode_target_t* target, iree_allocator_t host_allocator, + iree_vm_list_t** out_indices_list) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_vm_context_t* context = NULL; + iree_hal_device_t* device = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_tooling_create_context_from_flags( + instance, module_list->count, module_list->values, + /*default_device_uri=*/iree_string_view_empty(), host_allocator, + &context, &device, /*out_device_allocator=*/NULL)); + + // Invoke indices function. + iree_vm_list_t* outputs = NULL; + iree_status_t status = iree_vm_list_create(iree_vm_make_undefined_type_def(), + 1, host_allocator, &outputs); + if (iree_status_is_ok(status)) { + status = iree_vm_invoke( + context, target->indices_fn, IREE_VM_INVOCATION_FLAG_NONE, + /*policy=*/NULL, /*inputs=*/NULL, outputs, host_allocator); + } + + // Extract result list. + if (iree_status_is_ok(status)) { + iree_vm_ref_t list_ref = iree_vm_ref_null(); + status = iree_vm_list_get_ref_assign(outputs, 0, &list_ref); + if (iree_status_is_ok(status)) { + *out_indices_list = iree_vm_list_deref(list_ref); + if (*out_indices_list) { + iree_vm_list_retain(*out_indices_list); + } + } + } + + iree_vm_list_release(outputs); + iree_hal_device_release(device); + iree_vm_context_release(context); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +//===----------------------------------------------------------------------===// +// --list_parameters implementation +//===----------------------------------------------------------------------===// + +static iree_status_t iree_encode_print_parameters( + iree_vm_list_t* indices_list) { + iree_host_size_t scope_count = iree_vm_list_size(indices_list); + + for (iree_host_size_t scope_i = 0; scope_i < scope_count; ++scope_i) { + iree_vm_ref_t scope_struct_ref = iree_vm_ref_null(); + if (!iree_status_is_ok(iree_vm_list_get_ref_assign(indices_list, scope_i, + &scope_struct_ref))) { + continue; + } + iree_vm_list_t* scope_struct = iree_vm_list_deref(scope_struct_ref); + if (!scope_struct || iree_vm_list_size(scope_struct) < 2) continue; + + // Get scope name. + iree_vm_ref_t scope_name_ref = iree_vm_ref_null(); + iree_vm_list_get_ref_assign(scope_struct, 0, &scope_name_ref); + iree_vm_buffer_t* scope_name_buffer = iree_vm_buffer_deref(scope_name_ref); + iree_string_view_t scope_name = + scope_name_buffer ? iree_vm_buffer_as_string(scope_name_buffer) + : IREE_SV(""); + + fprintf(stdout, "Scope: \"%.*s\"\n", (int)scope_name.size, scope_name.data); + + // Get entries list. + iree_vm_ref_t entries_ref = iree_vm_ref_null(); + iree_vm_list_get_ref_assign(scope_struct, 1, &entries_ref); + iree_vm_list_t* entries = iree_vm_list_deref(entries_ref); + if (!entries) continue; + + // Print each entry. + iree_host_size_t entry_count = iree_vm_list_size(entries); + for (iree_host_size_t entry_i = 0; entry_i < entry_count; ++entry_i) { + iree_vm_ref_t entry_ref = iree_vm_ref_null(); + if (!iree_status_is_ok( + iree_vm_list_get_ref_assign(entries, entry_i, &entry_ref))) { + continue; + } + iree_vm_list_t* entry = iree_vm_list_deref(entry_ref); + if (!entry || iree_vm_list_size(entry) < 5) continue; + + iree_vm_value_t type_value, length_value; + iree_vm_list_get_value(entry, 0, &type_value); + iree_vm_list_get_value(entry, 3, &length_value); + + iree_vm_ref_t key_ref = iree_vm_ref_null(); + iree_vm_list_get_ref_assign(entry, 1, &key_ref); + iree_vm_buffer_t* key_buffer = iree_vm_buffer_deref(key_ref); + iree_string_view_t key = key_buffer ? iree_vm_buffer_as_string(key_buffer) + : IREE_SV(""); + + if (type_value.i64 == 0) { + // SPLAT entry. + iree_vm_value_t pattern_value, pattern_length_value; + iree_vm_list_get_value(entry, 4, &pattern_value); + iree_vm_list_get_value(entry, 5, &pattern_length_value); + fprintf(stdout, + " %.*s: SPLAT, %" PRIu64 " bytes, pattern=0x%0*" PRIx64 "\n", + (int)key.size, key.data, (uint64_t)length_value.i64, + (int)pattern_length_value.i64 * 2, (uint64_t)pattern_value.i64); + } else { + // DATA entry. + iree_vm_value_t alignment_value; + iree_vm_list_get_value(entry, 4, &alignment_value); + fprintf(stdout, + " %.*s: DATA, %" PRIu64 " bytes, alignment %" PRIu64 "\n", + (int)key.size, key.data, (uint64_t)length_value.i64, + (uint64_t)alignment_value.i64); + } + } + } + + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// Parse output flags +//===----------------------------------------------------------------------===// + +static iree_status_t iree_encode_parse_output_flags( + iree_output_scope_list_t* list) { + iree_host_size_t count = FLAG_output_list().count; + if (count == 0) return iree_ok_status(); + + IREE_RETURN_IF_ERROR(iree_allocator_malloc( + list->allocator, count * sizeof(iree_output_scope_t), + (void**)&list->entries)); + list->count = count; + + for (iree_host_size_t i = 0; i < count; ++i) { + iree_string_view_t flag = FLAG_output_list().values[i]; + iree_string_view_t scope, path; + if (iree_string_view_split(flag, '=', &scope, &path) == -1) { + // No scope provided - use empty scope. + path = scope; + scope = iree_string_view_empty(); + } + list->entries[i].scope = scope; + list->entries[i].path = path; + } + + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// Create output archives +//===----------------------------------------------------------------------===// + +// Parses parameter indices and populates archive builders. +static iree_status_t iree_encode_parse_indices_into_archives( + iree_vm_list_t* indices_list, iree_output_archive_t* archives, + iree_host_size_t archive_count) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_host_size_t scope_count = iree_vm_list_size(indices_list); + for (iree_host_size_t scope_i = 0; scope_i < scope_count; ++scope_i) { + // Get scope struct: [scope_name, entries_list]. + iree_vm_ref_t scope_struct_ref = iree_vm_ref_null(); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_vm_list_get_ref_assign(indices_list, scope_i, &scope_struct_ref)); + iree_vm_list_t* scope_struct = iree_vm_list_deref(scope_struct_ref); + if (!scope_struct || iree_vm_list_size(scope_struct) < 2) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "invalid scope struct in indices"); + } + + // Get scope name. + iree_vm_ref_t scope_name_ref = iree_vm_ref_null(); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_vm_list_get_ref_assign(scope_struct, 0, &scope_name_ref)); + iree_vm_buffer_t* scope_name_buffer = iree_vm_buffer_deref(scope_name_ref); + iree_string_view_t scope_name = + scope_name_buffer ? iree_vm_buffer_as_string(scope_name_buffer) + : iree_string_view_empty(); + + // Find matching archive. + iree_output_archive_t* archive = NULL; + for (iree_host_size_t j = 0; j < archive_count; ++j) { + if (iree_string_view_equal(archives[j].scope, scope_name)) { + archive = &archives[j]; + break; + } + } + if (!archive) continue; // Scope not in output list. + + // Get entries list. + iree_vm_ref_t entries_ref = iree_vm_ref_null(); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_vm_list_get_ref_assign(scope_struct, 1, &entries_ref)); + iree_vm_list_t* entries = iree_vm_list_deref(entries_ref); + if (!entries) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "invalid entries in scope struct"); + } + + // Process each parameter entry. + iree_host_size_t entry_count = iree_vm_list_size(entries); + for (iree_host_size_t entry_i = 0; entry_i < entry_count; ++entry_i) { + iree_vm_ref_t entry_ref = iree_vm_ref_null(); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_vm_list_get_ref_assign(entries, entry_i, &entry_ref)); + iree_vm_list_t* entry = iree_vm_list_deref(entry_ref); + if (!entry || iree_vm_list_size(entry) < 5) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "invalid entry in entries list"); + } + + // Parse entry fields: [type, key, metadata, length, ...]. + iree_vm_value_t type_value, length_value; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_vm_list_get_value(entry, 0, &type_value)); + + iree_vm_ref_t key_ref = iree_vm_ref_null(); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_vm_list_get_ref_assign(entry, 1, &key_ref)); + iree_vm_buffer_t* key_buffer = iree_vm_buffer_deref(key_ref); + if (!key_buffer) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "parameter entry missing key"); + } + iree_string_view_t key = iree_vm_buffer_as_string(key_buffer); + + iree_vm_ref_t metadata_ref = iree_vm_ref_null(); + iree_vm_list_get_ref_assign(entry, 2, &metadata_ref); + iree_vm_buffer_t* metadata_buffer = iree_vm_buffer_deref(metadata_ref); + iree_const_byte_span_t metadata = iree_const_byte_span_empty(); + if (metadata_buffer) { + metadata = iree_vm_buffer_const_contents(metadata_buffer); + } + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_vm_list_get_value(entry, 3, &length_value)); + uint64_t length = (uint64_t)length_value.i64; + + if (type_value.i64 == 0) { + // SPLAT entry: [type, key, metadata, length, pattern, pattern_length]. + iree_vm_value_t pattern_value, pattern_length_value; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_vm_list_get_value(entry, 4, &pattern_value)); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_vm_list_get_value(entry, 5, &pattern_length_value)); + + uint64_t pattern = (uint64_t)pattern_value.i64; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_io_parameter_archive_builder_add_splat_entry( + &archive->builder, key, metadata, &pattern, + (uint8_t)pattern_length_value.i64, length)); + } else { + // DATA entry: [type, key, metadata, length, alignment]. + iree_vm_value_t alignment_value; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_vm_list_get_value(entry, 4, &alignment_value)); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_io_parameter_archive_builder_add_data_entry( + &archive->builder, key, metadata, + (uint64_t)alignment_value.i64, length)); + } + } + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +// Creates archive files and providers for each output scope. +static iree_status_t iree_encode_create_archives( + iree_vm_list_t* indices_list, iree_output_scope_list_t* output_list, + iree_allocator_t host_allocator, iree_output_archive_t** out_archives) { + IREE_TRACE_ZONE_BEGIN(z0); + + // Allocate archive array. + iree_output_archive_t* archives = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_allocator_malloc(host_allocator, + output_list->count * sizeof(iree_output_archive_t), + (void**)&archives)); + + // Initialize archive builders. + iree_status_t status = iree_ok_status(); + for (iree_host_size_t i = 0; i < output_list->count; ++i) { + memset(&archives[i], 0, sizeof(archives[i])); + archives[i].scope = output_list->entries[i].scope; + archives[i].path = output_list->entries[i].path; + status = iree_io_parameter_archive_builder_initialize(host_allocator, + &archives[i].builder); + if (!iree_status_is_ok(status)) break; + } + + // Parse indices into archive builders. + if (iree_status_is_ok(status)) { + status = iree_encode_parse_indices_into_archives(indices_list, archives, + output_list->count); + } + + // Create files and write headers. + if (iree_status_is_ok(status)) { + for (iree_host_size_t i = 0; i < output_list->count; ++i) { + iree_output_archive_t* archive = &archives[i]; + + iree_io_physical_size_t archive_size = + iree_io_parameter_archive_builder_total_size(&archive->builder); + + // Create null-terminated path. + char* path_cstr = NULL; + status = iree_allocator_malloc(host_allocator, archive->path.size + 1, + (void**)&path_cstr); + if (!iree_status_is_ok(status)) break; + memcpy(path_cstr, archive->path.data, archive->path.size); + path_cstr[archive->path.size] = '\0'; + + // Create output file. + status = iree_io_file_handle_create( + IREE_IO_FILE_MODE_READ | IREE_IO_FILE_MODE_WRITE, + iree_make_cstring_view(path_cstr), archive_size, host_allocator, + &archive->file_handle); + iree_allocator_free(host_allocator, path_cstr); + if (!iree_status_is_ok(status)) break; + + // Create stream and index. + iree_io_stream_t* stream = NULL; + status = + iree_io_stream_open(IREE_IO_STREAM_MODE_WRITABLE, + archive->file_handle, 0, host_allocator, &stream); + if (!iree_status_is_ok(status)) break; + + status = iree_io_parameter_index_create(host_allocator, &archive->index); + if (!iree_status_is_ok(status)) { + iree_io_stream_release(stream); + break; + } + + // Write archive header. + status = iree_io_parameter_archive_builder_write( + &archive->builder, archive->file_handle, 0, stream, archive->index); + iree_io_stream_release(stream); + if (!iree_status_is_ok(status)) break; + + // Create provider backed by the archive. + status = iree_io_parameter_index_provider_create( + archive->scope, archive->index, + IREE_IO_PARAMETER_INDEX_PROVIDER_DEFAULT_MAX_CONCURRENT_OPERATIONS, + host_allocator, &archive->provider); + if (!iree_status_is_ok(status)) break; + } + } + + if (!iree_status_is_ok(status)) { + for (iree_host_size_t i = 0; i < output_list->count; ++i) { + iree_output_archive_deinitialize(&archives[i]); + } + iree_allocator_free(host_allocator, archives); + IREE_TRACE_ZONE_END(z0); + return status; + } + + *out_archives = archives; + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// Create encoding context with output providers +//===----------------------------------------------------------------------===// + +// Creates the encoding context with output providers attached. +// TODO(benvanik): Allow adding providers to existing parameters module to avoid +// recreating context. +static iree_status_t iree_encode_create_encoding_context( + iree_vm_instance_t* instance, iree_tooling_module_list_t* module_list, + iree_output_archive_t* archives, iree_host_size_t archive_count, + iree_allocator_t host_allocator, iree_vm_context_t** out_context, + iree_hal_device_t** out_device) { + IREE_TRACE_ZONE_BEGIN(z0); + + // Collect output providers. + iree_host_size_t provider_count = 0; + for (iree_host_size_t i = 0; i < archive_count; ++i) { + if (archives[i].provider) ++provider_count; + } + + iree_io_parameter_provider_t** providers = + (iree_io_parameter_provider_t**)iree_alloca( + provider_count * sizeof(iree_io_parameter_provider_t*)); + for (iree_host_size_t i = 0, j = 0; i < archive_count; ++i) { + if (archives[i].provider) { + providers[j++] = archives[i].provider; + } + } + + // Create parameters module with output providers. + iree_vm_module_t* params_module = NULL; + iree_status_t status = iree_tooling_create_parameters_module_from_flags( + instance, provider_count, providers, host_allocator, ¶ms_module); + + // Pre-populate resolved_list with params module so resolver won't create + // default. + iree_tooling_module_list_t resolved_list; + iree_tooling_module_list_initialize(&resolved_list); + + if (iree_status_is_ok(status)) { + status = iree_tooling_module_list_push_back(&resolved_list, params_module); + } + + // Resolve dependencies (adds HAL, etc.). + if (iree_status_is_ok(status)) { + status = iree_tooling_resolve_modules( + instance, module_list->count, module_list->values, + /*default_device_uri=*/iree_string_view_empty(), host_allocator, + &resolved_list, out_device, /*out_device_allocator=*/NULL); + } + + // Create context. + if (iree_status_is_ok(status)) { + status = iree_vm_context_create_with_modules( + instance, IREE_VM_CONTEXT_FLAG_NONE, resolved_list.count, + resolved_list.values, host_allocator, out_context); + } + + iree_tooling_module_list_reset(&resolved_list); + iree_vm_module_release(params_module); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +//===----------------------------------------------------------------------===// +// Call steps function +//===----------------------------------------------------------------------===// + +static iree_status_t iree_encode_call_steps(iree_vm_context_t* context, + iree_encode_target_t* target, + iree_allocator_t host_allocator, + iree_vm_list_t** out_steps_list) { + IREE_TRACE_ZONE_BEGIN(z0); + + *out_steps_list = NULL; + if (!target->steps_fn.module) { + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); // Steps function is optional. + } + + iree_vm_list_t* outputs = NULL; + iree_status_t status = iree_vm_list_create(iree_vm_make_undefined_type_def(), + 1, host_allocator, &outputs); + if (iree_status_is_ok(status)) { + status = iree_vm_invoke( + context, target->steps_fn, IREE_VM_INVOCATION_FLAG_NONE, + /*policy=*/NULL, /*inputs=*/NULL, outputs, host_allocator); + } + + if (iree_status_is_ok(status)) { + iree_vm_ref_t list_ref = iree_vm_ref_null(); + status = iree_vm_list_get_ref_assign(outputs, 0, &list_ref); + if (iree_status_is_ok(status)) { + *out_steps_list = iree_vm_list_deref(list_ref); + if (*out_steps_list) { + iree_vm_list_retain(*out_steps_list); + } + } + } + + iree_vm_list_release(outputs); + IREE_TRACE_ZONE_END(z0); + return status; +} + +//===----------------------------------------------------------------------===// +// Execute encoder +//===----------------------------------------------------------------------===// + +static iree_status_t iree_encode_execute(iree_vm_context_t* context, + iree_hal_device_t* device, + iree_encode_target_t* target, + iree_vm_list_t* steps_list, + iree_allocator_t host_allocator) { + IREE_TRACE_ZONE_BEGIN(z0); + + // Build inputs: [steps_list, wait_fence, signal_fence]. + iree_vm_list_t* inputs = NULL; + iree_status_t status = iree_vm_list_create(iree_vm_make_undefined_type_def(), + 3, host_allocator, &inputs); + + // Push steps list (may be NULL). + if (iree_status_is_ok(status)) { + if (steps_list) { + iree_vm_ref_t steps_ref = iree_vm_list_retain_ref(steps_list); + status = iree_vm_list_push_ref_move(inputs, &steps_ref); + } else { + iree_vm_ref_t null_ref = iree_vm_ref_null(); + status = iree_vm_list_push_ref_move(inputs, &null_ref); + } + } + + // Append async fences. + iree_hal_fence_t* signal_fence = NULL; + if (iree_status_is_ok(status)) { + status = + iree_tooling_append_async_fences(inputs, target->encode_fn, device, + /*wait_fence=*/NULL, &signal_fence); + } + + // Invoke encoder. + if (iree_status_is_ok(status)) { + status = iree_vm_invoke( + context, target->encode_fn, IREE_VM_INVOCATION_FLAG_NONE, + /*policy=*/NULL, inputs, /*outputs=*/NULL, host_allocator); + } + + iree_vm_list_release(inputs); + + // Wait for completion. + if (iree_status_is_ok(status) && signal_fence) { + status = iree_hal_fence_wait(signal_fence, iree_infinite_timeout(), + IREE_HAL_WAIT_FLAG_DEFAULT); + } + + iree_hal_fence_release(signal_fence); + IREE_TRACE_ZONE_END(z0); + return status; +} + +//===----------------------------------------------------------------------===// +// Dump output parameters +//===----------------------------------------------------------------------===// + +// Dumps the contents of output archives similar to iree-dump-parameters. +static iree_status_t iree_encode_dump_outputs(iree_output_archive_t* archives, + iree_host_size_t archive_count, + iree_allocator_t host_allocator) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_string_builder_t sb; + iree_string_builder_initialize(host_allocator, &sb); + + iree_status_t status = iree_ok_status(); + for (iree_host_size_t i = 0; i < archive_count && iree_status_is_ok(status); + ++i) { + iree_output_archive_t* archive = &archives[i]; + if (!archive->index) continue; + + status = iree_string_builder_append_cstring( + &sb, + "//" + "===-----------------------------------------------------------------" + "---------------------------------------------===//\n"); + if (!iree_status_is_ok(status)) break; + + // Print archive header. + iree_io_physical_size_t archive_size = + iree_io_parameter_archive_builder_total_size(&archive->builder); + status = iree_string_builder_append_format( + &sb, "// Output: %.*s (%" PRIu64 " bytes)\n", (int)archive->path.size, + archive->path.data, archive_size); + if (!iree_status_is_ok(status)) break; + + // Dump parameter index. + status = iree_io_parameter_index_dump(archive->scope, archive->index, &sb); + } + + if (iree_status_is_ok(status)) { + fprintf(stdout, "%.*s", (int)iree_string_builder_size(&sb), + iree_string_builder_buffer(&sb)); + } + + iree_string_builder_deinitialize(&sb); + IREE_TRACE_ZONE_END(z0); + return status; +} + +//===----------------------------------------------------------------------===// +// Main encoding workflow +//===----------------------------------------------------------------------===// + +static iree_status_t iree_tooling_encode_parameters( + iree_allocator_t host_allocator) { + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = iree_ok_status(); + + // Create VM instance. + iree_vm_instance_t* instance = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_tooling_create_instance(host_allocator, &instance)); + + // Load modules and discover encoder functions. + iree_tooling_module_list_t module_list; + iree_vm_module_t* encoder_module = NULL; + iree_encode_target_set_t target_set; + status = iree_encode_load_and_discover(instance, host_allocator, &module_list, + &encoder_module, &target_set); + + // Select target. + iree_encode_target_t* selected_target = NULL; + if (iree_status_is_ok(status)) { + status = iree_encode_select_target(&target_set, &selected_target); + } + if (iree_status_is_ok(status)) { + status = iree_encode_validate_target(selected_target); + } + + // Handle --list_targets (early exit). + if (iree_status_is_ok(status) && FLAG_list_targets) { + status = iree_encode_print_targets(encoder_module, &target_set); + iree_encode_target_set_deinitialize(&target_set); + iree_tooling_module_list_reset(&module_list); + iree_vm_instance_release(instance); + IREE_TRACE_ZONE_END(z0); + return status; + } + + // Call indices function. + iree_vm_list_t* indices_list = NULL; + if (iree_status_is_ok(status)) { + status = iree_encode_call_indices(instance, &module_list, selected_target, + host_allocator, &indices_list); + } + + // Handle --list_parameters (early exit). + if (iree_status_is_ok(status) && FLAG_list_parameters) { + status = iree_encode_print_parameters(indices_list); + iree_vm_list_release(indices_list); + iree_encode_target_set_deinitialize(&target_set); + iree_tooling_module_list_reset(&module_list); + iree_vm_instance_release(instance); + IREE_TRACE_ZONE_END(z0); + return status; + } + + // Parse output flags. + iree_output_scope_list_t output_list; + iree_output_scope_list_initialize(host_allocator, &output_list); + if (iree_status_is_ok(status)) { + status = iree_encode_parse_output_flags(&output_list); + } + if (iree_status_is_ok(status) && output_list.count == 0) { + status = iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "no output specified; use --output=[scope=]path.irpa " + "(e.g., --output=encoded=output.irpa or --output=output.irpa)"); + } + + // Create output archives. + iree_output_archive_t* archives = NULL; + if (iree_status_is_ok(status)) { + status = iree_encode_create_archives(indices_list, &output_list, + host_allocator, &archives); + } + + // Create encoding context with output providers. + iree_vm_context_t* context = NULL; + iree_hal_device_t* device = NULL; + if (iree_status_is_ok(status)) { + status = iree_encode_create_encoding_context( + instance, &module_list, archives, output_list.count, host_allocator, + &context, &device); + } + + // Call steps function. + iree_vm_list_t* steps_list = NULL; + if (iree_status_is_ok(status)) { + status = iree_encode_call_steps(context, selected_target, host_allocator, + &steps_list); + } + + // Execute encoder. + if (iree_status_is_ok(status)) { + status = iree_encode_execute(context, device, selected_target, steps_list, + host_allocator); + } + + // Dump output parameters (unless quiet mode). + if (iree_status_is_ok(status) && !FLAG_quiet) { + status = + iree_encode_dump_outputs(archives, output_list.count, host_allocator); + } + + // Cleanup. + iree_vm_list_release(steps_list); + iree_vm_list_release(indices_list); + if (archives) { + for (iree_host_size_t i = 0; i < output_list.count; ++i) { + iree_output_archive_deinitialize(&archives[i]); + } + iree_allocator_free(host_allocator, archives); + } + iree_hal_device_release(device); + iree_vm_context_release(context); + iree_output_scope_list_deinitialize(&output_list); + iree_encode_target_set_deinitialize(&target_set); + iree_tooling_module_list_reset(&module_list); + iree_vm_instance_release(instance); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +//===----------------------------------------------------------------------===// +// Entry point +//===----------------------------------------------------------------------===// + +int main(int argc, char** argv) { + IREE_TRACE_APP_ENTER(); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_allocator_t host_allocator = iree_allocator_system(); + int exit_code = EXIT_SUCCESS; + + iree_flags_set_usage( + "iree-encode-parameters", + "Encodes parameter files using an encoding module.\n" + "\n" + "This tool transforms model parameters using an encoder module produced\n" + "by iree-compile with --iree-parameter-encoder-output-file. The encoder\n" + "pre-computes parameter transformations (packing, encoding, dispatches)\n" + "that would otherwise run at model load time.\n" + "\n" + "WORKFLOW:\n" + " 1. Compile main module with encoder output:\n" + " iree-compile model.mlir \\\n" + " --iree-parameter-encoder-output-file=encoder.mlir \\\n" + " --iree-parameter-splat-path=input.irpa \\\n" + " -o main.vmfb\n" + "\n" + " 2. Compile the encoder module:\n" + " iree-compile encoder.mlir -o encoder.vmfb\n" + "\n" + " 3. Run the encoder to transform parameters:\n" + " iree-encode-parameters \\\n" + " --module=encoder.vmfb \\\n" + " --parameters=model=input.irpa \\\n" + " --output=encoded=output.irpa\n" + "\n" + " 4. Run the main module with encoded parameters:\n" + " iree-run-module \\\n" + " --module=main.vmfb \\\n" + " --parameters=model=input.irpa \\\n" + " --parameters=encoded=output.irpa\n" + "\n" + "FLAGS:\n" + " --module=path.vmfb Encoder module (required)\n" + " --parameters=scope=path Input parameter file(s)\n" + " --output=scope=path.irpa Output encoded parameter file(s)\n" + " --list-targets List available encoding targets\n" + " --list-parameters List parameters that will be encoded\n" + " --target=name Select specific target (default: auto-detect)\n" + " --quiet Suppress output except errors\n"); + iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv); + + if (argc > 1) { + fprintf(stderr, "Error: no positional arguments expected.\n"); + fprintf(stderr, + "Use one or more --parameters=file.ext flags to specify parameter " + "files.\n"); + IREE_TRACE_ZONE_END(z0); + IREE_TRACE_APP_EXIT(exit_code); + return EXIT_FAILURE; + } + + iree_status_t status = iree_tooling_encode_parameters(host_allocator); + + fflush(stdout); + if (!iree_status_is_ok(status)) { + iree_status_fprint(stderr, status); + iree_status_free(status); + exit_code = EXIT_FAILURE; + } + fflush(stderr); + + IREE_TRACE_ZONE_END(z0); + IREE_TRACE_APP_EXIT(exit_code); + return exit_code; +} diff --git a/tools/test/iree-dump-module.mlir b/tools/test/iree-dump-module.mlir index db95f5c1f604..2609a754e14e 100644 --- a/tools/test/iree-dump-module.mlir +++ b/tools/test/iree-dump-module.mlir @@ -8,20 +8,20 @@ // RUN: %t.vmfb | \ // RUN: FileCheck %s -// CHECK-LABEL: @module : version 0 +// CHECK: @module : version 0 -// CHECK-LABEL: module.fn0 +// CHECK: fn0 func.func @fn0(%input : tensor) -> (tensor) { - // CHECK: [{{[0-9]+}}] + // CHECK: [{{[0-9]+}}]{{.*}} %result = math.absf %input : tensor return %result : tensor } -// CHECK-LABEL: module.fn1 +// CHECK: fn1 func.func @fn1(%input : tensor) -> (tensor) { - // CHECK: [{{[0-9]+}}] + // CHECK: [{{[0-9]+}}]{{.*}} %result = arith.mulf %input, %input : tensor return %result : tensor } -// CHECK-LABEL: module.__init +// CHECK: __init